/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationFunctionAdapter;
import io.trino.operator.aggregation.state.GenericBooleanState;
import io.trino.operator.aggregation.state.GenericBooleanStateSerializer;
import io.trino.operator.aggregation.state.GenericDoubleState;
import io.trino.operator.aggregation.state.GenericDoubleStateSerializer;
import io.trino.operator.aggregation.state.GenericLongState;
import io.trino.operator.aggregation.state.GenericLongStateSerializer;
import io.trino.operator.aggregation.state.StateCompiler;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.function.AggregationFunctionMetadata;
import io.trino.spi.function.AggregationImplementation;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.Signature;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.spi.type.TypeSignatureParameter;
import io.trino.sql.gen.lambda.BinaryFunctionInterface;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.List;

public class ReduceAggregationFunction
extends SqlAggregationFunction {
    public static final ReduceAggregationFunction REDUCE_AGG = new ReduceAggregationFunction();
    private static final String NAME = "reduce_agg";
    private static final MethodHandle LONG_STATE_INPUT_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "input", GenericLongState.class, Object.class, Long.TYPE, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle DOUBLE_STATE_INPUT_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "input", GenericDoubleState.class, Object.class, Double.TYPE, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle BOOLEAN_STATE_INPUT_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "input", GenericBooleanState.class, Object.class, Boolean.TYPE, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle LONG_STATE_COMBINE_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "combine", GenericLongState.class, GenericLongState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle DOUBLE_STATE_COMBINE_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "combine", GenericDoubleState.class, GenericDoubleState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle BOOLEAN_STATE_COMBINE_FUNCTION = Reflection.methodHandle(ReduceAggregationFunction.class, "combine", GenericBooleanState.class, GenericBooleanState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle LONG_STATE_OUTPUT_FUNCTION = Reflection.methodHandle(GenericLongState.class, "write", Type.class, GenericLongState.class, BlockBuilder.class);
    private static final MethodHandle DOUBLE_STATE_OUTPUT_FUNCTION = Reflection.methodHandle(GenericDoubleState.class, "write", Type.class, GenericDoubleState.class, BlockBuilder.class);
    private static final MethodHandle BOOLEAN_STATE_OUTPUT_FUNCTION = Reflection.methodHandle(GenericBooleanState.class, "write", Type.class, GenericBooleanState.class, BlockBuilder.class);

    public ReduceAggregationFunction() {
        super(FunctionMetadata.aggregateBuilder((String)NAME).signature(Signature.builder().typeVariable("T").typeVariable("S").returnType(new TypeSignature("S", new TypeSignatureParameter[0])).argumentType(new TypeSignature("T", new TypeSignatureParameter[0])).argumentType(new TypeSignature("S", new TypeSignatureParameter[0])).argumentType(TypeSignature.functionType((TypeSignature)new TypeSignature("S", new TypeSignatureParameter[0]), (TypeSignature[])new TypeSignature[]{new TypeSignature("T", new TypeSignatureParameter[0]), new TypeSignature("S", new TypeSignatureParameter[0])})).argumentType(TypeSignature.functionType((TypeSignature)new TypeSignature("S", new TypeSignatureParameter[0]), (TypeSignature[])new TypeSignature[]{new TypeSignature("S", new TypeSignatureParameter[0]), new TypeSignature("S", new TypeSignatureParameter[0])})).build()).description("Reduce input elements into a single value").build(), AggregationFunctionMetadata.builder().intermediateType(new TypeSignature("S", new TypeSignatureParameter[0])).build());
    }

    @Override
    public AggregationImplementation specialize(BoundSignature boundSignature) {
        Type inputType = (Type)boundSignature.getArgumentTypes().get(0);
        Type stateType = (Type)boundSignature.getArgumentTypes().get(1);
        if (stateType.getJavaType() == Long.TYPE) {
            return AggregationImplementation.builder().inputFunction(ReduceAggregationFunction.normalizeInputMethod(boundSignature, inputType, LONG_STATE_INPUT_FUNCTION)).combineFunction(LONG_STATE_COMBINE_FUNCTION).outputFunction(LONG_STATE_OUTPUT_FUNCTION.bindTo(stateType)).accumulatorStateDescriptor(GenericLongState.class, (AccumulatorStateSerializer)new GenericLongStateSerializer(stateType), StateCompiler.generateStateFactory(GenericLongState.class)).lambdaInterfaces(new Class[]{BinaryFunctionInterface.class, BinaryFunctionInterface.class}).build();
        }
        if (stateType.getJavaType() == Double.TYPE) {
            return AggregationImplementation.builder().inputFunction(ReduceAggregationFunction.normalizeInputMethod(boundSignature, inputType, DOUBLE_STATE_INPUT_FUNCTION)).combineFunction(DOUBLE_STATE_COMBINE_FUNCTION).outputFunction(DOUBLE_STATE_OUTPUT_FUNCTION.bindTo(stateType)).accumulatorStateDescriptor(GenericDoubleState.class, (AccumulatorStateSerializer)new GenericDoubleStateSerializer(stateType), StateCompiler.generateStateFactory(GenericDoubleState.class)).lambdaInterfaces(new Class[]{BinaryFunctionInterface.class, BinaryFunctionInterface.class}).build();
        }
        if (stateType.getJavaType() == Boolean.TYPE) {
            return AggregationImplementation.builder().inputFunction(ReduceAggregationFunction.normalizeInputMethod(boundSignature, inputType, BOOLEAN_STATE_INPUT_FUNCTION)).combineFunction(BOOLEAN_STATE_COMBINE_FUNCTION).outputFunction(BOOLEAN_STATE_OUTPUT_FUNCTION.bindTo(stateType)).accumulatorStateDescriptor(GenericBooleanState.class, (AccumulatorStateSerializer)new GenericBooleanStateSerializer(stateType), StateCompiler.generateStateFactory(GenericBooleanState.class)).lambdaInterfaces(new Class[]{BinaryFunctionInterface.class, BinaryFunctionInterface.class}).build();
        }
        throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, String.format("State type not supported for %s: %s", NAME, stateType.getDisplayName()));
    }

    private static MethodHandle normalizeInputMethod(BoundSignature boundSignature, Type inputType, MethodHandle inputMethodHandle) {
        inputMethodHandle = inputMethodHandle.asType(inputMethodHandle.type().changeParameterType(1, inputType.getJavaType()));
        inputMethodHandle = AggregationFunctionAdapter.normalizeInputMethod(inputMethodHandle, boundSignature, (List<AggregationFunctionAdapter.AggregationParameterKind>)ImmutableList.of((Object)((Object)AggregationFunctionAdapter.AggregationParameterKind.STATE), (Object)((Object)AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL), (Object)((Object)AggregationFunctionAdapter.AggregationParameterKind.INPUT_CHANNEL)), 2);
        return inputMethodHandle;
    }

    public static void input(GenericLongState state, Object value, long initialStateValue, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) {
        if (state.isNull()) {
            state.setNull(false);
            state.setValue(initialStateValue);
        }
        state.setValue((Long)inputFunction.apply(state.getValue(), value));
    }

    public static void input(GenericDoubleState state, Object value, double initialStateValue, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) {
        if (state.isNull()) {
            state.setNull(false);
            state.setValue(initialStateValue);
        }
        state.setValue((Double)inputFunction.apply(state.getValue(), value));
    }

    public static void input(GenericBooleanState state, Object value, boolean initialStateValue, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) {
        if (state.isNull()) {
            state.setNull(false);
            state.setValue(initialStateValue);
        }
        state.setValue((Boolean)inputFunction.apply(state.getValue(), value));
    }

    public static void combine(GenericLongState state, GenericLongState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) {
        if (state.isNull()) {
            state.set(otherState);
            return;
        }
        state.setValue((Long)combineFunction.apply(state.getValue(), otherState.getValue()));
    }

    public static void combine(GenericDoubleState state, GenericDoubleState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) {
        if (state.isNull()) {
            state.set(otherState);
            return;
        }
        state.setValue((Double)combineFunction.apply(state.getValue(), otherState.getValue()));
    }

    public static void combine(GenericBooleanState state, GenericBooleanState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction) {
        if (state.isNull()) {
            state.set(otherState);
            return;
        }
        state.setValue((Boolean)combineFunction.apply(state.getValue(), otherState.getValue()));
    }
}

