/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.FunctionResolver;
import io.trino.metadata.ResolvedFunction;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.ArithmeticBinaryExpression;
import io.trino.sql.ir.ArithmeticUnaryExpression;
import io.trino.sql.ir.Array;
import io.trino.sql.ir.BetweenPredicate;
import io.trino.sql.ir.BinaryLiteral;
import io.trino.sql.ir.BindExpression;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.CoalesceExpression;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.DecimalLiteral;
import io.trino.sql.ir.DoubleLiteral;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.GenericLiteral;
import io.trino.sql.ir.IfExpression;
import io.trino.sql.ir.InPredicate;
import io.trino.sql.ir.IntervalLiteral;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNotNullPredicate;
import io.trino.sql.ir.IsNullPredicate;
import io.trino.sql.ir.LambdaExpression;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.LongLiteral;
import io.trino.sql.ir.NodeRef;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.NullIfExpression;
import io.trino.sql.ir.NullLiteral;
import io.trino.sql.ir.Row;
import io.trino.sql.ir.SearchedCaseExpression;
import io.trino.sql.ir.SimpleCaseExpression;
import io.trino.sql.ir.StringLiteral;
import io.trino.sql.ir.SubscriptExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.type.FunctionType;
import io.trino.type.IntervalDayTimeType;
import io.trino.type.IntervalYearMonthType;
import io.trino.type.UnknownType;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class IrTypeAnalyzer {
    private final PlannerContext plannerContext;

    @Inject
    public IrTypeAnalyzer(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    public Map<NodeRef<Expression>, Type> getTypes(Session session, TypeProvider inputTypes, Iterable<Expression> expressions) {
        Visitor visitor = new Visitor(this.plannerContext, session, inputTypes);
        for (Expression expression : expressions) {
            visitor.process(expression, new Context((Map<Symbol, Type>)ImmutableMap.of()));
        }
        return visitor.getTypes();
    }

    public Map<NodeRef<Expression>, Type> getTypes(Session session, TypeProvider inputTypes, Expression expression) {
        return this.getTypes(session, inputTypes, (Iterable<Expression>)ImmutableList.of((Object)expression));
    }

    public Type getType(Session session, TypeProvider inputTypes, Expression expression) {
        return this.getTypes(session, inputTypes, expression).get(NodeRef.of(expression));
    }

    private static class Visitor
    extends IrVisitor<Type, Context> {
        private static final AccessControl ALLOW_ALL_ACCESS_CONTROL = new AllowAllAccessControl();
        private final PlannerContext plannerContext;
        private final Session session;
        private final TypeProvider symbolTypes;
        private final FunctionResolver functionResolver;
        private final Map<NodeRef<Expression>, Type> expressionTypes = new LinkedHashMap<NodeRef<Expression>, Type>();

        public Visitor(PlannerContext plannerContext, Session session, TypeProvider symbolTypes) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.symbolTypes = Objects.requireNonNull(symbolTypes, "symbolTypes is null");
            this.functionResolver = plannerContext.getFunctionResolver(WarningCollector.NOOP);
        }

        public Map<NodeRef<Expression>, Type> getTypes() {
            return this.expressionTypes;
        }

        private Type setExpressionType(Expression expression, Type type) {
            Objects.requireNonNull(expression, "expression cannot be null");
            Objects.requireNonNull(type, "type cannot be null");
            this.expressionTypes.put(NodeRef.of(expression), type);
            return type;
        }

        @Override
        public Type process(Expression node, Context context) {
            Type type = this.expressionTypes.get(NodeRef.of(node));
            if (type != null) {
                return type;
            }
            return (Type)super.process(node, context);
        }

        @Override
        protected Type visitRow(Row node, Context context) {
            List types = (List)node.getItems().stream().map(child -> this.process((Expression)child, context)).collect(ImmutableList.toImmutableList());
            return this.setExpressionType(node, (Type)RowType.anonymous((List)types));
        }

        @Override
        protected Type visitSymbolReference(SymbolReference node, Context context) {
            Symbol symbol = new Symbol(node.getName());
            Type type = context.argumentTypes().get(symbol);
            if (type == null) {
                type = this.symbolTypes.get(symbol);
            }
            Preconditions.checkArgument((type != null ? 1 : 0) != 0, (String)"No type for: %s", (Object)node.getName());
            return this.setExpressionType(node, type);
        }

        @Override
        protected Type visitNotExpression(NotExpression node, Context context) {
            this.process(node.getValue(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitLogicalExpression(LogicalExpression node, Context context) {
            node.getTerms().forEach(term -> this.process((Expression)term, context));
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitComparisonExpression(ComparisonExpression node, Context context) {
            this.process(node.getLeft(), context);
            this.process(node.getRight(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitIsNullPredicate(IsNullPredicate node, Context context) {
            this.process(node.getValue(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitIsNotNullPredicate(IsNotNullPredicate node, Context context) {
            this.process(node.getValue(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitNullIfExpression(NullIfExpression node, Context context) {
            Type firstType = this.process(node.getFirst(), context);
            Type ignored = this.process(node.getSecond(), context);
            return this.setExpressionType(node, firstType);
        }

        @Override
        protected Type visitIfExpression(IfExpression node, Context context) {
            Type conditionType = this.process(node.getCondition(), context);
            Preconditions.checkArgument((boolean)conditionType.equals((Object)BooleanType.BOOLEAN), (String)"Condition must be boolean: %s", (Object)conditionType);
            Type trueType = this.process(node.getTrueValue(), context);
            if (node.getFalseValue().isPresent()) {
                Type falseType = this.process(node.getFalseValue().get(), context);
                Preconditions.checkArgument((boolean)trueType.equals((Object)falseType), (String)"Types must be equal: %s vs %s", (Object)trueType, (Object)falseType);
            }
            return this.setExpressionType(node, trueType);
        }

        @Override
        protected Type visitSearchedCaseExpression(SearchedCaseExpression node, Context context) {
            Set resultTypes = node.getWhenClauses().stream().map(clause -> {
                Type operandType = this.process(clause.getOperand(), context);
                Preconditions.checkArgument((boolean)operandType.equals((Object)BooleanType.BOOLEAN), (String)"When clause operand must be boolean: %s", (Object)operandType);
                return this.setExpressionType((Expression)clause, this.process(clause.getResult(), context));
            }).collect(Collectors.toSet());
            Preconditions.checkArgument((resultTypes.size() == 1 ? 1 : 0) != 0, (String)"All result types must be the same: %s", resultTypes);
            Type resultType = (Type)resultTypes.iterator().next();
            node.getDefaultValue().ifPresent(defaultValue -> {
                Type defaultType = this.process((Expression)defaultValue, context);
                Preconditions.checkArgument((boolean)defaultType.equals((Object)resultType), (String)"Default result type must be the same as WHEN result types: %s vs %s", (Object)defaultType, (Object)resultType);
            });
            return this.setExpressionType(node, resultType);
        }

        @Override
        protected Type visitSimpleCaseExpression(SimpleCaseExpression node, Context context) {
            Type operandType = this.process(node.getOperand(), context);
            Set resultTypes = node.getWhenClauses().stream().map(clause -> {
                Type clauseOperandType = this.process(clause.getOperand(), context);
                Preconditions.checkArgument((boolean)clauseOperandType.equals((Object)operandType), (String)"WHEN clause operand type must match CASE operand type: %s vs %s", (Object)clauseOperandType, (Object)operandType);
                return this.setExpressionType((Expression)clause, this.process(clause.getResult(), context));
            }).collect(Collectors.toSet());
            Preconditions.checkArgument((resultTypes.size() == 1 ? 1 : 0) != 0, (String)"All result types must be the same: %s", resultTypes);
            Type resultType = (Type)resultTypes.iterator().next();
            node.getDefaultValue().ifPresent(defaultValue -> {
                Type defaultType = this.process((Expression)defaultValue, context);
                Preconditions.checkArgument((boolean)defaultType.equals((Object)resultType), (String)"Default result type must be the same as WHEN result types: %s vs %s", (Object)defaultType, (Object)resultType);
            });
            return this.setExpressionType(node, resultType);
        }

        @Override
        protected Type visitCoalesceExpression(CoalesceExpression node, Context context) {
            Set types = node.getOperands().stream().map(operand -> this.process((Expression)operand, context)).collect(Collectors.toSet());
            Preconditions.checkArgument((types.size() == 1 ? 1 : 0) != 0, (String)"All operands must have the same type: %s", types);
            return this.setExpressionType(node, (Type)types.iterator().next());
        }

        @Override
        protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, Context context) {
            return this.setExpressionType(node, this.process(node.getValue(), context));
        }

        @Override
        protected Type visitArithmeticBinary(ArithmeticBinaryExpression node, Context context) {
            ImmutableList.Builder argumentTypes = ImmutableList.builder();
            argumentTypes.add((Object)this.process(node.getLeft(), context));
            argumentTypes.add((Object)this.process(node.getRight(), context));
            BoundSignature operatorSignature = this.plannerContext.getMetadata().resolveOperator(OperatorType.valueOf((String)node.getOperator().name()), (List<? extends Type>)argumentTypes.build()).getSignature();
            return this.setExpressionType(node, operatorSignature.getReturnType());
        }

        @Override
        protected Type visitSubscriptExpression(SubscriptExpression node, Context context) {
            Type baseType = this.process(node.getBase(), context);
            this.process(node.getIndex(), context);
            Type type = baseType;
            Objects.requireNonNull(type);
            Type type2 = type;
            int n = 0;
            return this.setExpressionType(node, switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{RowType.class, ArrayType.class, MapType.class}, (Object)type2, n)) {
                case 0 -> {
                    RowType rowType = (RowType)type2;
                    yield ((RowType.Field)rowType.getFields().get(Math.toIntExact(((LongLiteral)node.getIndex()).getValue()) - 1)).getType();
                }
                case 1 -> {
                    ArrayType arrayType = (ArrayType)type2;
                    yield arrayType.getElementType();
                }
                case 2 -> {
                    MapType mapType = (MapType)type2;
                    yield mapType.getValueType();
                }
                default -> throw new IllegalStateException("Unexpected type: " + String.valueOf(baseType));
            });
        }

        @Override
        protected Type visitArray(Array node, Context context) {
            Set types = node.getValues().stream().map(entry -> this.process((Expression)entry, context)).collect(Collectors.toSet());
            if (types.isEmpty()) {
                return this.setExpressionType(node, (Type)new ArrayType((Type)UnknownType.UNKNOWN));
            }
            Preconditions.checkArgument((types.size() == 1 ? 1 : 0) != 0, (String)"All entries must have the same type: %s", types);
            return this.setExpressionType(node, (Type)new ArrayType((Type)types.iterator().next()));
        }

        @Override
        protected Type visitStringLiteral(StringLiteral node, Context context) {
            return this.setExpressionType(node, (Type)VarcharType.createVarcharType((int)node.length()));
        }

        @Override
        protected Type visitBinaryLiteral(BinaryLiteral node, Context context) {
            return this.setExpressionType(node, (Type)VarbinaryType.VARBINARY);
        }

        @Override
        protected Type visitLongLiteral(LongLiteral node, Context context) {
            if (node.getValue() >= Integer.MIN_VALUE && node.getValue() <= Integer.MAX_VALUE) {
                return this.setExpressionType(node, (Type)IntegerType.INTEGER);
            }
            return this.setExpressionType(node, (Type)BigintType.BIGINT);
        }

        @Override
        protected Type visitDoubleLiteral(DoubleLiteral node, Context context) {
            return this.setExpressionType(node, (Type)DoubleType.DOUBLE);
        }

        @Override
        protected Type visitDecimalLiteral(DecimalLiteral node, Context context) {
            return this.setExpressionType(node, (Type)Decimals.parse((String)node.getValue()).getType());
        }

        @Override
        protected Type visitBooleanLiteral(BooleanLiteral node, Context context) {
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitGenericLiteral(GenericLiteral node, Context context) {
            return this.setExpressionType(node, node.getType());
        }

        @Override
        protected Type visitIntervalLiteral(IntervalLiteral node, Context context) {
            Object type = node.isYearToMonth() ? IntervalYearMonthType.INTERVAL_YEAR_MONTH : IntervalDayTimeType.INTERVAL_DAY_TIME;
            return this.setExpressionType(node, (Type)type);
        }

        @Override
        protected Type visitNullLiteral(NullLiteral node, Context context) {
            return this.setExpressionType(node, (Type)UnknownType.UNKNOWN);
        }

        @Override
        protected Type visitFunctionCall(FunctionCall node, Context context) {
            ResolvedFunction function = this.functionResolver.resolveFunction(this.session, node.getName(), null, ALLOW_ALL_ACCESS_CONTROL);
            BoundSignature signature = function.getSignature();
            for (int i = 0; i < node.getArguments().size(); ++i) {
                Expression expression;
                Expression argument = node.getArguments().get(i);
                Type formalType = (Type)signature.getArgumentTypes().get(i);
                Objects.requireNonNull(argument);
                int n = 0;
                Type type = switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{LambdaExpression.class, BindExpression.class}, (Object)expression, n)) {
                    case 0 -> {
                        LambdaExpression lambda = (LambdaExpression)expression;
                        yield this.processLambdaExpression(lambda, ((FunctionType)formalType).getArgumentTypes());
                    }
                    case 1 -> {
                        BindExpression bind = (BindExpression)expression;
                        yield this.processBindExpression(bind, (FunctionType)formalType, context);
                    }
                    default -> this.process(argument, context);
                };
            }
            return this.setExpressionType(node, signature.getReturnType());
        }

        private Type processBindExpression(BindExpression bind, FunctionType formalType, Context context) {
            ArrayList<Type> argumentTypes = new ArrayList<Type>();
            argumentTypes.addAll((Collection)bind.getValues().stream().map(value -> this.process((Expression)value, context)).collect(ImmutableList.toImmutableList()));
            argumentTypes.addAll(formalType.getArgumentTypes());
            if (bind.getFunction() instanceof LambdaExpression) {
                Type unused = this.processLambdaExpression((LambdaExpression)bind.getFunction(), argumentTypes);
                return this.setExpressionType(bind, formalType);
            }
            throw new UnsupportedOperationException("not yet implemented");
        }

        private Type processLambdaExpression(LambdaExpression lambda, List<Type> argumentTypes) {
            ImmutableMap.Builder typeBindings = ImmutableMap.builder();
            for (int i = 0; i < argumentTypes.size(); ++i) {
                typeBindings.put((Object)new Symbol(lambda.getArguments().get(i)), (Object)argumentTypes.get(i));
            }
            Type returnType = this.process(lambda.getBody(), new Context((Map<Symbol, Type>)typeBindings.buildOrThrow()));
            return this.setExpressionType(lambda, new FunctionType(argumentTypes, returnType));
        }

        @Override
        protected Type visitBetweenPredicate(BetweenPredicate node, Context context) {
            this.process(node.getValue(), context);
            this.process(node.getMin(), context);
            this.process(node.getMax(), context);
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        public Type visitCast(Cast node, Context context) {
            this.process(node.getExpression(), context);
            return this.setExpressionType(node, node.getType());
        }

        @Override
        protected Type visitInPredicate(InPredicate node, Context context) {
            Expression value = node.getValue();
            Type type = this.process(value, context);
            for (Expression item : node.getValueList()) {
                Type itemType = this.process(item, context);
                Preconditions.checkArgument((boolean)itemType.equals((Object)type), (String)"Types must be equal: %s vs %s", (Object)itemType, (Object)type);
            }
            return this.setExpressionType(node, (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected Type visitExpression(Expression node, Context context) {
            throw new UnsupportedOperationException("Not a valid IR expression: " + node.getClass().getName());
        }
    }

    private record Context(Map<Symbol, Type> argumentTypes) {
    }
}

