/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.routine;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.Analysis;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.analyzer.Field;
import io.trino.sql.analyzer.RelationId;
import io.trino.sql.analyzer.RelationType;
import io.trino.sql.analyzer.Scope;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.LogicalPlanner;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TranslationMap;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.rule.LambdaCaptureDesugaringRewriter;
import io.trino.sql.planner.sanity.SugarFreeChecker;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.Expressions;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SqlToRowExpressionTranslator;
import io.trino.sql.relational.StandardFunctionResolution;
import io.trino.sql.relational.optimizer.ExpressionOptimizer;
import io.trino.sql.routine.SqlRoutineAnalysis;
import io.trino.sql.routine.ir.IrBlock;
import io.trino.sql.routine.ir.IrBreak;
import io.trino.sql.routine.ir.IrContinue;
import io.trino.sql.routine.ir.IrIf;
import io.trino.sql.routine.ir.IrLabel;
import io.trino.sql.routine.ir.IrLoop;
import io.trino.sql.routine.ir.IrRepeat;
import io.trino.sql.routine.ir.IrReturn;
import io.trino.sql.routine.ir.IrRoutine;
import io.trino.sql.routine.ir.IrSet;
import io.trino.sql.routine.ir.IrStatement;
import io.trino.sql.routine.ir.IrVariable;
import io.trino.sql.routine.ir.IrWhile;
import io.trino.sql.tree.AssignmentStatement;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.CaseStatement;
import io.trino.sql.tree.CaseStatementWhenClause;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.CompoundStatement;
import io.trino.sql.tree.ControlStatement;
import io.trino.sql.tree.ElseClause;
import io.trino.sql.tree.ElseIfClause;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionSpecification;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.IfStatement;
import io.trino.sql.tree.IterateStatement;
import io.trino.sql.tree.LambdaArgumentDeclaration;
import io.trino.sql.tree.LeaveStatement;
import io.trino.sql.tree.LoopStatement;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.RepeatStatement;
import io.trino.sql.tree.ReturnStatement;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.VariableDeclaration;
import io.trino.sql.tree.WhileStatement;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public final class SqlRoutinePlanner {
    private final PlannerContext plannerContext;
    private final WarningCollector warningCollector;

    public SqlRoutinePlanner(PlannerContext plannerContext, WarningCollector warningCollector) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.warningCollector = Objects.requireNonNull(warningCollector, "warningCollector is null");
    }

    public IrRoutine planSqlFunction(Session session, FunctionSpecification function, SqlRoutineAnalysis routineAnalysis) {
        ArrayList<IrVariable> allVariables = new ArrayList<IrVariable>();
        LinkedHashMap<String, IrVariable> scopeVariables = new LinkedHashMap<String, IrVariable>();
        ImmutableList.Builder parameters = ImmutableList.builder();
        routineAnalysis.arguments().forEach((name, type) -> {
            IrVariable variable = new IrVariable(allVariables.size(), (Type)type, Expressions.constantNull(type));
            allVariables.add(variable);
            scopeVariables.put((String)name, variable);
            parameters.add((Object)variable);
        });
        Analysis analysis = routineAnalysis.analysis();
        StatementVisitor visitor = new StatementVisitor(session, allVariables, analysis);
        IrStatement body = (IrStatement)visitor.process((Node)function.getStatement(), new Context(scopeVariables, Map.of()));
        return new IrRoutine(routineAnalysis.returnType(), (List<IrVariable>)parameters.build(), body);
    }

    private class StatementVisitor
    extends AstVisitor<IrStatement, Context> {
        private final Session session;
        private final List<IrVariable> allVariables;
        private final Analysis analysis;
        private final StandardFunctionResolution resolution;

        public StatementVisitor(Session session, List<IrVariable> allVariables, Analysis analysis) {
            this.session = Objects.requireNonNull(session, "session is null");
            this.resolution = new StandardFunctionResolution(SqlRoutinePlanner.this.plannerContext.getMetadata());
            this.allVariables = Objects.requireNonNull(allVariables, "allVariables is null");
            this.analysis = Objects.requireNonNull(analysis, "analysis is null");
        }

        protected IrStatement visitNode(Node node, Context context) {
            throw new UnsupportedOperationException("Not implemented: " + node);
        }

        protected IrStatement visitCompoundStatement(CompoundStatement node, Context context) {
            Context newContext = context.newScope();
            ImmutableList.Builder blockVariables = ImmutableList.builder();
            for (VariableDeclaration declaration : node.getVariableDeclarations()) {
                Type type = this.analysis.getType((Expression)declaration.getType());
                RowExpression defaultValue = declaration.getDefaultValue().map(expression -> this.toRowExpression(newContext, (Expression)expression)).orElse(Expressions.constantNull(type));
                for (Identifier name : declaration.getNames()) {
                    IrVariable variable = new IrVariable(this.allVariables.size(), type, defaultValue);
                    this.allVariables.add(variable);
                    Verify.verify((newContext.variables().put(StatementVisitor.identifierValue(name), variable) == null ? 1 : 0) != 0, (String)"Variable already declared in scope: %s", (Object)name);
                    blockVariables.add((Object)variable);
                }
            }
            List statements = (List)node.getStatements().stream().map(statement -> (IrStatement)this.process((Node)statement, newContext)).collect(ImmutableList.toImmutableList());
            return new IrBlock((List<IrVariable>)blockVariables.build(), statements);
        }

        protected IrStatement visitIfStatement(IfStatement node, Context context) {
            IrIf statement = null;
            List elseIfList = Lists.reverse((List)node.getElseIfClauses());
            for (int i = 0; i < elseIfList.size(); ++i) {
                ElseIfClause elseIf = (ElseIfClause)elseIfList.get(i);
                RowExpression condition = this.toRowExpression(context, elseIf.getExpression());
                IrBlock ifTrue = StatementVisitor.block(this.statements(elseIf.getStatements(), context));
                Optional<IrStatement> ifFalse = Optional.empty();
                if (i == 0 && node.getElseClause().isPresent()) {
                    List elseList = ((ElseClause)node.getElseClause().get()).getStatements();
                    ifFalse = Optional.of(StatementVisitor.block(this.statements(elseList, context)));
                } else if (statement != null) {
                    ifFalse = Optional.of(statement);
                }
                statement = new IrIf(condition, ifTrue, ifFalse);
            }
            return new IrIf(this.toRowExpression(context, node.getExpression()), StatementVisitor.block(this.statements(node.getStatements(), context)), Optional.ofNullable(statement));
        }

        protected IrStatement visitCaseStatement(CaseStatement node, Context context) {
            if (node.getExpression().isPresent()) {
                RowExpression valueExpression = this.toRowExpression(context, (Expression)node.getExpression().get());
                IrVariable valueVariable = new IrVariable(this.allVariables.size(), valueExpression.getType(), valueExpression);
                IrStatement statement = node.getElseClause().map(elseClause -> StatementVisitor.block(this.statements(elseClause.getStatements(), context))).orElseGet(() -> new IrBlock((List<IrVariable>)ImmutableList.of(), (List<IrStatement>)ImmutableList.of()));
                for (CaseStatementWhenClause whenClause : Lists.reverse((List)node.getWhenClauses())) {
                    RowExpression conditionValue = this.toRowExpression(context, whenClause.getExpression());
                    RowExpression testValue = Expressions.field(valueVariable.field(), valueVariable.type());
                    if (!((RowExpression)testValue).getType().equals(conditionValue.getType())) {
                        ResolvedFunction castFunction = SqlRoutinePlanner.this.plannerContext.getMetadata().getCoercion(((RowExpression)testValue).getType(), conditionValue.getType());
                        testValue = Expressions.call(castFunction, testValue);
                    }
                    ResolvedFunction equals = this.resolution.comparisonFunction(ComparisonExpression.Operator.EQUAL, ((RowExpression)testValue).getType(), conditionValue.getType());
                    CallExpression condition = Expressions.call(equals, testValue, conditionValue);
                    IrBlock ifTrue = StatementVisitor.block(this.statements(whenClause.getStatements(), context));
                    statement = new IrIf(condition, ifTrue, Optional.of(statement));
                }
                return new IrBlock((List<IrVariable>)ImmutableList.of((Object)valueVariable), (List<IrStatement>)ImmutableList.of((Object)statement));
            }
            IrStatement statement = node.getElseClause().map(elseClause -> StatementVisitor.block(this.statements(elseClause.getStatements(), context))).orElseGet(() -> new IrBlock((List<IrVariable>)ImmutableList.of(), (List<IrStatement>)ImmutableList.of()));
            for (CaseStatementWhenClause whenClause : Lists.reverse((List)node.getWhenClauses())) {
                RowExpression condition = this.toRowExpression(context, whenClause.getExpression());
                IrBlock ifTrue = StatementVisitor.block(this.statements(whenClause.getStatements(), context));
                statement = new IrIf(condition, ifTrue, Optional.of(statement));
            }
            return statement;
        }

        protected IrStatement visitWhileStatement(WhileStatement node, Context context) {
            Context newContext = context.newScope();
            Optional<IrLabel> label = StatementVisitor.getSqlLabel(newContext, node.getLabel());
            RowExpression condition = this.toRowExpression(newContext, node.getExpression());
            List<IrStatement> statements = this.statements(node.getStatements(), newContext);
            return new IrWhile(label, condition, StatementVisitor.block(statements));
        }

        protected IrStatement visitRepeatStatement(RepeatStatement node, Context context) {
            Context newContext = context.newScope();
            Optional<IrLabel> label = StatementVisitor.getSqlLabel(newContext, node.getLabel());
            RowExpression condition = this.toRowExpression(newContext, node.getCondition());
            List<IrStatement> statements = this.statements(node.getStatements(), newContext);
            return new IrRepeat(label, condition, StatementVisitor.block(statements));
        }

        protected IrStatement visitLoopStatement(LoopStatement node, Context context) {
            Context newContext = context.newScope();
            Optional<IrLabel> label = StatementVisitor.getSqlLabel(newContext, node.getLabel());
            List<IrStatement> statements = this.statements(node.getStatements(), newContext);
            return new IrLoop(label, StatementVisitor.block(statements));
        }

        protected IrStatement visitReturnStatement(ReturnStatement node, Context context) {
            return new IrReturn(this.toRowExpression(context, node.getValue()));
        }

        protected IrStatement visitAssignmentStatement(AssignmentStatement node, Context context) {
            Identifier name = node.getTarget();
            IrVariable target = context.variables().get(StatementVisitor.identifierValue(name));
            Preconditions.checkArgument((target != null ? 1 : 0) != 0, (String)"Variable not declared in scope: %s", (Object)name);
            return new IrSet(target, this.toRowExpression(context, node.getValue()));
        }

        protected IrStatement visitIterateStatement(IterateStatement node, Context context) {
            return new IrContinue(StatementVisitor.label(context, node.getLabel()));
        }

        protected IrStatement visitLeaveStatement(LeaveStatement node, Context context) {
            return new IrBreak(StatementVisitor.label(context, node.getLabel()));
        }

        private static Optional<IrLabel> getSqlLabel(Context context, Optional<Identifier> labelName) {
            return labelName.map(name -> {
                IrLabel label = new IrLabel(StatementVisitor.identifierValue(name));
                Verify.verify((context.labels().put(StatementVisitor.identifierValue(name), label) == null ? 1 : 0) != 0, (String)"Label already declared in this scope: %s", (Object)name);
                return label;
            });
        }

        private static IrLabel label(Context context, Identifier name) {
            IrLabel label = context.labels().get(StatementVisitor.identifierValue(name));
            Preconditions.checkArgument((label != null ? 1 : 0) != 0, (String)"Label not defined: %s", (Object)name);
            return label;
        }

        private RowExpression toRowExpression(Context context, Expression expression) {
            TypeProvider typeProvider = TypeProvider.viewOf((Map)context.variables().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> new Symbol((String)entry.getKey()), entry -> ((IrVariable)entry.getValue()).type())));
            List fields = (List)context.variables().entrySet().stream().map(entry -> Field.newUnqualified((String)entry.getKey(), ((IrVariable)entry.getValue()).type())).collect(ImmutableList.toImmutableList());
            Scope scope = Scope.builder().withRelationType(RelationId.of((Node)expression), new RelationType(fields)).build();
            SymbolAllocator symbolAllocator = new SymbolAllocator();
            List fieldSymbols = (List)fields.stream().map(symbolAllocator::newSymbol).collect(ImmutableList.toImmutableList());
            Map<NodeRef<LambdaArgumentDeclaration>, Symbol> nodeRefSymbolMap = LogicalPlanner.buildLambdaDeclarationToSymbolMap(this.analysis, symbolAllocator);
            TranslationMap translationMap = new TranslationMap(Optional.empty(), scope, this.analysis, nodeRefSymbolMap, fieldSymbols, this.session, SqlRoutinePlanner.this.plannerContext);
            Expression translated = StatementVisitor.coerceIfNecessary(this.analysis, expression, translationMap.rewrite(expression));
            Expression lambdaCaptureDesugared = LambdaCaptureDesugaringRewriter.rewrite(translated, typeProvider, symbolAllocator);
            ExpressionAnalyzer analyzer = this.createExpressionAnalyzer(this.session, typeProvider);
            analyzer.analyze(lambdaCaptureDesugared, scope);
            ExpressionInterpreter interpreter = new ExpressionInterpreter(lambdaCaptureDesugared, SqlRoutinePlanner.this.plannerContext, this.session, analyzer.getExpressionTypes());
            Expression optimized = new LiteralEncoder(SqlRoutinePlanner.this.plannerContext).toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), analyzer.getExpressionTypes().get(NodeRef.of((Node)lambdaCaptureDesugared)));
            SugarFreeChecker.validate(optimized);
            analyzer = this.createExpressionAnalyzer(this.session, typeProvider);
            analyzer.analyze(optimized, scope);
            TranslationVisitor translator = new TranslationVisitor(SqlRoutinePlanner.this.plannerContext.getMetadata(), analyzer.getExpressionTypes(), (Map<Symbol, Integer>)ImmutableMap.of(), context.variables());
            RowExpression rowExpression = (RowExpression)translator.process((Node)optimized, null);
            ExpressionOptimizer optimizer = new ExpressionOptimizer(SqlRoutinePlanner.this.plannerContext.getMetadata(), SqlRoutinePlanner.this.plannerContext.getFunctionManager(), this.session);
            rowExpression = optimizer.optimize(rowExpression);
            return rowExpression;
        }

        public static Expression coerceIfNecessary(Analysis analysis, Expression original, Expression rewritten) {
            Type coercion = analysis.getCoercion(original);
            if (coercion == null) {
                return rewritten;
            }
            return new Cast(rewritten, TypeSignatureTranslator.toSqlType(coercion), false, analysis.isTypeOnlyCoercion(original));
        }

        private ExpressionAnalyzer createExpressionAnalyzer(Session session, TypeProvider typeProvider) {
            return ExpressionAnalyzer.createWithoutSubqueries(SqlRoutinePlanner.this.plannerContext, (AccessControl)new AllowAllAccessControl(), session, typeProvider, (Map<NodeRef<Parameter>, Expression>)ImmutableMap.of(), node -> new VerifyException("Unexpected subquery"), SqlRoutinePlanner.this.warningCollector, false);
        }

        private List<IrStatement> statements(List<ControlStatement> statements, Context context) {
            return (List)statements.stream().map(statement -> (IrStatement)this.process((Node)statement, context)).collect(ImmutableList.toImmutableList());
        }

        private static IrBlock block(List<IrStatement> statements) {
            return new IrBlock((List<IrVariable>)ImmutableList.of(), statements);
        }

        private static String identifierValue(Identifier name) {
            return name.getValue();
        }
    }

    private record Context(Map<String, IrVariable> variables, Map<String, IrLabel> labels) {
        public Context {
            variables = new LinkedHashMap<String, IrVariable>(variables);
            labels = new LinkedHashMap<String, IrLabel>(labels);
        }

        public Context newScope() {
            return new Context(this.variables, this.labels);
        }
    }

    private static class TranslationVisitor
    extends SqlToRowExpressionTranslator.Visitor {
        private final Map<String, IrVariable> variables;

        public TranslationVisitor(Metadata metadata, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout, Map<String, IrVariable> variables) {
            super(metadata, types, layout);
            this.variables = Objects.requireNonNull(variables, "variables is null");
        }

        @Override
        protected RowExpression visitSymbolReference(SymbolReference node, Void context) {
            IrVariable variable = this.variables.get(node.getName());
            if (variable != null) {
                return Expressions.field(variable.field(), variable.type());
            }
            return super.visitSymbolReference(node, context);
        }
    }
}

