/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.gen;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.control.SwitchStatement;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.JumpInstruction;
import io.airlift.bytecode.instruction.LabelNode;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.sql.gen.Binding;
import io.trino.sql.gen.BytecodeGenerator;
import io.trino.sql.gen.BytecodeGeneratorContext;
import io.trino.sql.gen.BytecodeUtils;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;
import io.trino.util.FastutilSetHelper;
import java.lang.invoke.MethodHandle;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class InCodeGenerator
implements BytecodeGenerator {
    private final RowExpression valueExpression;
    private final List<RowExpression> testExpressions;
    private final ResolvedFunction resolvedEqualsFunction;
    private final ResolvedFunction resolvedHashCodeFunction;
    private final ResolvedFunction resolvedIsIndeterminate;

    public InCodeGenerator(SpecialForm specialForm) {
        Preconditions.checkArgument((specialForm.arguments().size() >= 2 ? 1 : 0) != 0, (Object)"At least two arguments are required");
        this.valueExpression = specialForm.arguments().get(0);
        this.testExpressions = specialForm.arguments().subList(1, specialForm.arguments().size());
        Preconditions.checkArgument((specialForm.functionDependencies().size() == 3 ? 1 : 0) != 0);
        this.resolvedEqualsFunction = specialForm.getOperatorDependency(OperatorType.EQUAL);
        this.resolvedHashCodeFunction = specialForm.getOperatorDependency(OperatorType.HASH_CODE);
        this.resolvedIsIndeterminate = specialForm.getOperatorDependency(OperatorType.INDETERMINATE);
    }

    @VisibleForTesting
    static SwitchGenerationCase checkSwitchGenerationCase(Type type, List<RowExpression> values) {
        if (values.size() >= 8) {
            return SwitchGenerationCase.SET_CONTAINS;
        }
        if (type.getJavaType() != Long.TYPE) {
            return SwitchGenerationCase.HASH_SWITCH;
        }
        for (RowExpression expression : values) {
            long longConstant;
            ConstantExpression constantExpression;
            Object constant;
            if (!(expression instanceof ConstantExpression) || (constant = (constantExpression = (ConstantExpression)expression).value()) == null || (longConstant = ((Number)constant).longValue()) >= Integer.MIN_VALUE && longConstant <= Integer.MAX_VALUE) continue;
            return SwitchGenerationCase.HASH_SWITCH;
        }
        return SwitchGenerationCase.DIRECT_SWITCH;
    }

    @Override
    public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext) {
        Type type = this.valueExpression.type();
        Class<Object> javaType = type.getJavaType();
        SwitchGenerationCase switchGenerationCase = InCodeGenerator.checkSwitchGenerationCase(type, this.testExpressions);
        MethodHandle equalsMethodHandle = generatorContext.getScalarFunctionImplementation(this.resolvedEqualsFunction, InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        MethodHandle hashCodeMethodHandle = generatorContext.getScalarFunctionImplementation(this.resolvedHashCodeFunction, InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        MethodHandle indeterminateMethodHandle = generatorContext.getScalarFunctionImplementation(this.resolvedIsIndeterminate, InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL})).getMethodHandle();
        ImmutableListMultimap.Builder hashBucketsBuilder = ImmutableListMultimap.builder();
        ImmutableList.Builder defaultBucket = ImmutableList.builder();
        ImmutableSet.Builder constantValuesBuilder = ImmutableSet.builder();
        for (RowExpression testValue : this.testExpressions) {
            BytecodeNode testBytecode = generatorContext.generate(testValue);
            if (InCodeGenerator.isDeterminateConstant(testValue, indeterminateMethodHandle)) {
                ConstantExpression constant = (ConstantExpression)testValue;
                Object object = constant.value();
                switch (switchGenerationCase.ordinal()) {
                    case 0: 
                    case 2: {
                        constantValuesBuilder.add(object);
                        break;
                    }
                    case 1: {
                        try {
                            int hashCode = Long.hashCode(hashCodeMethodHandle.invoke(object));
                            hashBucketsBuilder.put((Object)hashCode, (Object)testBytecode);
                            break;
                        }
                        catch (Throwable throwable) {
                            throw new IllegalArgumentException("Error processing IN statement: error calculating hash code for " + String.valueOf(object), throwable);
                        }
                    }
                    default: {
                        throw new IllegalArgumentException("Not supported switch generation case: " + String.valueOf((Object)switchGenerationCase));
                    }
                }
                continue;
            }
            defaultBucket.add((Object)testBytecode);
        }
        ImmutableListMultimap hashBuckets = hashBucketsBuilder.build();
        ImmutableSet constantValues = constantValuesBuilder.build();
        LabelNode end = new LabelNode("end");
        LabelNode match = new LabelNode("match");
        LabelNode noMatch = new LabelNode("noMatch");
        LabelNode defaultLabel = new LabelNode("default");
        Scope scope = generatorContext.getScope();
        Variable value = scope.getOrCreateTempVariable((Class)javaType);
        Variable expression = scope.getOrCreateTempVariable(Integer.TYPE);
        SwitchStatement.SwitchBuilder switchBuilder = new SwitchStatement.SwitchBuilder().expression((BytecodeExpression)expression);
        BytecodeBlock switchBlock = switch (switchGenerationCase.ordinal()) {
            case 0 -> {
                for (Object constantValue : constantValues) {
                    switchBuilder.addCase(Math.toIntExact((Long)constantValue), (BytecodeNode)JumpInstruction.jump((LabelNode)match));
                }
                switchBuilder.defaultCase((BytecodeNode)JumpInstruction.jump((LabelNode)defaultLabel));
                yield new BytecodeBlock().comment("lookupSwitch(<stackValue>))").append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.invokeStatic(InCodeGenerator.class, (String)"isInteger", Boolean.TYPE, (BytecodeExpression[])new BytecodeExpression[]{value})).ifFalse((BytecodeNode)new BytecodeBlock().gotoLabel(defaultLabel))).append((BytecodeNode)expression.set(value.cast(Integer.TYPE))).append((BytecodeNode)switchBuilder.build());
            }
            case 1 -> {
                for (Map.Entry bucket : hashBuckets.asMap().entrySet()) {
                    Collection testValues = (Collection)bucket.getValue();
                    BytecodeBlock caseBlock = InCodeGenerator.buildInCase(generatorContext, scope, this.resolvedEqualsFunction, match, defaultLabel, value, testValues, false, this.resolvedIsIndeterminate);
                    switchBuilder.addCase(((Integer)bucket.getKey()).intValue(), (BytecodeNode)caseBlock);
                }
                switchBuilder.defaultCase((BytecodeNode)JumpInstruction.jump((LabelNode)defaultLabel));
                Binding hashCodeBinding = generatorContext.getCallSiteBinder().bind(hashCodeMethodHandle);
                yield new BytecodeBlock().comment("lookupSwitch(hashCode(<stackValue>))").getVariable(value).append((BytecodeNode)BytecodeUtils.invoke(hashCodeBinding, this.resolvedHashCodeFunction.signature())).invokeStatic(Long.class, "hashCode", Integer.TYPE, new Class[]{Long.TYPE}).putVariable(expression).append((BytecodeNode)switchBuilder.build());
            }
            case 2 -> {
                Set<?> constantValuesSet = FastutilSetHelper.toFastutilHashSet(constantValues, type, hashCodeMethodHandle, equalsMethodHandle);
                Binding constant = generatorContext.getCallSiteBinder().bind(constantValuesSet, constantValuesSet.getClass());
                yield new BytecodeBlock().comment("inListSet.contains(<stackValue>)").append((BytecodeNode)new IfStatement().condition((BytecodeNode)new BytecodeBlock().comment("value").getVariable(value).comment("set").append((BytecodeNode)BytecodeUtils.loadConstant(constant)).invokeStatic(FastutilSetHelper.class, "in", Boolean.TYPE, new Class[]{javaType.isPrimitive() ? javaType : Object.class, constantValuesSet.getClass()})).ifTrue((BytecodeNode)JumpInstruction.jump((LabelNode)match)));
            }
            default -> throw new IllegalArgumentException("Not supported switch generation case: " + String.valueOf((Object)switchGenerationCase));
        };
        BytecodeBlock defaultCaseBlock = InCodeGenerator.buildInCase(generatorContext, scope, this.resolvedEqualsFunction, match, noMatch, value, (Collection<BytecodeNode>)defaultBucket.build(), true, this.resolvedIsIndeterminate).setDescription("default");
        BytecodeBlock block = new BytecodeBlock().comment("IN").append(generatorContext.generate(this.valueExpression)).append(BytecodeUtils.ifWasNullPopAndGoto(scope, end, Boolean.TYPE, javaType)).putVariable(value).append((BytecodeNode)switchBlock).visitLabel(defaultLabel).append((BytecodeNode)defaultCaseBlock);
        BytecodeBlock matchBlock = new BytecodeBlock().setDescription("match").visitLabel(match).append((BytecodeNode)generatorContext.wasNull().set(BytecodeExpressions.constantFalse())).push(true).gotoLabel(end);
        block.append((BytecodeNode)matchBlock);
        BytecodeBlock noMatchBlock = new BytecodeBlock().setDescription("noMatch").visitLabel(noMatch).push(false).gotoLabel(end);
        block.append((BytecodeNode)noMatchBlock);
        block.visitLabel(end);
        scope.releaseTempVariableForReuse(expression);
        scope.releaseTempVariableForReuse(value);
        return block;
    }

    public static boolean isInteger(long value) {
        return value == (long)((int)value);
    }

    private static BytecodeBlock buildInCase(BytecodeGeneratorContext generatorContext, Scope scope, ResolvedFunction equals, LabelNode matchLabel, LabelNode noMatchLabel, Variable value, Collection<BytecodeNode> testValues, boolean checkForNulls, ResolvedFunction isIndeterminateFunction) {
        Variable caseWasNull = null;
        if (checkForNulls) {
            caseWasNull = scope.getOrCreateTempVariable(Boolean.TYPE);
        }
        BytecodeBlock caseBlock = new BytecodeBlock();
        if (checkForNulls) {
            caseBlock.putVariable(caseWasNull, false);
        }
        LabelNode elseLabel = new LabelNode("else");
        BytecodeBlock elseBlock = new BytecodeBlock().visitLabel(elseLabel);
        Variable wasNull = generatorContext.wasNull();
        if (checkForNulls) {
            if (testValues.isEmpty()) {
                elseBlock.append((BytecodeNode)new BytecodeBlock().append(generatorContext.generateCall(isIndeterminateFunction, (List<BytecodeNode>)ImmutableList.of((Object)value))).putVariable(wasNull));
            } else {
                elseBlock.append((BytecodeNode)wasNull.set((BytecodeExpression)caseWasNull));
            }
        }
        elseBlock.gotoLabel(noMatchLabel);
        BytecodeBlock elseNode = elseBlock;
        for (BytecodeNode testNode : testValues) {
            LabelNode testLabel = new LabelNode("test");
            IfStatement test = new IfStatement();
            BytecodeNode equalsCall = generatorContext.generateCall(equals, (List<BytecodeNode>)ImmutableList.of((Object)value, (Object)testNode));
            test.condition().visitLabel(testLabel).append(equalsCall);
            if (checkForNulls) {
                IfStatement wasNullCheck = new IfStatement("if wasNull, set caseWasNull to true, clear wasNull, pop boolean, and goto next test value", new Object[0]);
                wasNullCheck.condition((BytecodeNode)wasNull);
                wasNullCheck.ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)caseWasNull.set(BytecodeExpressions.constantTrue())).append((BytecodeNode)wasNull.set(BytecodeExpressions.constantFalse())).pop(Boolean.TYPE).gotoLabel(elseLabel));
                test.condition().append((BytecodeNode)wasNullCheck);
            }
            test.ifTrue().gotoLabel(matchLabel);
            test.ifFalse((BytecodeNode)elseNode);
            elseNode = test;
            elseLabel = testLabel;
        }
        caseBlock.append((BytecodeNode)elseNode);
        if (checkForNulls) {
            scope.releaseTempVariableForReuse(caseWasNull);
        }
        return caseBlock;
    }

    private static boolean isDeterminateConstant(RowExpression expression, MethodHandle isIndeterminateFunction) {
        if (!(expression instanceof ConstantExpression)) {
            return false;
        }
        ConstantExpression constantExpression = (ConstantExpression)expression;
        Object value = constantExpression.value();
        if (value == null) {
            return false;
        }
        try {
            return !isIndeterminateFunction.invoke(value);
        }
        catch (Throwable t) {
            Throwables.throwIfUnchecked((Throwable)t);
            throw new RuntimeException(t);
        }
    }

    static enum SwitchGenerationCase {
        DIRECT_SWITCH,
        HASH_SWITCH,
        SET_CONTAINS;

    }
}

