/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.gen.columnar;

import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.ForLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.control.SwitchStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.JumpInstruction;
import io.airlift.bytecode.instruction.LabelNode;
import io.airlift.slice.Slice;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.SourcePage;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.sql.gen.Binding;
import io.trino.sql.gen.BytecodeUtils;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.InCodeGenerator;
import io.trino.sql.gen.SqlTypeBytecodeExpression;
import io.trino.sql.gen.columnar.ColumnarFilter;
import io.trino.sql.gen.columnar.ColumnarFilterCompiler;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;
import io.trino.util.CompilerUtils;
import io.trino.util.FastutilSetHelper;
import java.lang.invoke.MethodHandle;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;

public class InColumnarFilterGenerator {
    private final InputReferenceExpression valueExpression;
    private final boolean useSwitchCase;
    private final Set<Object> constantValues;
    private final MethodHandle equalsMethodHandle;
    private final MethodHandle hashCodeMethodHandle;

    public InColumnarFilterGenerator(SpecialForm specialForm, FunctionManager functionManager) {
        Preconditions.checkArgument((specialForm.form() == SpecialForm.Form.IN ? 1 : 0) != 0, (Object)"specialForm should be IN");
        Preconditions.checkArgument((specialForm.arguments().size() >= 2 ? 1 : 0) != 0, (Object)"At least two arguments are required");
        if (!(specialForm.arguments().getFirst() instanceof InputReferenceExpression)) {
            throw new UnsupportedOperationException("IN clause columnar evaluation is supported only on input references");
        }
        this.valueExpression = (InputReferenceExpression)specialForm.arguments().getFirst();
        List<RowExpression> expressions = specialForm.arguments().subList(1, specialForm.arguments().size());
        expressions.forEach(expression -> {
            if (!(expression instanceof ConstantExpression)) {
                throw new UnsupportedOperationException("IN clause columnar evaluation is supported only on input reference against constants");
            }
        });
        List testExpressions = (List)expressions.stream().map(ConstantExpression.class::cast).collect(ImmutableList.toImmutableList());
        Preconditions.checkArgument((specialForm.functionDependencies().size() == 3 ? 1 : 0) != 0);
        ResolvedFunction resolvedEqualsFunction = specialForm.getOperatorDependency(OperatorType.EQUAL);
        ResolvedFunction resolvedHashCodeFunction = specialForm.getOperatorDependency(OperatorType.HASH_CODE);
        ResolvedFunction resolvedIsIndeterminate = specialForm.getOperatorDependency(OperatorType.INDETERMINATE);
        this.equalsMethodHandle = functionManager.getScalarFunctionImplementation(resolvedEqualsFunction, InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        this.hashCodeMethodHandle = functionManager.getScalarFunctionImplementation(resolvedHashCodeFunction, InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        MethodHandle indeterminateMethodHandle = functionManager.getScalarFunctionImplementation(resolvedIsIndeterminate, InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        ImmutableSet.Builder constantValuesBuilder = ImmutableSet.builder();
        for (ConstantExpression testValue : testExpressions) {
            if (!InColumnarFilterGenerator.isDeterminateConstant(testValue, indeterminateMethodHandle)) continue;
            constantValuesBuilder.add(testValue.value());
        }
        this.constantValues = constantValuesBuilder.build();
        this.useSwitchCase = InColumnarFilterGenerator.useSwitchCaseGeneration(this.valueExpression.type(), expressions);
    }

    public Supplier<ColumnarFilter> generateColumnarFilter() {
        ClassDefinition classDefinition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName(ColumnarFilter.class.getSimpleName() + "_in", Optional.empty()), ParameterizedType.type(Object.class), new ParameterizedType[]{ParameterizedType.type(ColumnarFilter.class)});
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        classDefinition.declareDefaultConstructor(Access.a((Access[])new Access[]{Access.PUBLIC}));
        ColumnarFilterCompiler.generateGetInputChannels(callSiteBinder, classDefinition, this.valueExpression);
        Set<?> constantValuesSet = FastutilSetHelper.toFastutilHashSet(this.constantValues, this.valueExpression.type(), this.hashCodeMethodHandle, this.equalsMethodHandle);
        Binding constant = callSiteBinder.bind(constantValuesSet, constantValuesSet.getClass());
        this.generateFilterRangeMethod(callSiteBinder, classDefinition, constantValuesSet, constant);
        this.generateFilterListMethod(callSiteBinder, classDefinition, constantValuesSet, constant);
        return ColumnarFilterCompiler.createClassInstance(callSiteBinder, classDefinition);
    }

    private void generateFilterRangeMethod(CallSiteBinder binder, ClassDefinition classDefinition, Set<?> constantValuesSet, Binding constant) {
        Parameter session = Parameter.arg((String)"session", ConnectorSession.class);
        Parameter outputPositions = Parameter.arg((String)"outputPositions", int[].class);
        Parameter offset = Parameter.arg((String)"offset", Integer.TYPE);
        Parameter size = Parameter.arg((String)"size", Integer.TYPE);
        Parameter page = Parameter.arg((String)"page", SourcePage.class);
        MethodDefinition method = classDefinition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC}), "filterPositionsRange", ParameterizedType.type(Integer.TYPE), (Iterable)ImmutableList.of((Object)session, (Object)outputPositions, (Object)offset, (Object)size, (Object)page));
        Scope scope = method.getScope();
        BytecodeBlock body = method.getBody();
        ColumnarFilterCompiler.declareBlockVariables((List<RowExpression>)ImmutableList.of((Object)this.valueExpression), page, scope, body);
        Variable outputPositionsCount = scope.declareVariable("outputPositionsCount", body, BytecodeExpressions.constantInt((int)0));
        Variable position = scope.declareVariable(Integer.TYPE, "position");
        Variable result = scope.declareVariable(Boolean.TYPE, "result");
        IfStatement ifStatement = new IfStatement().condition((BytecodeNode)ColumnarFilterCompiler.generateBlockMayHaveNull((List<RowExpression>)ImmutableList.of((Object)this.valueExpression), scope));
        body.append((BytecodeNode)ifStatement);
        ifStatement.ifTrue((BytecodeNode)new ForLoop("nullable range based loop", new Object[0]).initialize((BytecodeNode)position.set((BytecodeExpression)offset)).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)position, (BytecodeExpression)BytecodeExpressions.add((BytecodeExpression)offset, (BytecodeExpression)size))).update((BytecodeNode)position.increment()).body((BytecodeNode)new IfStatement().condition((BytecodeNode)ColumnarFilterCompiler.generateBlockPositionNotNull((List<RowExpression>)ImmutableList.of((Object)this.valueExpression), scope, position)).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)this.generateSetContainsCall(binder, scope, constantValuesSet, constant, (BytecodeExpression)position, result)).append((BytecodeNode)ColumnarFilterCompiler.updateOutputPositions(result, position, outputPositions, outputPositionsCount)))));
        ifStatement.ifFalse((BytecodeNode)new ForLoop("non-nullable range based loop", new Object[0]).initialize((BytecodeNode)position.set((BytecodeExpression)offset)).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)position, (BytecodeExpression)BytecodeExpressions.add((BytecodeExpression)offset, (BytecodeExpression)size))).update((BytecodeNode)position.increment()).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)this.generateSetContainsCall(binder, scope, constantValuesSet, constant, (BytecodeExpression)position, result)).append((BytecodeNode)ColumnarFilterCompiler.updateOutputPositions(result, position, outputPositions, outputPositionsCount))));
        body.append((BytecodeNode)outputPositionsCount.ret());
    }

    private void generateFilterListMethod(CallSiteBinder binder, ClassDefinition classDefinition, Set<?> constantValuesSet, Binding constant) {
        Parameter session = Parameter.arg((String)"session", ConnectorSession.class);
        Parameter outputPositions = Parameter.arg((String)"outputPositions", int[].class);
        Parameter activePositions = Parameter.arg((String)"activePositions", int[].class);
        Parameter offset = Parameter.arg((String)"offset", Integer.TYPE);
        Parameter size = Parameter.arg((String)"size", Integer.TYPE);
        Parameter page = Parameter.arg((String)"page", SourcePage.class);
        MethodDefinition method = classDefinition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC}), "filterPositionsList", ParameterizedType.type(Integer.TYPE), (Iterable)ImmutableList.of((Object)session, (Object)outputPositions, (Object)activePositions, (Object)offset, (Object)size, (Object)page));
        Scope scope = method.getScope();
        BytecodeBlock body = method.getBody();
        ColumnarFilterCompiler.declareBlockVariables((List<RowExpression>)ImmutableList.of((Object)this.valueExpression), page, scope, body);
        Variable outputPositionsCount = scope.declareVariable("outputPositionsCount", body, BytecodeExpressions.constantInt((int)0));
        Variable index = scope.declareVariable(Integer.TYPE, "index");
        Variable position = scope.declareVariable(Integer.TYPE, "position");
        Variable result = scope.declareVariable(Boolean.TYPE, "result");
        IfStatement ifStatement = new IfStatement().condition((BytecodeNode)ColumnarFilterCompiler.generateBlockMayHaveNull((List<RowExpression>)ImmutableList.of((Object)this.valueExpression), scope));
        body.append((BytecodeNode)ifStatement);
        ifStatement.ifTrue((BytecodeNode)new ForLoop("nullable positions loop", new Object[0]).initialize((BytecodeNode)index.set((BytecodeExpression)offset)).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)index, (BytecodeExpression)BytecodeExpressions.add((BytecodeExpression)offset, (BytecodeExpression)size))).update((BytecodeNode)index.increment()).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)position.set(activePositions.getElement((BytecodeExpression)index))).append((BytecodeNode)new IfStatement().condition((BytecodeNode)ColumnarFilterCompiler.generateBlockPositionNotNull((List<RowExpression>)ImmutableList.of((Object)this.valueExpression), scope, position)).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)this.generateSetContainsCall(binder, scope, constantValuesSet, constant, (BytecodeExpression)position, result)).append((BytecodeNode)ColumnarFilterCompiler.updateOutputPositions(result, position, outputPositions, outputPositionsCount))))));
        ifStatement.ifFalse((BytecodeNode)new ForLoop("non-nullable positions loop", new Object[0]).initialize((BytecodeNode)index.set((BytecodeExpression)offset)).condition((BytecodeNode)BytecodeExpressions.lessThan((BytecodeExpression)index, (BytecodeExpression)BytecodeExpressions.add((BytecodeExpression)offset, (BytecodeExpression)size))).update((BytecodeNode)index.increment()).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)position.set(activePositions.getElement((BytecodeExpression)index))).append((BytecodeNode)this.generateSetContainsCall(binder, scope, constantValuesSet, constant, (BytecodeExpression)position, result)).append((BytecodeNode)ColumnarFilterCompiler.updateOutputPositions(result, position, outputPositions, outputPositionsCount))));
        body.append((BytecodeNode)outputPositionsCount.ret());
    }

    private BytecodeBlock generateSetContainsCall(CallSiteBinder binder, Scope scope, Set<?> constantValuesSet, Binding constant, BytecodeExpression position, Variable result) {
        Type valueType = this.valueExpression.type();
        Class<Object> javaType = valueType.getJavaType();
        Class<Object> callType = javaType;
        if (!callType.isPrimitive() && callType != Slice.class) {
            callType = Object.class;
        }
        String methodName = "get" + Primitives.wrap((Class)callType).getSimpleName();
        BytecodeExpression value = SqlTypeBytecodeExpression.constantType(binder, valueType).invoke(methodName, callType, new BytecodeExpression[]{scope.getVariable("block_" + this.valueExpression.field()), position});
        if (callType != javaType) {
            value = value.cast(javaType);
        }
        if (this.useSwitchCase) {
            LabelNode end = new LabelNode("end");
            LabelNode match = new LabelNode("match");
            LabelNode defaultLabel = new LabelNode("default");
            SwitchStatement.SwitchBuilder switchBuilder = SwitchStatement.switchBuilder();
            BytecodeBlock matchBlock = new BytecodeBlock().setDescription("match").visitLabel(match).append((BytecodeNode)result.set(BytecodeExpressions.constantTrue())).gotoLabel(end);
            BytecodeBlock defaultCaseBlock = new BytecodeBlock().setDescription("default").visitLabel(defaultLabel).append((BytecodeNode)result.set(BytecodeExpressions.constantFalse())).gotoLabel(end);
            for (Object constantValue : this.constantValues) {
                switchBuilder.addCase(Math.toIntExact((Long)constantValue), (BytecodeNode)JumpInstruction.jump((LabelNode)match));
            }
            switchBuilder.defaultCase((BytecodeNode)JumpInstruction.jump((LabelNode)defaultLabel));
            Variable expression = scope.createTempVariable(javaType);
            return new BytecodeBlock().comment("lookupSwitch(<stackValue>))").append((BytecodeNode)expression.set(value)).append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.invokeStatic(InCodeGenerator.class, (String)"isInteger", Boolean.TYPE, (BytecodeExpression[])new BytecodeExpression[]{expression})).ifFalse((BytecodeNode)new BytecodeBlock().gotoLabel(defaultLabel))).append((BytecodeNode)switchBuilder.expression(expression.cast(Integer.TYPE)).build()).append((BytecodeNode)matchBlock).append((BytecodeNode)defaultCaseBlock).visitLabel(end);
        }
        return new BytecodeBlock().comment("inListSet.contains(<stackValue>)").append((BytecodeNode)new BytecodeBlock().comment("value").append((BytecodeNode)value).comment("set").append((BytecodeNode)BytecodeUtils.loadConstant(constant)).invokeStatic(FastutilSetHelper.class, "in", Boolean.TYPE, new Class[]{javaType.isPrimitive() ? javaType : Object.class, constantValuesSet.getClass()}).putVariable(result));
    }

    private static boolean isDeterminateConstant(RowExpression expression, MethodHandle isIndeterminateFunction) {
        if (!(expression instanceof ConstantExpression)) {
            return false;
        }
        ConstantExpression constantExpression = (ConstantExpression)expression;
        Object value = constantExpression.value();
        if (value == null) {
            return false;
        }
        try {
            return !isIndeterminateFunction.invoke(value);
        }
        catch (Throwable t) {
            Throwables.throwIfUnchecked((Throwable)t);
            throw new RuntimeException(t);
        }
    }

    static boolean useSwitchCaseGeneration(Type type, List<RowExpression> values) {
        if (!type.getTypeParameters().isEmpty()) {
            throw new UnsupportedOperationException("Structural type not supported");
        }
        if (values.size() >= 8) {
            return false;
        }
        if (type.getJavaType() != Long.TYPE) {
            return false;
        }
        for (RowExpression expression : values) {
            long longConstant;
            if (!(expression instanceof ConstantExpression)) {
                throw new UnsupportedOperationException("IN clause columnar evaluation is supported only on input reference against constants");
            }
            Object constant = ((ConstantExpression)expression).value();
            if (constant == null || (longConstant = ((Number)constant).longValue()) >= Integer.MIN_VALUE && longConstant <= Integer.MAX_VALUE) continue;
            return false;
        }
        return true;
    }
}

