/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.routine;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.ResolvedFunction;
import io.trino.security.AccessControl;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.function.FunctionId;
import io.trino.spi.function.FunctionMetadata;
import io.trino.spi.function.Signature;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeNotFoundException;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.Analysis;
import io.trino.sql.analyzer.CorrelationSupport;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.analyzer.Field;
import io.trino.sql.analyzer.QueryType;
import io.trino.sql.analyzer.RelationId;
import io.trino.sql.analyzer.RelationType;
import io.trino.sql.analyzer.Scope;
import io.trino.sql.analyzer.SemanticExceptions;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.routine.SqlRoutineAnalysis;
import io.trino.sql.tree.AssignmentStatement;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.CaseStatement;
import io.trino.sql.tree.CaseStatementWhenClause;
import io.trino.sql.tree.CommentCharacteristic;
import io.trino.sql.tree.CompoundStatement;
import io.trino.sql.tree.ControlStatement;
import io.trino.sql.tree.DataType;
import io.trino.sql.tree.DeterministicCharacteristic;
import io.trino.sql.tree.ElseClause;
import io.trino.sql.tree.ElseIfClause;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionSpecification;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.IfStatement;
import io.trino.sql.tree.IterateStatement;
import io.trino.sql.tree.LanguageCharacteristic;
import io.trino.sql.tree.LeaveStatement;
import io.trino.sql.tree.LoopStatement;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NullInputCharacteristic;
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.ParameterDeclaration;
import io.trino.sql.tree.RepeatStatement;
import io.trino.sql.tree.ReturnStatement;
import io.trino.sql.tree.ReturnsClause;
import io.trino.sql.tree.SecurityCharacteristic;
import io.trino.sql.tree.VariableDeclaration;
import io.trino.sql.tree.WhileStatement;
import io.trino.type.TypeCoercion;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;

public class SqlRoutineAnalyzer {
    private final PlannerContext plannerContext;
    private final WarningCollector warningCollector;

