/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.routine;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.MoreCollectors;
import com.google.common.primitives.Primitives;
import io.airlift.bytecode.Access;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.ClassDefinition;
import io.airlift.bytecode.DynamicClassLoader;
import io.airlift.bytecode.MethodDefinition;
import io.airlift.bytecode.Parameter;
import io.airlift.bytecode.ParameterizedType;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.DoWhileLoop;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.control.WhileLoop;
import io.airlift.bytecode.expression.BytecodeExpression;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.Constant;
import io.airlift.bytecode.instruction.LabelNode;
import io.trino.metadata.FunctionManager;
import io.trino.operator.scalar.SpecializedSqlScalarFunction;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.function.ScalarFunctionAdapter;
import io.trino.spi.function.ScalarFunctionImplementation;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.gen.BytecodeUtils;
import io.trino.sql.gen.CachedInstanceBinder;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.LambdaBytecodeGenerator;
import io.trino.sql.gen.LambdaExpressionExtractor;
import io.trino.sql.gen.RowExpressionCompiler;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.RowExpressionVisitor;
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.VariableReferenceExpression;
import io.trino.sql.routine.ir.DefaultIrNodeVisitor;
import io.trino.sql.routine.ir.IrBlock;
import io.trino.sql.routine.ir.IrBreak;
import io.trino.sql.routine.ir.IrContinue;
import io.trino.sql.routine.ir.IrIf;
import io.trino.sql.routine.ir.IrLabel;
import io.trino.sql.routine.ir.IrLoop;
import io.trino.sql.routine.ir.IrNode;
import io.trino.sql.routine.ir.IrNodeVisitor;
import io.trino.sql.routine.ir.IrRepeat;
import io.trino.sql.routine.ir.IrReturn;
import io.trino.sql.routine.ir.IrRoutine;
import io.trino.sql.routine.ir.IrSet;
import io.trino.sql.routine.ir.IrStatement;
import io.trino.sql.routine.ir.IrVariable;
import io.trino.sql.routine.ir.IrWhile;
import io.trino.util.CompilerUtils;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

public final class SqlRoutineCompiler {
    private final FunctionManager functionManager;

    public SqlRoutineCompiler(FunctionManager functionManager) {
        this.functionManager = Objects.requireNonNull(functionManager, "functionManager is null");
    }

    public SpecializedSqlScalarFunction compile(IrRoutine routine) {
        Type returnType = routine.returnType();
        List parameterTypes = (List)routine.parameters().stream().map(IrVariable::type).collect(ImmutableList.toImmutableList());
        InvocationConvention callingConvention = new InvocationConvention(Collections.nCopies(parameterTypes.size(), InvocationConvention.InvocationArgumentConvention.BOXED_NULLABLE), InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN, true, true);
        Class<?> clazz = this.compileClass(routine);
        MethodHandle handle = (MethodHandle)Arrays.stream(clazz.getMethods()).filter(method -> method.getName().equals("run")).map(Reflection::methodHandle).collect(MoreCollectors.onlyElement());
        MethodHandle instanceFactory = Reflection.constructorMethodHandle(clazz, new Class[0]);
        MethodHandle objectHandle = handle.asType(handle.type().changeParameterType(0, Object.class));
        MethodHandle objectInstanceFactory = instanceFactory.asType(instanceFactory.type().changeReturnType(Object.class));
        return invocationConvention -> {
            MethodHandle adapted = ScalarFunctionAdapter.adapt((MethodHandle)objectHandle, (Type)returnType, (List)parameterTypes, (InvocationConvention)callingConvention, (InvocationConvention)invocationConvention);
            return ScalarFunctionImplementation.builder().methodHandle(adapted).instanceFactory(objectInstanceFactory).build();
        };
    }

