/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableMap;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.FieldDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.operator.aggregation.AggregationMask;
import io.trino.operator.aggregation.AggregationMaskBuilder;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.LazyBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.ValueBlock;
import io.trino.util.CompilerUtils;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;

public final class AggregationMaskCompiler {
    private AggregationMaskCompiler() {
    }

    public static Constructor<? extends AggregationMaskBuilder> generateAggregationMaskBuilder(int ... nonNullArgumentChannels) {
        ClassDefinition definition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName(AggregationMaskBuilder.class.getSimpleName()), ParameterizedType.type(Object.class), new ParameterizedType[]{ParameterizedType.type(AggregationMaskBuilder.class)});
        FieldDefinition selectedPositionsField = definition.declareField(Access.a((Access[])new Access[]{Access.PRIVATE}), "selectedPositions", int[].class);
        MethodDefinition constructor = definition.declareConstructor(Access.a((Access[])new Access[]{Access.PUBLIC}), new Parameter[0]);
        constructor.getBody().comment("super();").append((BytecodeNode)constructor.getThis()).invokeConstructor(Object.class, new Class[0]).append((BytecodeNode)constructor.getThis().setField(selectedPositionsField, BytecodeExpressions.newArray((ParameterizedType)ParameterizedType.type(int[].class), (int)0))).ret();
        Parameter argumentsParameter = Parameter.arg((String)"arguments", (ParameterizedType)ParameterizedType.type(Page.class));
        Parameter maskBlockParameter = Parameter.arg((String)"optionalMaskBlock", (ParameterizedType)ParameterizedType.type(Optional.class, (Class[])new Class[]{Block.class}));
        MethodDefinition method = definition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC}), "buildAggregationMask", ParameterizedType.type(AggregationMask.class), new Parameter[]{argumentsParameter, maskBlockParameter});
        BytecodeBlock body = method.getBody();
        Scope scope = method.getScope();
        Variable positionCount = scope.declareVariable("positionCount", body, argumentsParameter.invoke("getPositionCount", Integer.TYPE, new BytecodeExpression[0]));
        body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.equal((BytecodeExpression)positionCount, (BytecodeExpression)BytecodeExpressions.constantInt((int)0))).ifTrue((BytecodeNode)BytecodeExpressions.invokeStatic(AggregationMask.class, (String)"createSelectNone", AggregationMask.class, (BytecodeExpression[])new BytecodeExpression[]{positionCount}).ret()));
        Variable maskBlock = scope.declareVariable("maskBlock", body, maskBlockParameter.invoke("orElse", Object.class, new BytecodeExpression[]{BytecodeExpressions.constantNull(Object.class)}).cast(Block.class));
        body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)maskBlock.instanceOf(LazyBlock.class)).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)BytecodeExpressions.newInstance(IllegalArgumentException.class, (BytecodeExpression[])new BytecodeExpression[]{BytecodeExpressions.constantString((String)"mask block must not be a LazyBlock")})).throwObject()));
        Variable hasMaskBlock = scope.declareVariable("hasMaskBlock", body, BytecodeExpressions.isNotNull((BytecodeExpression)maskBlock));
        Variable maskBlockMayHaveNull = scope.declareVariable("maskBlockMayHaveNull", body, BytecodeExpressions.and((BytecodeExpression)hasMaskBlock, (BytecodeExpression)maskBlock.invoke("mayHaveNull", Boolean.TYPE, new BytecodeExpression[0])));
        Variable rleValue = scope.declareVariable(ByteArrayBlock.class, "rleValue");
        body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)maskBlock.instanceOf(RunLengthEncodedBlock.class)).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)rleValue.set(maskBlock.cast(RunLengthEncodedBlock.class).invoke("getValue", ValueBlock.class, new BytecodeExpression[0]).cast(ByteArrayBlock.class))).append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.not((BytecodeExpression)AggregationMaskCompiler.testMaskBlock((BytecodeExpression)rleValue, (BytecodeExpression)maskBlockMayHaveNull, BytecodeExpressions.constantInt((int)0)))).ifTrue((BytecodeNode)BytecodeExpressions.invokeStatic(AggregationMask.class, (String)"createSelectNone", AggregationMask.class, (BytecodeExpression[])new BytecodeExpression[]{positionCount}).ret())).append((BytecodeNode)maskBlock.set(BytecodeExpressions.constantNull(Block.class))).append((BytecodeNode)hasMaskBlock.set(BytecodeExpressions.constantFalse())).append((BytecodeNode)maskBlockMayHaveNull.set(BytecodeExpressions.constantFalse()))));
        ArrayList<Variable> nonNullArgs = new ArrayList<Variable>(nonNullArgumentChannels.length);
        ArrayList<Variable> nonNullArgMayHaveNulls = new ArrayList<Variable>(nonNullArgumentChannels.length);
        for (int channel : nonNullArgumentChannels) {
            Variable arg = scope.declareVariable("arg" + channel, body, argumentsParameter.invoke("getBlock", Block.class, new BytecodeExpression[]{BytecodeExpressions.constantInt((int)channel)}));
            body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.invokeStatic(AggregationMaskCompiler.class, (String)"isAlwaysNull", Boolean.TYPE, (BytecodeExpression[])new BytecodeExpression[]{arg})).ifTrue((BytecodeNode)BytecodeExpressions.invokeStatic(AggregationMask.class, (String)"createSelectNone", AggregationMask.class, (BytecodeExpression[])new BytecodeExpression[]{positionCount}).ret()));
            Variable mayHaveNull = scope.declareVariable("arg" + channel + "MayHaveNull", body, arg.invoke("mayHaveNull", Boolean.TYPE, new BytecodeExpression[0]));
            nonNullArgs.add(arg);
            nonNullArgMayHaveNulls.add(mayHaveNull);
        }
        BytecodeExpression isSelectAll = BytecodeExpressions.not((BytecodeExpression)hasMaskBlock);
        for (Variable mayHaveNull : nonNullArgMayHaveNulls) {
            isSelectAll = BytecodeExpressions.and((BytecodeExpression)isSelectAll, (BytecodeExpression)BytecodeExpressions.not((BytecodeExpression)mayHaveNull));
        }
        body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)isSelectAll).ifTrue((BytecodeNode)BytecodeExpressions.invokeStatic(AggregationMask.class, (String)"createSelectAll", AggregationMask.class, (BytecodeExpression[])new BytecodeExpression[]{positionCount}).ret()));
        Variable selectedPositions = scope.declareVariable("selectedPositions", body, method.getThis().getField(selectedPositionsField));
        body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)selectedPositions.length(), (BytecodeExpression)positionCount)).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)selectedPositions.set(BytecodeExpressions.newArray((ParameterizedType)ParameterizedType.type(int[].class), (BytecodeExpression)positionCount))).append((BytecodeNode)method.getThis().setField(selectedPositionsField, (BytecodeExpression)selectedPositions))));
        Variable maskValueBlock = scope.declareVariable(ByteArrayBlock.class, "maskValueBlock");
        Variable position = scope.declareVariable("position", body, BytecodeExpressions.constantInt((int)0));
        BytecodeExpression isPositionSelected = AggregationMaskCompiler.testMaskBlock((BytecodeExpression)maskValueBlock, (BytecodeExpression)maskBlockMayHaveNull, (BytecodeExpression)position);
        for (int i = 0; i < nonNullArgs.size(); ++i) {
            Variable arg = (Variable)nonNullArgs.get(i);
            Variable mayHaveNull = (Variable)nonNullArgMayHaveNulls.get(i);
            isPositionSelected = BytecodeExpressions.and((BytecodeExpression)isPositionSelected, (BytecodeExpression)AggregationMaskCompiler.testPositionIsNotNull((BytecodeExpression)arg, (BytecodeExpression)mayHaveNull, (BytecodeExpression)position));
        }
        Variable selectedPositionsIndex = scope.declareVariable("selectedPositionsIndex", body, BytecodeExpressions.constantInt((int)0));
        Variable rawIds = scope.declareVariable(int[].class, "rawIds");
        Variable rawIdsOffset = scope.declareVariable(Integer.TYPE, "rawIdsOffset");
        body.append((BytecodeNode)new IfStatement().condition((BytecodeNode)maskBlock.instanceOf(DictionaryBlock.class)).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)maskValueBlock.set(maskBlock.cast(DictionaryBlock.class).invoke("getDictionary", ValueBlock.class, new BytecodeExpression[]{position}).cast(ByteArrayBlock.class))).append((BytecodeNode)rawIds.set(maskBlock.cast(DictionaryBlock.class).invoke("getRawIds", int[].class, new BytecodeExpression[0]))).append((BytecodeNode)rawIdsOffset.set(maskBlock.cast(DictionaryBlock.class).invoke("getRawIdsOffset", Integer.TYPE, new BytecodeExpression[0]))).append((BytecodeNode)new ForLoop().initialize((BytecodeNode)position.set(BytecodeExpressions.constantInt((int)0))).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)position, (BytecodeExpression)positionCount)).update((BytecodeNode)position.increment()).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)position.set(rawIds.getElement(BytecodeExpressions.add((BytecodeExpression)rawIdsOffset, (BytecodeExpression)position)))).append((BytecodeNode)new IfStatement().condition((BytecodeNode)isPositionSelected).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)selectedPositions.setElement((BytecodeExpression)selectedPositionsIndex, (BytecodeExpression)position)).append((BytecodeNode)selectedPositionsIndex.increment())))))).ifFalse((BytecodeNode)new BytecodeBlock().append((BytecodeNode)maskValueBlock.set(maskBlock.cast(ByteArrayBlock.class))).append((BytecodeNode)new ForLoop().initialize((BytecodeNode)position.set(BytecodeExpressions.constantInt((int)0))).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)position, (BytecodeExpression)positionCount)).update((BytecodeNode)position.increment()).body((BytecodeNode)new IfStatement().condition((BytecodeNode)isPositionSelected).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)selectedPositions.setElement((BytecodeExpression)selectedPositionsIndex, (BytecodeExpression)position)).append((BytecodeNode)selectedPositionsIndex.increment()))))));
        body.append((BytecodeNode)BytecodeExpressions.invokeStatic(AggregationMask.class, (String)"createSelectedPositions", AggregationMask.class, (BytecodeExpression[])new BytecodeExpression[]{positionCount, selectedPositions, selectedPositionsIndex}).ret());
        Class<AggregationMaskBuilder> builderClass = CompilerUtils.defineClass(definition, AggregationMaskBuilder.class, (Map<Long, MethodHandle>)ImmutableMap.of(), AggregationMaskCompiler.class.getClassLoader());
        try {
            return builderClass.getConstructor(new Class[0]);
        }
        catch (NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
    }

    private static BytecodeExpression testPositionIsNotNull(BytecodeExpression block, BytecodeExpression mayHaveNulls, BytecodeExpression position) {
        return BytecodeExpressions.or((BytecodeExpression)BytecodeExpressions.not((BytecodeExpression)mayHaveNulls), (BytecodeExpression)BytecodeExpressions.not((BytecodeExpression)block.invoke("isNull", Boolean.TYPE, new BytecodeExpression[]{position})));
    }

    private static BytecodeExpression testMaskBlock(BytecodeExpression block, BytecodeExpression mayHaveNulls, BytecodeExpression position) {
        Verify.verify((boolean)block.getType().equals((Object)ParameterizedType.type(ByteArrayBlock.class)));
        return BytecodeExpressions.or((BytecodeExpression)BytecodeExpressions.isNull((BytecodeExpression)block), (BytecodeExpression)BytecodeExpressions.and((BytecodeExpression)AggregationMaskCompiler.testPositionIsNotNull(block, mayHaveNulls, position), (BytecodeExpression)BytecodeExpressions.notEqual((BytecodeExpression)block.invoke("getByte", Byte.TYPE, new BytecodeExpression[]{position}).cast(Integer.TYPE), (BytecodeExpression)BytecodeExpressions.constantInt((int)0))));
    }

    @UsedByGeneratedCode
    public static boolean isAlwaysNull(Block block) {
        if (block instanceof RunLengthEncodedBlock) {
            RunLengthEncodedBlock rle = (RunLengthEncodedBlock)block;
            return rle.getValue().isNull(0);
        }
        return false;
    }
}