    public SqlRoutineAnalyzer(PlannerContext plannerContext, WarningCollector warningCollector) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.warningCollector = Objects.requireNonNull(warningCollector, "warningCollector is null");
    }

    public static FunctionMetadata extractFunctionMetadata(FunctionId functionId, FunctionSpecification function) {
        SqlRoutineAnalyzer.validateLanguage(function);
        SqlRoutineAnalyzer.validateReturn(function);
        String functionName = SqlRoutineAnalyzer.getFunctionName(function);
        Signature.Builder signatureBuilder = Signature.builder().returnType(TypeSignatureTranslator.toTypeSignature(function.getReturnsClause().getReturnType()));
        SqlRoutineAnalyzer.validateArguments(function);
        function.getParameters().stream().map(ParameterDeclaration::getType).map(TypeSignatureTranslator::toTypeSignature).forEach(arg_0 -> ((Signature.Builder)signatureBuilder).argumentType(arg_0));
        Signature signature = signatureBuilder.build();
        FunctionMetadata.Builder builder = FunctionMetadata.scalarBuilder((String)functionName).functionId(functionId).signature(signature).nullable().argumentNullability(Collections.nCopies(signature.getArgumentTypes().size(), SqlRoutineAnalyzer.isCalledOnNull(function)));
        SqlRoutineAnalyzer.getComment(function).filter(Predicate.not(String::isBlank)).ifPresentOrElse(arg_0 -> ((FunctionMetadata.Builder)builder).description(arg_0), () -> ((FunctionMetadata.Builder)builder).noDescription());
        if (!SqlRoutineAnalyzer.getDeterministic(function).orElse(true).booleanValue()) {
            builder.nondeterministic();
        }
        SqlRoutineAnalyzer.validateSecurity(function);
        return builder.build();
    }

    public SqlRoutineAnalysis analyze(Session session, AccessControl accessControl, FunctionSpecification function) {
        String functionName = SqlRoutineAnalyzer.getFunctionName(function);
        SqlRoutineAnalyzer.validateLanguage(function);
        boolean calledOnNull = SqlRoutineAnalyzer.isCalledOnNull(function);
        Optional<String> comment = SqlRoutineAnalyzer.getComment(function);
        SqlRoutineAnalyzer.validateSecurity(function);
        ReturnsClause returnsClause = function.getReturnsClause();
        Type returnType = this.getType((Node)returnsClause, returnsClause.getReturnType());
        Map<String, Type> arguments = this.getArguments(function);
        SqlRoutineAnalyzer.validateReturn(function);
        StatementVisitor visitor = new StatementVisitor(session, accessControl, returnType);
        visitor.process((Node)function.getStatement(), new Context(arguments, Set.of()));
        Analysis analysis = visitor.getAnalysis();
        boolean actuallyDeterministic = analysis.getResolvedFunctions().stream().allMatch(ResolvedFunction::isDeterministic);
        boolean declaredDeterministic = SqlRoutineAnalyzer.getDeterministic(function).orElse(true);
        if (!declaredDeterministic && actuallyDeterministic) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.INVALID_ARGUMENTS, (Node)function, "Deterministic function declared NOT DETERMINISTIC", new Object[0]);
        }
        if (declaredDeterministic && !actuallyDeterministic) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.INVALID_ARGUMENTS, (Node)function, "Non-deterministic function declared DETERMINISTIC", new Object[0]);
        }
        return new SqlRoutineAnalysis(functionName, arguments, returnType, calledOnNull, actuallyDeterministic, comment, visitor.getAnalysis());
    }

    private static String getFunctionName(FunctionSpecification function) {
        String name = function.getName().getSuffix();
        if (name.contains("@") || name.contains("$")) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, (Node)function, "Function name cannot contain '@' or '$'", new Object[0]);
        }
        return name;
    }

    private Type getType(Node node, DataType type) {
        try {
            return this.plannerContext.getTypeManager().getType(TypeSignatureTranslator.toTypeSignature(type));
        }
        catch (TypeNotFoundException e) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.TYPE_MISMATCH, node, "Unknown type: " + type, new Object[0]);
        }
    }

    private Map<String, Type> getArguments(FunctionSpecification function) {
        SqlRoutineAnalyzer.validateArguments(function);
        LinkedHashMap<String, Type> arguments = new LinkedHashMap<String, Type>();
        for (ParameterDeclaration parameter : function.getParameters()) {
            arguments.put(SqlRoutineAnalyzer.identifierValue((Identifier)parameter.getName().orElseThrow()), this.getType((Node)parameter, parameter.getType()));
        }
        return arguments;
    }

    private static void validateArguments(FunctionSpecification function) {
        LinkedHashSet<String> argumentNames = new LinkedHashSet<String>();
        for (ParameterDeclaration parameter : function.getParameters()) {
            if (parameter.getName().isEmpty()) {
                throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.INVALID_ARGUMENTS, (Node)parameter, "Function parameters must have a name", new Object[0]);
            }
            String name = SqlRoutineAnalyzer.identifierValue((Identifier)parameter.getName().get());
            if (argumentNames.add(name)) continue;
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.INVALID_ARGUMENTS, (Node)parameter, "Duplicate function parameter name: " + name, new Object[0]);
        }
    }

    private static Optional<String> getLanguage(FunctionSpecification function) {
        List language = (List)function.getRoutineCharacteristics().stream().filter(LanguageCharacteristic.class::isInstance).map(LanguageCharacteristic.class::cast).collect(ImmutableList.toImmutableList());
        if (language.size() > 1) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.SYNTAX_ERROR, (Node)function, "Multiple language clauses specified", new Object[0]);
        }
        return language.stream().map(LanguageCharacteristic::getLanguage).map(Identifier::getValue).findAny();
    }

    private static void validateLanguage(FunctionSpecification function) {
        Optional<String> language = SqlRoutineAnalyzer.getLanguage(function);
        if (language.isPresent() && !language.get().equalsIgnoreCase("sql")) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, (Node)function, "Unsupported language: %s", language.get());
        }
    }

    private static Optional<Boolean> getDeterministic(FunctionSpecification function) {
        List deterministic = (List)function.getRoutineCharacteristics().stream().filter(DeterministicCharacteristic.class::isInstance).map(DeterministicCharacteristic.class::cast).collect(ImmutableList.toImmutableList());
        if (deterministic.size() > 1) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.SYNTAX_ERROR, (Node)function, "Multiple deterministic clauses specified", new Object[0]);
        }
        return deterministic.stream().map(DeterministicCharacteristic::isDeterministic).findAny();
    }

    private static boolean isCalledOnNull(FunctionSpecification function) {
        List nullInput = (List)function.getRoutineCharacteristics().stream().filter(NullInputCharacteristic.class::isInstance).map(NullInputCharacteristic.class::cast).collect(ImmutableList.toImmutableList());
        if (nullInput.size() > 1) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.SYNTAX_ERROR, (Node)function, "Multiple null-call clauses specified", new Object[0]);
        }
        return nullInput.stream().map(NullInputCharacteristic::isCalledOnNull).findAny().orElse(true);
    }

    public static boolean isRunAsInvoker(FunctionSpecification function) {
        List security = (List)function.getRoutineCharacteristics().stream().filter(SecurityCharacteristic.class::isInstance).map(SecurityCharacteristic.class::cast).collect(ImmutableList.toImmutableList());
        if (security.size() > 1) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.SYNTAX_ERROR, (Node)function, "Multiple security clauses specified", new Object[0]);
        }
        return security.stream().map(SecurityCharacteristic::getSecurity).map(arg_0 -> SecurityCharacteristic.Security.INVOKER.equals(arg_0)).findAny().orElse(false);
    }

    private static void validateSecurity(FunctionSpecification function) {
        SqlRoutineAnalyzer.isRunAsInvoker(function);
    }

    private static Optional<String> getComment(FunctionSpecification function) {
        List comment = (List)function.getRoutineCharacteristics().stream().filter(CommentCharacteristic.class::isInstance).map(CommentCharacteristic.class::cast).collect(ImmutableList.toImmutableList());
        if (comment.size() > 1) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.SYNTAX_ERROR, (Node)function, "Multiple comment clauses specified", new Object[0]);
        }
        return comment.stream().map(CommentCharacteristic::getComment).findAny();
    }

    private static void validateReturn(FunctionSpecification function) {
        ControlStatement statement = function.getStatement();
        if (statement instanceof ReturnStatement) {
            return;
        }
        Preconditions.checkArgument((boolean)(statement instanceof CompoundStatement), (String)"invalid function statement: %s", (Object)statement);
        CompoundStatement body = (CompoundStatement)statement;
        if (!(Iterables.getLast((Iterable)body.getStatements(), null) instanceof ReturnStatement)) {
            throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.MISSING_RETURN, (Node)body, "Function must end in a RETURN statement", new Object[0]);
        }
    }

    private static String identifierValue(Identifier name) {
        return name.getValue();
    }

    private class StatementVisitor
    extends AstVisitor<Void, Context> {
        private final Session session;
        private final AccessControl accessControl;
        private final Type returnType;
        private final Analysis analysis = new Analysis(null, (Map<NodeRef<Parameter>, Expression>)ImmutableMap.of(), QueryType.OTHERS);
        private final TypeCoercion typeCoercion = new TypeCoercion(arg_0 -> ((TypeManager)SqlRoutineAnalyzer.this.plannerContext.getTypeManager()).getType(arg_0));

        public StatementVisitor(Session session, AccessControl accessControl, Type returnType) {
            this.session = Objects.requireNonNull(session, "session is null");
            this.accessControl = Objects.requireNonNull(accessControl, "accessControl is null");
            this.returnType = Objects.requireNonNull(returnType, "returnType is null");
        }

        public Analysis getAnalysis() {
            return this.analysis;
        }

        protected Void visitNode(Node node, Context context) {
            throw new UnsupportedOperationException("Analysis not yet implemented: " + node);
        }

        protected Void visitCompoundStatement(CompoundStatement node, Context context) {
            Context newContext = context.newScope();
            for (VariableDeclaration declaration : node.getVariableDeclarations()) {
                Type type = SqlRoutineAnalyzer.this.getType((Node)declaration, declaration.getType());
                this.analysis.addType((Expression)declaration.getType(), type);
                declaration.getDefaultValue().ifPresent(value -> this.analyzeExpression(newContext, (Expression)value, type, "Value of DEFAULT"));
                for (Identifier name : declaration.getNames()) {
                    if (newContext.variables().put(SqlRoutineAnalyzer.identifierValue(name), type) == null) continue;
                    throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.ALREADY_EXISTS, (Node)name, "Variable already declared in this scope: %s", name);
                }
            }
            this.analyzeNodes(newContext, node.getStatements());
            return null;
        }

        protected Void visitIfStatement(IfStatement node, Context context) {
            this.analyzeExpression(context, node.getExpression(), (Type)BooleanType.BOOLEAN, "Condition of IF statement");
            this.analyzeNodes(context, node.getStatements());
            this.analyzeNodes(context, node.getElseIfClauses());
            node.getElseClause().ifPresent(statement -> this.process((Node)statement, context));
            return null;
        }

        protected Void visitElseIfClause(ElseIfClause node, Context context) {
            this.analyzeExpression(context, node.getExpression(), (Type)BooleanType.BOOLEAN, "Condition of ELSEIF clause");
            this.analyzeNodes(context, node.getStatements());
            return null;
        }

        protected Void visitElseClause(ElseClause node, Context context) {
            this.analyzeNodes(context, node.getStatements());
            return null;
        }

        protected Void visitCaseStatement(CaseStatement node, Context context) {
            if (node.getExpression().isPresent()) {
                Type valueType = this.analyzeExpression(context, (Expression)node.getExpression().get());
                for (CaseStatementWhenClause whenClause : node.getWhenClauses()) {
                    Type whenType = this.analyzeExpression(context, whenClause.getExpression());
                    Optional<Type> superType = this.typeCoercion.getCommonSuperType(valueType, whenType);
                    if (superType.isEmpty()) {
                        throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.TYPE_MISMATCH, (Node)whenClause.getExpression(), "WHEN clause value must evaluate to CASE value type %s (actual: %s)", valueType, whenType);
                    }
                    if (whenType.equals(superType.get())) continue;
                    this.addCoercion(whenClause.getExpression(), whenType, superType.get());
                }
            } else {
                for (CaseStatementWhenClause whenClause : node.getWhenClauses()) {
                    this.analyzeExpression(context, whenClause.getExpression(), (Type)BooleanType.BOOLEAN, "Condition of WHEN clause");
                }
            }
            for (CaseStatementWhenClause whenClause : node.getWhenClauses()) {
                this.analyzeNodes(context, whenClause.getStatements());
            }
            if (node.getElseClause().isPresent()) {
                this.process((Node)node.getElseClause().get(), context);
            }
            return null;
        }

        protected Void visitWhileStatement(WhileStatement node, Context context) {
            Context newContext = context.newScope();
            node.getLabel().ifPresent(name -> StatementVisitor.defineLabel(newContext, name));
            this.analyzeExpression(newContext, node.getExpression(), (Type)BooleanType.BOOLEAN, "Condition of WHILE statement");
            this.analyzeNodes(newContext, node.getStatements());
            return null;
        }

        protected Void visitRepeatStatement(RepeatStatement node, Context context) {
            Context newContext = context.newScope();
            node.getLabel().ifPresent(name -> StatementVisitor.defineLabel(newContext, name));
            this.analyzeExpression(newContext, node.getCondition(), (Type)BooleanType.BOOLEAN, "Condition of REPEAT statement");
            this.analyzeNodes(newContext, node.getStatements());
            return null;
        }

        protected Void visitLoopStatement(LoopStatement node, Context context) {
            Context newContext = context.newScope();
            node.getLabel().ifPresent(name -> StatementVisitor.defineLabel(newContext, name));
            this.analyzeNodes(newContext, node.getStatements());
            return null;
        }

        protected Void visitReturnStatement(ReturnStatement node, Context context) {
            this.analyzeExpression(context, node.getValue(), this.returnType, "Value of RETURN");
            return null;
        }

        protected Void visitAssignmentStatement(AssignmentStatement node, Context context) {
            Identifier name = node.getTarget();
            Type targetType = context.variables().get(SqlRoutineAnalyzer.identifierValue(name));
            if (targetType == null) {
                throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.NOT_FOUND, (Node)name, "Variable cannot be resolved: %s", name);
            }
            this.analyzeExpression(context, node.getValue(), targetType, String.format("Value of SET '%s'", name));
            return null;
        }

        protected Void visitIterateStatement(IterateStatement node, Context context) {
            StatementVisitor.verifyLabelExists(context, node.getLabel());
            return null;
        }

        protected Void visitLeaveStatement(LeaveStatement node, Context context) {
            StatementVisitor.verifyLabelExists(context, node.getLabel());
            return null;
        }

        private void analyzeExpression(Context context, Expression expression, Type expectedType, String message) {
            Type actualType = this.analyzeExpression(context, expression);
            if (actualType.equals(expectedType)) {
                return;
            }
            if (!this.typeCoercion.canCoerce(actualType, expectedType)) {
                throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.TYPE_MISMATCH, (Node)expression, message + " must evaluate to %s (actual: %s)", expectedType, actualType);
            }
            this.addCoercion(expression, actualType, expectedType);
        }

        private Type analyzeExpression(Context context, Expression expression) {
            List fields = (List)context.variables().entrySet().stream().map(entry -> Field.newUnqualified((String)entry.getKey(), (Type)entry.getValue())).collect(ImmutableList.toImmutableList());
            Scope scope = Scope.builder().withRelationType(RelationId.of((Node)expression), new RelationType(fields)).build();
            ExpressionAnalyzer.analyzeExpressionWithoutSubqueries(this.session, SqlRoutineAnalyzer.this.plannerContext, this.accessControl, scope, this.analysis, expression, (ErrorCodeSupplier)StandardErrorCode.NOT_SUPPORTED, "Queries are not allowed in functions", SqlRoutineAnalyzer.this.warningCollector, CorrelationSupport.DISALLOWED);
            return this.analysis.getType(expression);
        }

        private void addCoercion(Expression expression, Type actualType, Type expectedType) {
            this.analysis.addCoercion(expression, expectedType, this.typeCoercion.isTypeOnlyCoercion(actualType, expectedType));
        }

        private void analyzeNodes(Context context, List<? extends Node> statements) {
            for (Node node : statements) {
                this.process(node, context);
            }
        }

        private static void defineLabel(Context context, Identifier name) {
            if (!context.labels().add(SqlRoutineAnalyzer.identifierValue(name))) {
                throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.ALREADY_EXISTS, (Node)name, "Label already declared in this scope: %s", name);
            }
        }

        private static void verifyLabelExists(Context context, Identifier name) {
            if (!context.labels().contains(SqlRoutineAnalyzer.identifierValue(name))) {
                throw SemanticExceptions.semanticException((ErrorCodeSupplier)StandardErrorCode.NOT_FOUND, (Node)name, "Label not defined: %s", name);
            }
        }
    }

    private record Context(Map<String, Type> variables, Set<String> labels) {
        private Context {
            variables = new LinkedHashMap<String, Type>(variables);
            labels = new LinkedHashSet<String>(labels);
        }

        public Context newScope() {
            return new Context(this.variables, this.labels);
        }
    }
}

