/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.MoreCollectors;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.trino.operator.aggregation.AggregationMask;
import io.trino.spi.block.Block;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.ValueBlock;
import io.trino.spi.function.GroupedAccumulatorState;
import io.trino.sql.gen.BytecodeUtils;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.util.CompilerUtils;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Method;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

final class AggregationLoopBuilder {
    private AggregationLoopBuilder() {
    }

    public static MethodHandle buildLoop(MethodHandle function, int stateCount, int parameterCount, boolean grouped) {
        AggregationLoopBuilder.verifyFunctionSignature(function, stateCount, parameterCount);
        CallSiteBinder binder = new CallSiteBinder();
        ClassDefinition definition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC, Access.FINAL}), CompilerUtils.makeClassName("AggregationLoop"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        definition.declareDefaultConstructor(Access.a((Access[])new Access[]{Access.PRIVATE}));
        AggregationLoopBuilder.buildSpecializedLoop(binder, definition, function, stateCount, parameterCount, grouped);
        Class<Object> clazz = CompilerUtils.defineClass(definition, Object.class, binder.getBindings(), AggregationLoopBuilder.class.getClassLoader());
        Method invokeMethod = (Method)Arrays.stream(clazz.getMethods()).filter(method -> method.getName().equals("invoke")).collect(MoreCollectors.onlyElement());
        try {
            return MethodHandles.lookup().unreflect(invokeMethod);
        }
        catch (IllegalAccessException e) {
            throw new RuntimeException(e);
        }
    }

    private static void buildSpecializedLoop(CallSiteBinder binder, ClassDefinition classDefinition, MethodHandle function, int stateCount, int parameterCount, boolean grouped) {
        AggregationParameters aggregationParameters = AggregationParameters.create(function, stateCount, parameterCount, grouped);
        MethodDefinition methodDefinition = classDefinition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), "invoke", ParameterizedType.type(Void.TYPE), aggregationParameters.allParameters());
        Function<List<BlockType>, BytecodeNode> coreLoopBuilder = blockTypes -> {
            MethodDefinition method = AggregationLoopBuilder.buildCoreLoop(binder, classDefinition, function, blockTypes, aggregationParameters);
            return BytecodeExpressions.invokeStatic((MethodDefinition)method, (BytecodeExpression[])aggregationParameters.allParameters().toArray(new BytecodeExpression[0]));
        };
        BytecodeNode bytecodeNode = AggregationLoopBuilder.buildLoopSelection(coreLoopBuilder, new ArrayDeque<BlockType>(parameterCount), new ArrayDeque<Parameter>(aggregationParameters.blocks()));
        methodDefinition.getBody().append(bytecodeNode).ret();
    }

    private static BytecodeNode buildLoopSelection(Function<List<BlockType>, BytecodeNode> coreLoopBuilder, ArrayDeque<BlockType> currentTypes, ArrayDeque<Parameter> remainingParameters) {
        if (remainingParameters.isEmpty()) {
            return coreLoopBuilder.apply((List<BlockType>)ImmutableList.copyOf(currentTypes));
        }
        Parameter blockParameter = remainingParameters.removeFirst();
        currentTypes.addLast(BlockType.VALUE);
        BytecodeNode valueLoop = AggregationLoopBuilder.buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters);
        currentTypes.removeLast();
        currentTypes.addLast(BlockType.DICTIONARY);
        BytecodeNode dictionaryLoop = AggregationLoopBuilder.buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters);
        currentTypes.removeLast();
        currentTypes.addLast(BlockType.RLE);
        BytecodeNode rleLoop = AggregationLoopBuilder.buildLoopSelection(coreLoopBuilder, currentTypes, remainingParameters);
        currentTypes.removeLast();
        IfStatement blockTypeSelection = new IfStatement().condition((BytecodeNode)blockParameter.instanceOf(ValueBlock.class)).ifTrue(valueLoop).ifFalse((BytecodeNode)new IfStatement().condition((BytecodeNode)blockParameter.instanceOf(DictionaryBlock.class)).ifTrue(dictionaryLoop).ifFalse((BytecodeNode)new IfStatement().condition((BytecodeNode)blockParameter.instanceOf(RunLengthEncodedBlock.class)).ifTrue(rleLoop).ifFalse((BytecodeNode)new BytecodeBlock().append((BytecodeNode)BytecodeExpressions.newInstance(UnsupportedOperationException.class, (BytecodeExpression[])new BytecodeExpression[]{BytecodeExpressions.constantString((String)"Aggregation is not decomposable")})).throwObject())));
        remainingParameters.addFirst(blockParameter);
        return blockTypeSelection;
    }

    private static MethodDefinition buildCoreLoop(CallSiteBinder binder, ClassDefinition classDefinition, MethodHandle function, List<BlockType> blockTypes, AggregationParameters aggregationParameters) {
        StringBuilder methodName = new StringBuilder("invoke_");
        for (BlockType blockType : blockTypes) {
            methodName.append(blockType.name().charAt(0));
        }
        MethodDefinition methodDefinition = classDefinition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC, Access.STATIC}), methodName.toString(), ParameterizedType.type(Void.TYPE), aggregationParameters.allParameters());
        Scope scope = methodDefinition.getScope();
        BytecodeBlock body = methodDefinition.getBody();
        Variable position = scope.declareVariable(Integer.TYPE, "position");
        ImmutableList.Builder aggregationArguments = ImmutableList.builder();
        aggregationArguments.addAll(aggregationParameters.states());
        AggregationLoopBuilder.addBlockPositionArguments(methodDefinition, position, blockTypes, aggregationParameters.blocks(), (ImmutableList.Builder<BytecodeExpression>)aggregationArguments);
        aggregationArguments.addAll(aggregationParameters.lambdas());
        BytecodeBlock invokeFunction = new BytecodeBlock();
        if (aggregationParameters.groupIds().isPresent()) {
            Variable groupId = scope.declareVariable(Integer.TYPE, "groupId");
            invokeFunction.append((BytecodeNode)groupId.set(aggregationParameters.groupIds().get().getElement((BytecodeExpression)position)));
            for (Parameter stateParameter : aggregationParameters.states()) {
                invokeFunction.append((BytecodeNode)stateParameter.cast(GroupedAccumulatorState.class).invoke("setGroupId", Void.TYPE, new BytecodeExpression[]{groupId}));
            }
        }
        invokeFunction.append((BytecodeNode)BytecodeUtils.invoke(binder.bind(function), "input", (List<BytecodeExpression>)aggregationArguments.build()));
        Variable positionCount = scope.declareVariable("positionCount", body, aggregationParameters.mask().invoke("getSelectedPositionCount", Integer.TYPE, new BytecodeExpression[0]));
        ForLoop selectAllLoop = new ForLoop().initialize((BytecodeNode)position.set(BytecodeExpressions.constantInt((int)0))).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)position, (BytecodeExpression)positionCount)).update((BytecodeNode)position.increment()).body((BytecodeNode)invokeFunction);
        Variable index = scope.declareVariable("index", body, BytecodeExpressions.constantInt((int)0));
        Variable selectedPositions = scope.declareVariable(int[].class, "selectedPositions");
        ForLoop maskedLoop = new ForLoop().initialize((BytecodeNode)selectedPositions.set(aggregationParameters.mask().invoke("getSelectedPositions", int[].class, new BytecodeExpression[0]))).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)index, (BytecodeExpression)positionCount)).update((BytecodeNode)index.increment()).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)position.set(selectedPositions.getElement((BytecodeExpression)index))).append((BytecodeNode)invokeFunction));
        body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)aggregationParameters.mask().invoke("isSelectAll", Boolean.TYPE, new BytecodeExpression[0])).ifTrue((BytecodeNode)selectAllLoop).ifFalse((BytecodeNode)maskedLoop));
        body.ret();
        return methodDefinition;
    }

    private static void addBlockPositionArguments(MethodDefinition methodDefinition, Variable position, List<BlockType> blockTypes, List<Parameter> blockParameters, ImmutableList.Builder<BytecodeExpression> aggregationArguments) {
        Scope scope = methodDefinition.getScope();
        BytecodeBlock methodBody = methodDefinition.getBody();
        block5: for (int i = 0; i < blockTypes.size(); ++i) {
            BlockType blockType = blockTypes.get(i);
            switch (blockType.ordinal()) {
                case 2: {
                    aggregationArguments.add((Object)blockParameters.get(i).cast(ValueBlock.class));
                    aggregationArguments.add((Object)position);
                    continue block5;
                }
                case 1: {
                    Variable valueBlock = scope.declareVariable("valueBlock" + i, methodBody, blockParameters.get(i).cast(DictionaryBlock.class).invoke("getDictionary", ValueBlock.class, new BytecodeExpression[0]));
                    Variable rawIds = scope.declareVariable("rawIds" + i, methodBody, blockParameters.get(i).cast(DictionaryBlock.class).invoke("getRawIds", int[].class, new BytecodeExpression[0]));
                    Variable rawIdsOffset = scope.declareVariable("rawIdsOffset" + i, methodBody, blockParameters.get(i).cast(DictionaryBlock.class).invoke("getRawIdsOffset", Integer.TYPE, new BytecodeExpression[0]));
                    aggregationArguments.add((Object)valueBlock);
                    aggregationArguments.add((Object)rawIds.getElement(BytecodeExpressions.add((BytecodeExpression)rawIdsOffset, (BytecodeExpression)position)));
                    continue block5;
                }
                case 0: {
                    Variable valueBlock = scope.declareVariable("valueBlock" + i, methodBody, blockParameters.get(i).cast(RunLengthEncodedBlock.class).invoke("getValue", ValueBlock.class, new BytecodeExpression[0]));
                    aggregationArguments.add((Object)valueBlock);
                    aggregationArguments.add((Object)BytecodeExpressions.constantInt((int)0));
                }
            }
        }
    }

    private static void verifyFunctionSignature(MethodHandle function, int stateCount, int parameterCount) {
        ImmutableList expectedParameterTypes = ImmutableList.builder().addAll(function.type().parameterList().subList(0, stateCount)).addAll(Iterables.limit((Iterable)Iterables.cycle((Object[])new Class[]{ValueBlock.class, Integer.TYPE}), (int)(parameterCount * 2))).addAll(function.type().parameterList().subList(stateCount + parameterCount * 2, function.type().parameterCount())).build();
        MethodType expectedSignature = MethodType.methodType(Void.TYPE, expectedParameterTypes);
        Preconditions.checkArgument((boolean)function.type().equals((Object)expectedSignature), (String)"Expected function signature to be %s, but is %s", (Object)expectedSignature, (Object)function.type());
    }

    private record AggregationParameters(Parameter mask, Optional<Parameter> groupIds, List<Parameter> states, List<Parameter> blocks, List<Parameter> lambdas) {
        static AggregationParameters create(MethodHandle function, int stateCount, int parameterCount, boolean grouped) {
            Parameter mask = Parameter.arg((String)"aggregationMask", AggregationMask.class);
            Optional<Parameter> groupIds = Optional.empty();
            if (grouped) {
                groupIds = Optional.of(Parameter.arg((String)"groupIds", int[].class));
            }
            ImmutableList.Builder states = ImmutableList.builder();
            for (int i = 0; i < stateCount; ++i) {
                states.add((Object)Parameter.arg((String)("state" + i), (Class)function.type().parameterType(i)));
            }
            ImmutableList.Builder parameters = ImmutableList.builder();
            for (int i = 0; i < parameterCount; ++i) {
                parameters.add((Object)Parameter.arg((String)("block" + i), Block.class));
            }
            ImmutableList.Builder lambdas = ImmutableList.builder();
            int lambdaFunctionOffset = stateCount + parameterCount * 2;
            for (int i = 0; i < function.type().parameterCount() - lambdaFunctionOffset; ++i) {
                lambdas.add((Object)Parameter.arg((String)("lambda" + i), (Class)function.type().parameterType(lambdaFunctionOffset + i)));
            }
            return new AggregationParameters(mask, groupIds, (List<Parameter>)states.build(), (List<Parameter>)parameters.build(), (List<Parameter>)lambdas.build());
        }

        public List<Parameter> allParameters() {
            return ImmutableList.builder().add((Object)this.mask).addAll(this.groupIds.stream().iterator()).addAll(this.states).addAll(this.blocks).addAll(this.lambdas).build();
        }
    }

    private static enum BlockType {
        RLE,
        DICTIONARY,
        VALUE;

    }
}

