/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.gen;

import com.google.common.base.Preconditions;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.Constant;
import io.trino.metadata.FunctionManager;
import io.trino.spi.type.Type;
import io.trino.sql.gen.AndCodeGenerator;
import io.trino.sql.gen.BetweenCodeGenerator;
import io.trino.sql.gen.BindCodeGenerator;
import io.trino.sql.gen.Binding;
import io.trino.sql.gen.BytecodeGenerator;
import io.trino.sql.gen.BytecodeGeneratorContext;
import io.trino.sql.gen.BytecodeUtils;
import io.trino.sql.gen.CachedInstanceBinder;
import io.trino.sql.gen.CallSiteBinder;
import io.trino.sql.gen.CoalesceCodeGenerator;
import io.trino.sql.gen.DereferenceCodeGenerator;
import io.trino.sql.gen.IfCodeGenerator;
import io.trino.sql.gen.InCodeGenerator;
import io.trino.sql.gen.IsNullCodeGenerator;
import io.trino.sql.gen.LambdaBytecodeGenerator;
import io.trino.sql.gen.NullIfCodeGenerator;
import io.trino.sql.gen.OrCodeGenerator;
import io.trino.sql.gen.RowConstructorCodeGenerator;
import io.trino.sql.gen.SwitchCodeGenerator;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.RowExpressionVisitor;
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.VariableReferenceExpression;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class RowExpressionCompiler {
    private final CallSiteBinder callSiteBinder;
    private final CachedInstanceBinder cachedInstanceBinder;
    private final RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler;
    private final FunctionManager functionManager;
    private final Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> compiledLambdaMap;
    private static final String TEMP_PREFIX = "$$TEMP$$";

    public RowExpressionCompiler(CallSiteBinder callSiteBinder, CachedInstanceBinder cachedInstanceBinder, RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler, FunctionManager functionManager, Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> compiledLambdaMap) {
        this.callSiteBinder = callSiteBinder;
        this.cachedInstanceBinder = cachedInstanceBinder;
        this.fieldReferenceCompiler = fieldReferenceCompiler;
        this.functionManager = functionManager;
        this.compiledLambdaMap = compiledLambdaMap;
    }

    public BytecodeNode compile(RowExpression rowExpression, Scope scope) {
        return this.compile(rowExpression, scope, Optional.empty());
    }

    public BytecodeNode compile(RowExpression rowExpression, Scope scope, Optional<Class<?>> lambdaInterface) {
        return rowExpression.accept(new Visitor(), new Context(scope, lambdaInterface));
    }

    public static VariableReferenceExpression createTempVariableReferenceExpression(Variable variable, Type type) {
        return new VariableReferenceExpression(TEMP_PREFIX + variable.getName(), type);
    }

    private class Visitor
    implements RowExpressionVisitor<BytecodeNode, Context> {
        private Visitor() {
        }

        @Override
        public BytecodeNode visitCall(CallExpression call, Context context) {
            BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext(RowExpressionCompiler.this, context.getScope(), RowExpressionCompiler.this.callSiteBinder, RowExpressionCompiler.this.cachedInstanceBinder, RowExpressionCompiler.this.functionManager);
            return generatorContext.generateFullCall(call.getResolvedFunction(), call.getArguments());
        }

        @Override
        public BytecodeNode visitSpecialForm(SpecialForm specialForm, Context context) {
            BytecodeGenerator generator = switch (specialForm.getForm()) {
                case SpecialForm.Form.IF -> new IfCodeGenerator(specialForm);
                case SpecialForm.Form.NULL_IF -> new NullIfCodeGenerator(specialForm);
                case SpecialForm.Form.SWITCH -> new SwitchCodeGenerator(specialForm);
                case SpecialForm.Form.BETWEEN -> new BetweenCodeGenerator(specialForm);
                case SpecialForm.Form.IS_NULL -> new IsNullCodeGenerator(specialForm);
                case SpecialForm.Form.COALESCE -> new CoalesceCodeGenerator(specialForm);
                case SpecialForm.Form.IN -> new InCodeGenerator(specialForm);
                case SpecialForm.Form.AND -> new AndCodeGenerator(specialForm);
                case SpecialForm.Form.OR -> new OrCodeGenerator(specialForm);
                case SpecialForm.Form.DEREFERENCE -> new DereferenceCodeGenerator(specialForm);
                case SpecialForm.Form.ROW_CONSTRUCTOR -> new RowConstructorCodeGenerator(specialForm);
                case SpecialForm.Form.BIND -> new BindCodeGenerator(specialForm, RowExpressionCompiler.this.compiledLambdaMap, context.getLambdaInterface().get());
                default -> throw new IllegalStateException("Cannot compile special form: " + String.valueOf((Object)specialForm.getForm()));
            };
            BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext(RowExpressionCompiler.this, context.getScope(), RowExpressionCompiler.this.callSiteBinder, RowExpressionCompiler.this.cachedInstanceBinder, RowExpressionCompiler.this.functionManager);
            return generator.generateExpression(generatorContext);
        }

        @Override
        public BytecodeNode visitConstant(ConstantExpression constant, Context context) {
            Object value = constant.getValue();
            Class javaType = constant.getType().getJavaType();
            BytecodeBlock block = new BytecodeBlock();
            if (value == null) {
                return block.comment("constant null").append((BytecodeNode)context.getScope().getVariable("wasNull").set(BytecodeExpressions.constantTrue())).pushJavaDefault(javaType);
            }
            block.comment("constant " + String.valueOf(constant.getType().getTypeSignature()));
            if (javaType == Boolean.TYPE) {
                return block.append((BytecodeNode)Constant.loadBoolean((boolean)((Boolean)value)));
            }
            if (javaType == Long.TYPE) {
                return block.append((BytecodeNode)Constant.loadLong((long)((Long)value)));
            }
            if (javaType == Double.TYPE) {
                return block.append((BytecodeNode)Constant.loadDouble((double)((Double)value)));
            }
            if (javaType == String.class) {
                return block.append((BytecodeNode)Constant.loadString((String)((String)value)));
            }
            Binding binding = RowExpressionCompiler.this.callSiteBinder.bind(value, constant.getType().getJavaType());
            return new BytecodeBlock().setDescription("constant " + String.valueOf(constant.getType())).comment(constant.toString()).append((BytecodeNode)BytecodeUtils.loadConstant(binding));
        }

        @Override
        public BytecodeNode visitInputReference(InputReferenceExpression node, Context context) {
            return RowExpressionCompiler.this.fieldReferenceCompiler.visitInputReference(node, context.getScope());
        }

        @Override
        public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Context context) {
            Preconditions.checkState((boolean)RowExpressionCompiler.this.compiledLambdaMap.containsKey(lambda), (Object)"lambda expressions map does not contain this lambda definition");
            if (!context.lambdaInterface.get().isAnnotationPresent(FunctionalInterface.class)) {
                throw new VerifyException("lambda should be generated as class annotated with FunctionalInterface");
            }
            BytecodeGeneratorContext generatorContext = new BytecodeGeneratorContext(RowExpressionCompiler.this, context.getScope(), RowExpressionCompiler.this.callSiteBinder, RowExpressionCompiler.this.cachedInstanceBinder, RowExpressionCompiler.this.functionManager);
            return LambdaBytecodeGenerator.generateLambda(generatorContext, (List<RowExpression>)ImmutableList.of(), RowExpressionCompiler.this.compiledLambdaMap.get(lambda), context.getLambdaInterface().get());
        }

        @Override
        public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Context context) {
            if (reference.getName().startsWith(RowExpressionCompiler.TEMP_PREFIX)) {
                return context.getScope().getTempVariable(reference.getName().substring(RowExpressionCompiler.TEMP_PREFIX.length()));
            }
            return RowExpressionCompiler.this.fieldReferenceCompiler.visitVariableReference(reference, context.getScope());
        }
    }

    private static class Context {
        private final Scope scope;
        private final Optional<Class<?>> lambdaInterface;

        public Context(Scope scope, Optional<Class<?>> lambdaInterface) {
            this.scope = scope;
            this.lambdaInterface = lambdaInterface;
        }

        public Scope getScope() {
            return this.scope;
        }

        public Optional<Class<?>> getLambdaInterface() {
            return this.lambdaInterface;
        }
    }
}