    @VisibleForTesting
    public Class<?> compileClass(IrRoutine routine) {
        ClassDefinition classDefinition = new ClassDefinition(Access.a((Access[])new Access[]{Access.PUBLIC, Access.FINAL}), CompilerUtils.makeClassName("SqlRoutine"), ParameterizedType.type(Object.class), new ParameterizedType[0]);
        CallSiteBinder callSiteBinder = new CallSiteBinder();
        CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);
        Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> compiledLambdaMap = this.generateMethodsForLambda(classDefinition, cachedInstanceBinder, routine);
        this.generateRunMethod(classDefinition, cachedInstanceBinder, compiledLambdaMap, routine);
        SqlRoutineCompiler.declareConstructor(classDefinition, cachedInstanceBinder);
        return CompilerUtils.defineClass(classDefinition, Object.class, callSiteBinder.getBindings(), (ClassLoader)new DynamicClassLoader(this.getClass().getClassLoader()));
    }

    private Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> generateMethodsForLambda(ClassDefinition containerClassDefinition, CachedInstanceBinder cachedInstanceBinder, IrNode node) {
        Set<LambdaDefinitionExpression> lambdaExpressions = SqlRoutineCompiler.extractLambda(node);
        ImmutableMap.Builder compiledLambdaMap = ImmutableMap.builder();
        int counter = 0;
        for (LambdaDefinitionExpression lambdaExpression : lambdaExpressions) {
            LambdaBytecodeGenerator.CompiledLambda compiledLambda = LambdaBytecodeGenerator.preGenerateLambdaExpression(lambdaExpression, "lambda_" + counter, containerClassDefinition, (Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda>)compiledLambdaMap.buildOrThrow(), cachedInstanceBinder.getCallSiteBinder(), cachedInstanceBinder, this.functionManager);
            compiledLambdaMap.put((Object)lambdaExpression, (Object)compiledLambda);
            ++counter;
        }
        return compiledLambdaMap.buildOrThrow();
    }

    private void generateRunMethod(ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder, Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> compiledLambdaMap, IrRoutine routine) {
        ImmutableList.Builder parameterBuilder = ImmutableList.builder();
        parameterBuilder.add((Object)Parameter.arg((String)"session", ConnectorSession.class));
        for (IrVariable sqlVariable : routine.parameters()) {
            parameterBuilder.add((Object)Parameter.arg((String)SqlRoutineCompiler.name(sqlVariable), (ParameterizedType)SqlRoutineCompiler.compilerType(sqlVariable.type())));
        }
        MethodDefinition method = classDefinition.declareMethod(Access.a((Access[])new Access[]{Access.PUBLIC}), "run", SqlRoutineCompiler.compilerType(routine.returnType()), (Iterable)parameterBuilder.build());
        Scope scope = method.getScope();
        scope.declareVariable(Boolean.TYPE, "wasNull");
        Map variables = (Map)VariableExtractor.extract(routine).stream().distinct().collect(ImmutableMap.toImmutableMap(Function.identity(), variable -> SqlRoutineCompiler.getOrDeclareVariable(scope, variable)));
        BytecodeVisitor visitor = new BytecodeVisitor(cachedInstanceBinder, compiledLambdaMap, variables);
        method.getBody().append((BytecodeNode)visitor.process(routine, scope));
    }

    private static BytecodeNode throwIfInterrupted() {
        return new IfStatement().condition((BytecodeNode)BytecodeExpressions.invokeStatic(Thread.class, (String)"currentThread", Thread.class, (BytecodeExpression[])new BytecodeExpression[0]).invoke("isInterrupted", Boolean.TYPE, new BytecodeExpression[0])).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)BytecodeExpressions.newInstance(RuntimeException.class, (BytecodeExpression[])new BytecodeExpression[]{BytecodeExpressions.constantString((String)"Thread interrupted")})).throwObject());
    }

    private static void declareConstructor(ClassDefinition classDefinition, CachedInstanceBinder cachedInstanceBinder) {
        MethodDefinition constructorDefinition = classDefinition.declareConstructor(Access.a((Access[])new Access[]{Access.PUBLIC}), new Parameter[0]);
        BytecodeBlock body = constructorDefinition.getBody();
        body.append((BytecodeNode)constructorDefinition.getThis()).invokeConstructor(Object.class, new Class[0]);
        cachedInstanceBinder.generateInitializations(constructorDefinition.getThis(), body);
        body.ret();
    }

    private static Variable getOrDeclareVariable(Scope scope, IrVariable variable) {
        return SqlRoutineCompiler.getOrDeclareVariable(scope, SqlRoutineCompiler.compilerType(variable.type()), SqlRoutineCompiler.name(variable));
    }

    private static Variable getOrDeclareVariable(Scope scope, ParameterizedType type, String name) {
        try {
            return scope.getVariable(name);
        }
        catch (IllegalArgumentException e) {
            return scope.declareVariable(type, name);
        }
    }

    private static ParameterizedType compilerType(Type type) {
        return ParameterizedType.type((Class)Primitives.wrap((Class)type.getJavaType()));
    }

    private static String name(IrVariable variable) {
        return SqlRoutineCompiler.name(variable.field());
    }

    private static String name(int field) {
        return "v" + field;
    }

    private static Set<LambdaDefinitionExpression> extractLambda(IrNode node) {
        final ImmutableSet.Builder expressions = ImmutableSet.builder();
        node.accept(new DefaultIrNodeVisitor(){

            @Override
            public void visitRowExpression(RowExpression expression) {
                expressions.addAll(LambdaExpressionExtractor.extractLambdaExpressions(expression));
            }
        }, null);
        return expressions.build();
    }

    private static class VariableExtractor
    extends DefaultIrNodeVisitor {
        private final List<IrVariable> variables = new ArrayList<IrVariable>();

        private VariableExtractor() {
        }

        @Override
        public Void visitVariable(IrVariable node, Void context) {
            this.variables.add(node);
            return null;
        }

        public static List<IrVariable> extract(IrNode node) {
            VariableExtractor extractor = new VariableExtractor();
            extractor.process(node, null);
            return extractor.variables;
        }
    }

    private class BytecodeVisitor
    implements IrNodeVisitor<Scope, BytecodeNode> {
        private final CachedInstanceBinder cachedInstanceBinder;
        private final Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> compiledLambdaMap;
        private final Map<IrVariable, Variable> variables;
        private final Map<IrLabel, LabelNode> continueLabels = new HashMap<IrLabel, LabelNode>();
        private final Map<IrLabel, LabelNode> breakLabels = new HashMap<IrLabel, LabelNode>();

        public BytecodeVisitor(CachedInstanceBinder cachedInstanceBinder, Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> compiledLambdaMap, Map<IrVariable, Variable> variables) {
            this.cachedInstanceBinder = Objects.requireNonNull(cachedInstanceBinder, "cachedInstanceBinder is null");
            this.compiledLambdaMap = Objects.requireNonNull(compiledLambdaMap, "compiledLambdaMap is null");
            this.variables = Objects.requireNonNull(variables, "variables is null");
        }

        @Override
        public BytecodeNode visitNode(IrNode node, Scope context) {
            throw new VerifyException("Unsupported node: " + node.getClass().getSimpleName());
        }

        @Override
        public BytecodeNode visitRoutine(IrRoutine node, Scope scope) {
            return (BytecodeNode)this.process(node.body(), scope);
        }

        @Override
        public BytecodeNode visitSet(IrSet node, Scope scope) {
            return new BytecodeBlock().append(this.compile(node.value(), scope)).putVariable(this.variables.get(node.target()));
        }

        @Override
        public BytecodeNode visitBlock(IrBlock node, Scope scope) {
            BytecodeBlock block = new BytecodeBlock();
            for (IrVariable sqlVariable : node.variables()) {
                block.append(this.compile(sqlVariable.defaultValue(), scope)).putVariable(this.variables.get(sqlVariable));
            }
            LabelNode continueLabel = new LabelNode("continue");
            LabelNode breakLabel = new LabelNode("break");
            node.label().ifPresent(label -> {
                Verify.verify((this.continueLabels.putIfAbsent((IrLabel)label, continueLabel) == null ? 1 : 0) != 0, (String)"continue label for loop label %s already exists", (Object)label);
                Verify.verify((this.breakLabels.putIfAbsent((IrLabel)label, breakLabel) == null ? 1 : 0) != 0, (String)"break label for loop label %s already exists", (Object)label);
                block.visitLabel(continueLabel);
            });
            for (IrStatement statement : node.statements()) {
                block.append((BytecodeNode)this.process(statement, scope));
            }
            if (node.label().isPresent()) {
                block.visitLabel(breakLabel);
            }
            return block;
        }

        @Override
        public BytecodeNode visitReturn(IrReturn node, Scope scope) {
            return new BytecodeBlock().append(this.compile(node.value(), scope)).ret(Primitives.wrap((Class)node.value().type().getJavaType()));
        }

        @Override
        public BytecodeNode visitContinue(IrContinue node, Scope scope) {
            LabelNode label = this.continueLabels.get(node.target());
            Verify.verify((label != null ? 1 : 0) != 0, (String)"continue target does not exist", (Object[])new Object[0]);
            return new BytecodeBlock().gotoLabel(label);
        }

        @Override
        public BytecodeNode visitBreak(IrBreak node, Scope scope) {
            LabelNode label = this.breakLabels.get(node.target());
            Verify.verify((label != null ? 1 : 0) != 0, (String)"break target does not exist", (Object[])new Object[0]);
            return new BytecodeBlock().gotoLabel(label);
        }

        @Override
        public BytecodeNode visitIf(IrIf node, Scope scope) {
            IfStatement ifStatement = new IfStatement().condition(this.compileBoolean(node.condition(), scope)).ifTrue((BytecodeNode)this.process(node.ifTrue(), scope));
            if (node.ifFalse().isPresent()) {
                ifStatement.ifFalse((BytecodeNode)this.process(node.ifFalse().get(), scope));
            }
            return ifStatement;
        }

        @Override
        public BytecodeNode visitWhile(IrWhile node, Scope scope) {
            return this.compileLoop(scope, node.label(), interruption -> new WhileLoop().condition(this.compileBoolean(node.condition(), scope)).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)interruption).append((BytecodeNode)this.process(node.body(), scope))));
        }

        @Override
        public BytecodeNode visitRepeat(IrRepeat node, Scope scope) {
            return this.compileLoop(scope, node.label(), interruption -> new DoWhileLoop().condition(BytecodeVisitor.not(this.compileBoolean(node.condition(), scope))).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)interruption).append((BytecodeNode)this.process(node.block(), scope))));
        }

        @Override
        public BytecodeNode visitLoop(IrLoop node, Scope scope) {
            return this.compileLoop(scope, node.label(), interruption -> new WhileLoop().condition((BytecodeNode)Constant.loadBoolean((boolean)true)).body((BytecodeNode)new BytecodeBlock().append((BytecodeNode)interruption).append((BytecodeNode)this.process(node.block(), scope))));
        }

        private BytecodeNode compileLoop(Scope scope, Optional<IrLabel> label, Function<BytecodeBlock, BytecodeNode> loop) {
            BytecodeBlock block = new BytecodeBlock();
            Variable interruption = scope.getOrCreateTempVariable(Integer.TYPE);
            block.putVariable(interruption, 0);
            BytecodeBlock interruptionBlock = new BytecodeBlock().append((BytecodeNode)interruption.increment()).append((BytecodeNode)new IfStatement().condition((BytecodeNode)BytecodeExpressions.greaterThanOrEqual((BytecodeExpression)interruption, (BytecodeExpression)BytecodeExpressions.constantInt((int)1000))).ifTrue((BytecodeNode)new BytecodeBlock().append((BytecodeNode)interruption.set(BytecodeExpressions.constantInt((int)0))).append(SqlRoutineCompiler.throwIfInterrupted())));
            LabelNode continueLabel = new LabelNode("continue");
            LabelNode breakLabel = new LabelNode("break");
            if (label.isPresent()) {
                Verify.verify((this.continueLabels.putIfAbsent(label.get(), continueLabel) == null ? 1 : 0) != 0, (String)"continue label for loop label %s already exists", (Object)label.get());
                Verify.verify((this.breakLabels.putIfAbsent(label.get(), breakLabel) == null ? 1 : 0) != 0, (String)"break label for loop label %s already exists", (Object)label.get());
                block.visitLabel(continueLabel);
            }
            block.append(loop.apply(interruptionBlock));
            if (label.isPresent()) {
                block.visitLabel(breakLabel);
            }
            scope.releaseTempVariableForReuse(interruption);
            return block;
        }

        private BytecodeNode compile(RowExpression expression, Scope scope) {
            if (expression instanceof InputReferenceExpression) {
                InputReferenceExpression input = (InputReferenceExpression)expression;
                return scope.getVariable(SqlRoutineCompiler.name(input.field()));
            }
            RowExpressionCompiler rowExpressionCompiler = new RowExpressionCompiler(this.cachedInstanceBinder.getCallSiteBinder(), this.cachedInstanceBinder, FieldReferenceCompiler.INSTANCE, SqlRoutineCompiler.this.functionManager, this.compiledLambdaMap);
            return new BytecodeBlock().comment("boolean wasNull = false;").putVariable(scope.getVariable("wasNull"), expression.type().getJavaType() == Void.TYPE).comment("expression: " + String.valueOf(expression)).append(rowExpressionCompiler.compile(expression, scope)).append(BytecodeUtils.boxPrimitiveIfNecessary(scope, Primitives.wrap((Class)expression.type().getJavaType())));
        }

        private BytecodeNode compileBoolean(RowExpression expression, Scope scope) {
            Preconditions.checkArgument((boolean)expression.type().equals((Object)BooleanType.BOOLEAN), (Object)"type must be boolean");
            LabelNode notNull = new LabelNode("notNull");
            LabelNode done = new LabelNode("done");
            return new BytecodeBlock().append(this.compile(expression, scope)).comment("if value is null, return false, otherwise unbox").dup().ifNotNullGoto(notNull).pop().push(false).gotoLabel(done).visitLabel(notNull).invokeVirtual(Boolean.class, "booleanValue", Boolean.TYPE, new Class[0]).visitLabel(done);
        }

        private static BytecodeNode not(BytecodeNode node) {
            LabelNode trueLabel = new LabelNode("true");
            LabelNode endLabel = new LabelNode("end");
            return new BytecodeBlock().append(node).comment("boolean not").ifTrueGoto(trueLabel).push(true).gotoLabel(endLabel).visitLabel(trueLabel).push(false).visitLabel(endLabel);
        }
    }

    private static class FieldReferenceCompiler
    implements RowExpressionVisitor<BytecodeNode, Scope> {
        public static final FieldReferenceCompiler INSTANCE = new FieldReferenceCompiler();

        private FieldReferenceCompiler() {
        }

        @Override
        public BytecodeNode visitInputReference(InputReferenceExpression node, Scope scope) {
            Class boxedType = Primitives.wrap((Class)node.type().getJavaType());
            return new BytecodeBlock().append((BytecodeNode)scope.getVariable(SqlRoutineCompiler.name(node.field()))).append((BytecodeNode)BytecodeUtils.unboxPrimitiveIfNecessary(scope, boxedType));
        }

        @Override
        public BytecodeNode visitCall(CallExpression call, Scope scope) {
            throw new UnsupportedOperationException();
        }

        @Override
        public BytecodeNode visitSpecialForm(SpecialForm specialForm, Scope context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public BytecodeNode visitConstant(ConstantExpression literal, Scope scope) {
            throw new UnsupportedOperationException();
        }

        @Override
        public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Scope context) {
            throw new UnsupportedOperationException();
        }
    }
}

