/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.metadata.BoundSignature;
import io.trino.operator.PagesIndex;
import io.trino.operator.aggregation.AccumulatorCompiler;
import io.trino.operator.aggregation.AccumulatorFactory;
import io.trino.operator.aggregation.AccumulatorFactoryBinder;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.LambdaProvider;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.connector.SortOrder;
import io.trino.spi.type.Type;
import io.trino.sql.gen.JoinCompiler;
import io.trino.type.BlockTypeOperators;
import java.lang.invoke.MethodHandle;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public final class InternalAggregationFunction {
    private static final Set<Class<?>> SUPPORTED_PRIMITIVE_TYPES = ImmutableSet.of(Long.TYPE, Double.TYPE, Boolean.TYPE);
    private final List<Class<?>> lambdaInterfaces;
    private final AccumulatorFactoryBinder factory;

    public InternalAggregationFunction(BoundSignature boundSignature, AggregationMetadata aggregationMetadata) {
        Objects.requireNonNull(boundSignature, "boundSignature is null");
        Objects.requireNonNull(aggregationMetadata, "aggregationMetadata is null");
        this.factory = AccumulatorCompiler.generateAccumulatorFactoryBinder(boundSignature, aggregationMetadata);
        this.lambdaInterfaces = ImmutableList.copyOf(aggregationMetadata.getLambdaInterfaces());
        InternalAggregationFunction.verifyInputFunctionSignature(boundSignature, aggregationMetadata);
        InternalAggregationFunction.verifyCombineFunction(aggregationMetadata);
        InternalAggregationFunction.verifyExactOutputFunction(aggregationMetadata);
    }

    public List<Class<?>> getLambdaInterfaces() {
        return this.lambdaInterfaces;
    }

    public AccumulatorFactory bind(List<Integer> inputChannels, Optional<Integer> maskChannel) {
        return this.factory.bind(inputChannels, maskChannel, (List<Type>)ImmutableList.of(), (List<Integer>)ImmutableList.of(), (List<SortOrder>)ImmutableList.of(), null, false, null, null, (List<LambdaProvider>)ImmutableList.of(), null);
    }

    public AccumulatorFactory bind(List<Integer> inputChannels, Optional<Integer> maskChannel, List<Type> sourceTypes, List<Integer> orderByChannels, List<SortOrder> orderings, PagesIndex.Factory pagesIndexFactory, boolean distinct, JoinCompiler joinCompiler, BlockTypeOperators blockTypeOperators, List<LambdaProvider> lambdaProviders, Session session) {
        return this.factory.bind(inputChannels, maskChannel, sourceTypes, orderByChannels, orderings, pagesIndexFactory, distinct, joinCompiler, blockTypeOperators, lambdaProviders, session);
    }

    private static void verifyInputFunctionSignature(BoundSignature boundSignature, AggregationMetadata aggregationMetadata) {
        int i;
        MethodHandle inputFunction = aggregationMetadata.getInputFunction();
        List<AggregationMetadata.AggregationParameterKind> inputParameterKinds = aggregationMetadata.getInputParameterKinds();
        List<Class<?>> parameters = inputFunction.type().parameterList();
        List<Class<?>> lambdaInterfaces = aggregationMetadata.getLambdaInterfaces();
        Preconditions.checkArgument((parameters.size() == inputParameterKinds.size() + lambdaInterfaces.size() ? 1 : 0) != 0, (String)"Expected input to have %s input arguments, but it has %s arguments", (int)(inputParameterKinds.size() + lambdaInterfaces.size()), (int)parameters.size());
        List<AggregationMetadata.AccumulatorStateDescriptor<?>> accumulatorStateDescriptors = aggregationMetadata.getAccumulatorStateDescriptors();
        Preconditions.checkArgument((inputParameterKinds.stream().filter(AggregationMetadata.AggregationParameterKind.STATE::equals).count() == (long)accumulatorStateDescriptors.size() ? 1 : 0) != 0, (Object)"Number of state parameter in input function must be the same as size of stateDescriptors");
        Preconditions.checkArgument((inputParameterKinds.get(0) == AggregationMetadata.AggregationParameterKind.STATE ? 1 : 0) != 0, (Object)"First parameter must be state");
        int stateIndex = 0;
        int parameterIndex = 0;
        block6: for (i = 0; i < inputParameterKinds.size(); ++i) {
            AggregationMetadata.AggregationParameterKind parameterKind = inputParameterKinds.get(i);
            switch (parameterKind) {
                case STATE: {
                    Preconditions.checkArgument((accumulatorStateDescriptors.get(stateIndex).getStateInterface() == parameters.get(i) ? 1 : 0) != 0, (String)"State argument must be of type %s", accumulatorStateDescriptors.get(stateIndex).getStateInterface());
                    ++stateIndex;
                    continue block6;
                }
                case BLOCK_INPUT_CHANNEL: 
                case NULLABLE_BLOCK_INPUT_CHANNEL: {
                    Preconditions.checkArgument((parameters.get(i) == Block.class ? 1 : 0) != 0, (Object)"Parameter must be Block if it has @BlockPosition");
                    ++parameterIndex;
                    continue block6;
                }
                case INPUT_CHANNEL: {
                    Preconditions.checkArgument((!parameters.get(i).isPrimitive() || SUPPORTED_PRIMITIVE_TYPES.contains(parameters.get(i)) ? 1 : 0) != 0, (String)"Unsupported type: %s", (Object)parameters.get(i).getSimpleName());
                    InternalAggregationFunction.verifyMethodParameterType(inputFunction, i, boundSignature.getArgumentTypes().get(parameterIndex));
                    ++parameterIndex;
                    continue block6;
                }
                case BLOCK_INDEX: {
                    Preconditions.checkArgument((parameters.get(i) == Integer.TYPE ? 1 : 0) != 0, (Object)"Block index parameter must be an int");
                    continue block6;
                }
                default: {
                    throw new IllegalArgumentException("Unsupported parameter: " + parameterKind);
                }
            }
        }
        Preconditions.checkArgument((stateIndex == accumulatorStateDescriptors.size() ? 1 : 0) != 0, (String)"Input function only has %s states, expected: %s", (int)stateIndex, (int)accumulatorStateDescriptors.size());
        for (i = 0; i < lambdaInterfaces.size(); ++i) {
            InternalAggregationFunction.verifyMethodParameterType(inputFunction, i + inputParameterKinds.size(), lambdaInterfaces.get(i), "function");
        }
    }

    private static void verifyCombineFunction(AggregationMetadata aggregationMetadata) {
        int i;
        MethodHandle combineFunction = aggregationMetadata.getCombineFunction();
        Class<?>[] parameterTypes = combineFunction.type().parameterArray();
        List<Class<?>> lambdaInterfaces = aggregationMetadata.getLambdaInterfaces();
        List<AggregationMetadata.AccumulatorStateDescriptor<?>> stateDescriptors = aggregationMetadata.getAccumulatorStateDescriptors();
        Preconditions.checkArgument((parameterTypes.length == stateDescriptors.size() * 2 + lambdaInterfaces.size() ? 1 : 0) != 0, (Object)"Number of arguments for combine function must be 2 times the size of states plus number of lambda channels.");
        for (i = 0; i < stateDescriptors.size() * 2; ++i) {
            Preconditions.checkArgument((boolean)parameterTypes[i].equals(stateDescriptors.get(i % stateDescriptors.size()).getStateInterface()), (String)"Type for Parameter index %s is unexpected. Arguments for combine function must appear in the order of state1, state2, ..., otherState1, otherState2, ...", (int)i);
        }
        for (i = 0; i < lambdaInterfaces.size(); ++i) {
            InternalAggregationFunction.verifyMethodParameterType(combineFunction, i + stateDescriptors.size() * 2, lambdaInterfaces.get(i), "function");
        }
    }

    private static void verifyExactOutputFunction(AggregationMetadata aggregationMetadata) {
        List<AggregationMetadata.AccumulatorStateDescriptor<?>> stateDescriptors;
        Class<?>[] parameterTypes = aggregationMetadata.getOutputFunction().type().parameterArray();
        Preconditions.checkArgument((parameterTypes.length == (stateDescriptors = aggregationMetadata.getAccumulatorStateDescriptors()).size() + 1 ? 1 : 0) != 0, (Object)"Number of arguments for combine function must be exactly one plus than number of states.");
        for (int i = 0; i < stateDescriptors.size(); ++i) {
            Preconditions.checkArgument((boolean)parameterTypes[i].equals(stateDescriptors.get(i).getStateInterface()), (String)"Type for Parameter index %s is unexpected", (int)i);
        }
        Preconditions.checkArgument((Arrays.stream(parameterTypes).filter(type -> type.equals(BlockBuilder.class)).count() == 1L ? 1 : 0) != 0, (Object)"Output function must take exactly one BlockBuilder parameter");
    }

    private static void verifyMethodParameterType(MethodHandle inputFunction, int i, Type type) {
        InternalAggregationFunction.verifyMethodParameterType(inputFunction, i, type.getJavaType(), type.getDisplayName());
    }

    private static void verifyMethodParameterType(MethodHandle method, int index, Class<?> javaType, String sqlTypeDisplayName) {
        Preconditions.checkArgument((boolean)((Class)method.type().parameterType(index)).isAssignableFrom(javaType), (String)"Expected method %s parameter %s type to be %s (%s)", (Object)method, (Object)index, (Object)javaType.getName(), (Object)sqlTypeDisplayName);
    }
}

