/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Preconditions;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.cache.CacheUtils;
import io.trino.cache.SafeCaches;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.FunctionResolver;
import io.trino.metadata.ResolvedFunction;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.function.BoundSignature;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TimeType;
import io.trino.spi.type.TimeWithTimeZoneType;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.Array;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.BinaryLiteral;
import io.trino.sql.tree.BindExpression;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.CoalesceExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.DecimalLiteral;
import io.trino.sql.tree.DoubleLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.InListExpression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IntervalLiteral;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LambdaArgumentDeclaration;
import io.trino.sql.tree.LambdaExpression;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullIfExpression;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.SearchedCaseExpression;
import io.trino.sql.tree.SimpleCaseExpression;
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.SubscriptExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.type.DateTimes;
import io.trino.type.FunctionType;
import io.trino.type.IntervalDayTimeType;
import io.trino.type.IntervalYearMonthType;
import io.trino.type.UnknownType;
import java.lang.runtime.SwitchBootstraps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class IrTypeAnalyzer {
    private final PlannerContext plannerContext;

    @Inject
    public IrTypeAnalyzer(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
    }

    public Map<NodeRef<Expression>, Type> getTypes(Session session, TypeProvider inputTypes, Iterable<Expression> expressions) {
        Visitor visitor = new Visitor(this.plannerContext, session, inputTypes);
        for (Expression expression : expressions) {
            visitor.process((Node)expression, new Context((Map<Symbol, Type>)ImmutableMap.of()));
        }
        return visitor.getTypes();
    }

    public Map<NodeRef<Expression>, Type> getTypes(Session session, TypeProvider inputTypes, Expression expression) {
        return this.getTypes(session, inputTypes, (Iterable<Expression>)ImmutableList.of((Object)expression));
    }

    public Type getType(Session session, TypeProvider inputTypes, Expression expression) {
        return this.getTypes(session, inputTypes, expression).get(NodeRef.of((Node)expression));
    }

    private static class Visitor
    extends AstVisitor<Type, Context> {
        private static final AccessControl ALLOW_ALL_ACCESS_CONTROL = new AllowAllAccessControl();
        private final PlannerContext plannerContext;
        private final Session session;
        private final TypeProvider symbolTypes;
        private final FunctionResolver functionResolver;
        private final Cache<String, Type> varcharCastableTypeCache = SafeCaches.buildNonEvictableCache((CacheBuilder)CacheBuilder.newBuilder().maximumSize(1000L));
        private final Map<NodeRef<Expression>, Type> expressionTypes = new LinkedHashMap<NodeRef<Expression>, Type>();

        public Visitor(PlannerContext plannerContext, Session session, TypeProvider symbolTypes) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.symbolTypes = Objects.requireNonNull(symbolTypes, "symbolTypes is null");
            this.functionResolver = plannerContext.getFunctionResolver(WarningCollector.NOOP);
        }

        public Map<NodeRef<Expression>, Type> getTypes() {
            return this.expressionTypes;
        }

        private Type setExpressionType(Expression expression, Type type) {
            Objects.requireNonNull(expression, "expression cannot be null");
            Objects.requireNonNull(type, "type cannot be null");
            this.expressionTypes.put((NodeRef<Expression>)NodeRef.of((Node)expression), type);
            return type;
        }

        public Type process(Node node, Context context) {
            Type type;
            if (node instanceof Expression && (type = this.expressionTypes.get(NodeRef.of((Node)((Expression)node)))) != null) {
                return type;
            }
            return (Type)super.process(node, (Object)context);
        }

        protected Type visitRow(Row node, Context context) {
            List types = (List)node.getItems().stream().map(child -> this.process((Node)child, context)).collect(ImmutableList.toImmutableList());
            return this.setExpressionType((Expression)node, (Type)RowType.anonymous((List)types));
        }

        protected Type visitSymbolReference(SymbolReference node, Context context) {
            Symbol symbol = Symbol.from((Expression)node);
            Type type = context.argumentTypes().get(symbol);
            if (type == null) {
                type = this.symbolTypes.get(symbol);
            }
            Preconditions.checkArgument((type != null ? 1 : 0) != 0, (String)"No type for: %s", (Object)node.getName());
            return this.setExpressionType((Expression)node, type);
        }

        protected Type visitNotExpression(NotExpression node, Context context) {
            this.process((Node)node.getValue(), context);
            return this.setExpressionType((Expression)node, (Type)BooleanType.BOOLEAN);
        }

        protected Type visitLogicalExpression(LogicalExpression node, Context context) {
            node.getTerms().forEach(term -> this.process((Node)term, context));
            return this.setExpressionType((Expression)node, (Type)BooleanType.BOOLEAN);
        }

        protected Type visitComparisonExpression(ComparisonExpression node, Context context) {
            this.process((Node)node.getLeft(), context);
            this.process((Node)node.getRight(), context);
            return this.setExpressionType((Expression)node, (Type)BooleanType.BOOLEAN);
        }

        protected Type visitIsNullPredicate(IsNullPredicate node, Context context) {
            this.process((Node)node.getValue(), context);
            return this.setExpressionType((Expression)node, (Type)BooleanType.BOOLEAN);
        }

        protected Type visitIsNotNullPredicate(IsNotNullPredicate node, Context context) {
            this.process((Node)node.getValue(), context);
            return this.setExpressionType((Expression)node, (Type)BooleanType.BOOLEAN);
        }

        protected Type visitNullIfExpression(NullIfExpression node, Context context) {
            Type firstType = this.process((Node)node.getFirst(), context);
            Type ignored = this.process((Node)node.getSecond(), context);
            return this.setExpressionType((Expression)node, firstType);
        }

        protected Type visitIfExpression(IfExpression node, Context context) {
            Type conditionType = this.process((Node)node.getCondition(), context);
            Preconditions.checkArgument((boolean)conditionType.equals((Object)BooleanType.BOOLEAN), (String)"Condition must be boolean: %s", (Object)conditionType);
            Type trueType = this.process((Node)node.getTrueValue(), context);
            if (node.getFalseValue().isPresent()) {
                Type falseType = this.process((Node)node.getFalseValue().get(), context);
                Preconditions.checkArgument((boolean)trueType.equals((Object)falseType), (String)"Types must be equal: %s vs %s", (Object)trueType, (Object)falseType);
            }
            return this.setExpressionType((Expression)node, trueType);
        }

        protected Type visitSearchedCaseExpression(SearchedCaseExpression node, Context context) {
            Set resultTypes = node.getWhenClauses().stream().map(clause -> {
                Type operandType = this.process((Node)clause.getOperand(), context);
                Preconditions.checkArgument((boolean)operandType.equals((Object)BooleanType.BOOLEAN), (String)"When clause operand must be boolean: %s", (Object)operandType);
                return this.setExpressionType((Expression)clause, this.process((Node)clause.getResult(), context));
            }).collect(Collectors.toSet());
            Preconditions.checkArgument((resultTypes.size() == 1 ? 1 : 0) != 0, (String)"All result types must be the same: %s", resultTypes);
            Type resultType = (Type)resultTypes.iterator().next();
            node.getDefaultValue().ifPresent(defaultValue -> {
                Type defaultType = this.process((Node)defaultValue, context);
                Preconditions.checkArgument((boolean)defaultType.equals((Object)resultType), (String)"Default result type must be the same as WHEN result types: %s vs %s", (Object)defaultType, (Object)resultType);
            });
            return this.setExpressionType((Expression)node, resultType);
        }

        protected Type visitSimpleCaseExpression(SimpleCaseExpression node, Context context) {
            Type operandType = this.process((Node)node.getOperand(), context);
            Set resultTypes = node.getWhenClauses().stream().map(clause -> {
                Type clauseOperandType = this.process((Node)clause.getOperand(), context);
                Preconditions.checkArgument((boolean)clauseOperandType.equals((Object)operandType), (String)"WHEN clause operand type must match CASE operand type: %s vs %s", (Object)clauseOperandType, (Object)operandType);
                return this.setExpressionType((Expression)clause, this.process((Node)clause.getResult(), context));
            }).collect(Collectors.toSet());
            Preconditions.checkArgument((resultTypes.size() == 1 ? 1 : 0) != 0, (String)"All result types must be the same: %s", resultTypes);
            Type resultType = (Type)resultTypes.iterator().next();
            node.getDefaultValue().ifPresent(defaultValue -> {
                Type defaultType = this.process((Node)defaultValue, context);
                Preconditions.checkArgument((boolean)defaultType.equals((Object)resultType), (String)"Default result type must be the same as WHEN result types: %s vs %s", (Object)defaultType, (Object)resultType);
            });
            return this.setExpressionType((Expression)node, resultType);
        }

        protected Type visitCoalesceExpression(CoalesceExpression node, Context context) {
            Set types = node.getOperands().stream().map(operand -> this.process((Node)operand, context)).collect(Collectors.toSet());
            Preconditions.checkArgument((types.size() == 1 ? 1 : 0) != 0, (String)"All operands must have the same type: %s", types);
            return this.setExpressionType((Expression)node, (Type)types.iterator().next());
        }

        protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, Context context) {
            return this.setExpressionType((Expression)node, this.process((Node)node.getValue(), context));
        }

        protected Type visitArithmeticBinary(ArithmeticBinaryExpression node, Context context) {
            ImmutableList.Builder argumentTypes = ImmutableList.builder();
            argumentTypes.add((Object)this.process((Node)node.getLeft(), context));
            argumentTypes.add((Object)this.process((Node)node.getRight(), context));
            BoundSignature operatorSignature = this.plannerContext.getMetadata().resolveOperator(OperatorType.valueOf((String)node.getOperator().name()), (List<? extends Type>)argumentTypes.build()).getSignature();
            return this.setExpressionType((Expression)node, operatorSignature.getReturnType());
        }

        protected Type visitSubscriptExpression(SubscriptExpression node, Context context) {
            Type baseType = this.process((Node)node.getBase(), context);
            this.process((Node)node.getIndex(), context);
            Type type = baseType;
            Objects.requireNonNull(type);
            Type type2 = type;
            int n = 0;
            return this.setExpressionType((Expression)node, switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{RowType.class, ArrayType.class, MapType.class}, (Object)type2, n)) {
                case 0 -> {
                    RowType rowType = (RowType)type2;
                    yield ((RowType.Field)rowType.getFields().get(Math.toIntExact(((LongLiteral)node.getIndex()).getParsedValue()) - 1)).getType();
                }
                case 1 -> {
                    ArrayType arrayType = (ArrayType)type2;
                    yield arrayType.getElementType();
                }
                case 2 -> {
                    MapType mapType = (MapType)type2;
                    yield mapType.getValueType();
                }
                default -> throw new IllegalStateException("Unexpected type: " + String.valueOf(baseType));
            });
        }

        protected Type visitArray(Array node, Context context) {
            Set types = node.getValues().stream().map(entry -> this.process((Node)entry, context)).collect(Collectors.toSet());
            if (types.isEmpty()) {
                return this.setExpressionType((Expression)node, (Type)new ArrayType((Type)UnknownType.UNKNOWN));
            }
            Preconditions.checkArgument((types.size() == 1 ? 1 : 0) != 0, (String)"All entries must have the same type: %s", types);
            return this.setExpressionType((Expression)node, (Type)new ArrayType((Type)types.iterator().next()));
        }

        protected Type visitStringLiteral(StringLiteral node, Context context) {
            return this.setExpressionType((Expression)node, (Type)VarcharType.createVarcharType((int)node.length()));
        }

        protected Type visitBinaryLiteral(BinaryLiteral node, Context context) {
            return this.setExpressionType((Expression)node, (Type)VarbinaryType.VARBINARY);
        }

        protected Type visitLongLiteral(LongLiteral node, Context context) {
            if (node.getParsedValue() >= Integer.MIN_VALUE && node.getParsedValue() <= Integer.MAX_VALUE) {
                return this.setExpressionType((Expression)node, (Type)IntegerType.INTEGER);
            }
            return this.setExpressionType((Expression)node, (Type)BigintType.BIGINT);
        }

        protected Type visitDoubleLiteral(DoubleLiteral node, Context context) {
            return this.setExpressionType((Expression)node, (Type)DoubleType.DOUBLE);
        }

        protected Type visitDecimalLiteral(DecimalLiteral node, Context context) {
            return this.setExpressionType((Expression)node, (Type)Decimals.parse((String)node.getValue()).getType());
        }

        protected Type visitBooleanLiteral(BooleanLiteral node, Context context) {
            return this.setExpressionType((Expression)node, (Type)BooleanType.BOOLEAN);
        }

        protected Type visitGenericLiteral(GenericLiteral node, Context context) {
            Type type;
            String string = node.getType();
            Objects.requireNonNull(string);
            String string2 = string;
            int n = 0;
            block7: while (true) {
                switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{String.class, String.class, String.class, String.class, String.class}, (Object)string2, n)) {
                    case 0: {
                        String name = string2;
                        if (!name.equalsIgnoreCase("CHAR")) {
                            n = 1;
                            continue block7;
                        }
                        type = CharType.createCharType((int)node.getValue().length());
                        break block7;
                    }
                    case 1: {
                        String name = string2;
                        if (!name.equalsIgnoreCase("TIMESTAMP") || !DateTimes.timestampHasTimeZone(node.getValue())) {
                            n = 2;
                            continue block7;
                        }
                        type = TimestampWithTimeZoneType.createTimestampWithTimeZoneType((int)DateTimes.extractTimestampPrecision(node.getValue()));
                        break block7;
                    }
                    case 2: {
                        String name = string2;
                        if (!name.equalsIgnoreCase("TIMESTAMP")) {
                            n = 3;
                            continue block7;
                        }
                        type = TimestampType.createTimestampType((int)DateTimes.extractTimestampPrecision(node.getValue()));
                        break block7;
                    }
                    case 3: {
                        String name = string2;
                        if (!name.equalsIgnoreCase("TIME") || !DateTimes.timeHasTimeZone(node.getValue())) {
                            n = 4;
                            continue block7;
                        }
                        type = TimeWithTimeZoneType.createTimeWithTimeZoneType((int)DateTimes.extractTimePrecision(node.getValue()));
                        break block7;
                    }
                    case 4: {
                        String name = string2;
                        if (!name.equalsIgnoreCase("TIME")) {
                            n = 5;
                            continue block7;
                        }
                        type = TimeType.createTimeType((int)DateTimes.extractTimePrecision(node.getValue()));
                        break block7;
                    }
                    default: {
                        type = (Type)CacheUtils.uncheckedCacheGet(this.varcharCastableTypeCache, (Object)node.getType(), () -> this.plannerContext.getTypeManager().fromSqlType(node.getType()));
                        break block7;
                    }
                }
                break;
            }
            return this.setExpressionType((Expression)node, type);
        }

        protected Type visitIntervalLiteral(IntervalLiteral node, Context context) {
            Object type = node.isYearToMonth() ? IntervalYearMonthType.INTERVAL_YEAR_MONTH : IntervalDayTimeType.INTERVAL_DAY_TIME;
            return this.setExpressionType((Expression)node, (Type)type);
        }

        protected Type visitNullLiteral(NullLiteral node, Context context) {
            return this.setExpressionType((Expression)node, (Type)UnknownType.UNKNOWN);
        }

        protected Type visitFunctionCall(FunctionCall node, Context context) {
            ResolvedFunction function = this.functionResolver.resolveFunction(this.session, node.getName(), null, ALLOW_ALL_ACCESS_CONTROL);
            BoundSignature signature = function.getSignature();
            for (int i = 0; i < node.getArguments().size(); ++i) {
                Expression expression;
                Expression argument = (Expression)node.getArguments().get(i);
                Type formalType = (Type)signature.getArgumentTypes().get(i);
                Objects.requireNonNull(argument);
                int n = 0;
                Type type = switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{LambdaExpression.class, BindExpression.class}, (Object)expression, n)) {
                    case 0 -> {
                        LambdaExpression lambda = (LambdaExpression)expression;
                        yield this.processLambdaExpression(lambda, ((FunctionType)formalType).getArgumentTypes());
                    }
                    case 1 -> {
                        BindExpression bind = (BindExpression)expression;
                        yield this.processBindExpression(bind, (FunctionType)formalType, context);
                    }
                    default -> this.process((Node)argument, context);
                };
            }
            return this.setExpressionType((Expression)node, signature.getReturnType());
        }

        private Type processBindExpression(BindExpression bind, FunctionType formalType, Context context) {
            ArrayList<Type> argumentTypes = new ArrayList<Type>();
            argumentTypes.addAll((Collection)bind.getValues().stream().map(value -> this.process((Node)value, context)).collect(ImmutableList.toImmutableList()));
            argumentTypes.addAll(formalType.getArgumentTypes());
            if (bind.getFunction() instanceof LambdaExpression) {
                Type unused = this.processLambdaExpression((LambdaExpression)bind.getFunction(), argumentTypes);
                return this.setExpressionType((Expression)bind, formalType);
            }
            throw new UnsupportedOperationException("not yet implemented");
        }

        private Type processLambdaExpression(LambdaExpression lambda, List<Type> argumentTypes) {
            ImmutableMap.Builder typeBindings = ImmutableMap.builder();
            for (int i = 0; i < argumentTypes.size(); ++i) {
                typeBindings.put((Object)new Symbol(((LambdaArgumentDeclaration)lambda.getArguments().get(i)).getName().getValue()), (Object)argumentTypes.get(i));
            }
            Type returnType = this.process((Node)lambda.getBody(), new Context((Map<Symbol, Type>)typeBindings.buildOrThrow()));
            return this.setExpressionType((Expression)lambda, new FunctionType(argumentTypes, returnType));
        }

        protected Type visitBetweenPredicate(BetweenPredicate node, Context context) {
            this.process((Node)node.getValue(), context);
            this.process((Node)node.getMin(), context);
            this.process((Node)node.getMax(), context);
            return this.setExpressionType((Expression)node, (Type)BooleanType.BOOLEAN);
        }

        public Type visitCast(Cast node, Context context) {
            this.process((Node)node.getExpression(), context);
            return this.setExpressionType((Expression)node, this.plannerContext.getTypeManager().getType(TypeSignatureTranslator.toTypeSignature(node.getType())));
        }

        protected Type visitInPredicate(InPredicate node, Context context) {
            Expression value = node.getValue();
            InListExpression valueList = (InListExpression)node.getValueList();
            Type type = this.process((Node)value, context);
            for (Expression item : valueList.getValues()) {
                Type itemType = this.process((Node)item, context);
                Preconditions.checkArgument((boolean)itemType.equals((Object)type), (String)"Types must be equal: %s vs %s", (Object)itemType, (Object)type);
            }
            this.setExpressionType((Expression)valueList, type);
            return this.setExpressionType((Expression)node, (Type)BooleanType.BOOLEAN);
        }

        protected Type visitExpression(Expression node, Context context) {
            throw new UnsupportedOperationException("Not a valid IR expression: " + node.getClass().getName());
        }

        protected Type visitNode(Node node, Context context) {
            throw new UnsupportedOperationException("Not a valid IR expression: " + node.getClass().getName());
        }
    }

    private record Context(Map<Symbol, Type> argumentTypes) {
    }
}

