/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.relational;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.trino.Session;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DecimalParseResult;
import io.trino.spi.type.Decimals;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.TimeType;
import io.trino.spi.type.TimeWithTimeZoneType;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.ArithmeticBinaryExpression;
import io.trino.sql.ir.ArithmeticUnaryExpression;
import io.trino.sql.ir.BetweenPredicate;
import io.trino.sql.ir.BinaryLiteral;
import io.trino.sql.ir.BindExpression;
import io.trino.sql.ir.BooleanLiteral;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.CoalesceExpression;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.DecimalLiteral;
import io.trino.sql.ir.DoubleLiteral;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.GenericLiteral;
import io.trino.sql.ir.IfExpression;
import io.trino.sql.ir.InPredicate;
import io.trino.sql.ir.IntervalLiteral;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.IsNotNullPredicate;
import io.trino.sql.ir.IsNullPredicate;
import io.trino.sql.ir.LambdaExpression;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.LongLiteral;
import io.trino.sql.ir.NodeRef;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.NullIfExpression;
import io.trino.sql.ir.NullLiteral;
import io.trino.sql.ir.Row;
import io.trino.sql.ir.SearchedCaseExpression;
import io.trino.sql.ir.SimpleCaseExpression;
import io.trino.sql.ir.StringLiteral;
import io.trino.sql.ir.SubscriptExpression;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.Symbol;
import io.trino.sql.relational.CallExpression;
import io.trino.sql.relational.ConstantExpression;
import io.trino.sql.relational.Expressions;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.LambdaDefinitionExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.RowExpressionVisitor;
import io.trino.sql.relational.SpecialForm;
import io.trino.sql.relational.StandardFunctionResolution;
import io.trino.sql.relational.VariableReferenceExpression;
import io.trino.sql.relational.optimizer.ExpressionOptimizer;
import io.trino.type.DateTimes;
import io.trino.type.JsonType;
import io.trino.type.TypeCoercion;
import io.trino.type.UnknownType;
import io.trino.util.DateTimeUtils;
import java.lang.runtime.SwitchBootstraps;
import java.util.List;
import java.util.Map;
import java.util.Objects;

public final class SqlToRowExpressionTranslator {
    private SqlToRowExpressionTranslator() {
    }

    public static RowExpression translate(Expression expression, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout, Metadata metadata, FunctionManager functionManager, TypeManager typeManager, Session session, boolean optimize) {
        Visitor visitor = new Visitor(metadata, typeManager, types, layout);
        RowExpression result = (RowExpression)visitor.process(expression, null);
        Objects.requireNonNull(result, "result is null");
        if (optimize) {
            ExpressionOptimizer optimizer = new ExpressionOptimizer(metadata, functionManager, session);
            return optimizer.optimize(result);
        }
        return result;
    }

    public static class Visitor
    extends IrVisitor<RowExpression, Void> {
        private final Metadata metadata;
        private final TypeCoercion typeCoercion;
        private final Map<NodeRef<Expression>, Type> types;
        private final Map<Symbol, Integer> layout;
        private final StandardFunctionResolution standardFunctionResolution;

        protected Visitor(Metadata metadata, TypeManager typeManager, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout) {
            this.metadata = metadata;
            this.typeCoercion = new TypeCoercion(arg_0 -> ((TypeManager)typeManager).getType(arg_0));
            this.types = ImmutableMap.copyOf(Objects.requireNonNull(types, "types is null"));
            this.layout = layout;
            this.standardFunctionResolution = new StandardFunctionResolution(metadata);
        }

        private Type getType(Expression node) {
            return this.types.get(NodeRef.of(node));
        }

        @Override
        protected RowExpression visitExpression(Expression node, Void context) {
            throw new UnsupportedOperationException("not yet implemented: expression translator for " + node.getClass().getName());
        }

        @Override
        protected RowExpression visitNullLiteral(NullLiteral node, Void context) {
            return Expressions.constantNull((Type)UnknownType.UNKNOWN);
        }

        @Override
        protected RowExpression visitBooleanLiteral(BooleanLiteral node, Void context) {
            return Expressions.constant(node.getValue(), (Type)BooleanType.BOOLEAN);
        }

        @Override
        protected RowExpression visitLongLiteral(LongLiteral node, Void context) {
            if (node.getValue() >= Integer.MIN_VALUE && node.getValue() <= Integer.MAX_VALUE) {
                return Expressions.constant(node.getValue(), (Type)IntegerType.INTEGER);
            }
            return Expressions.constant(node.getValue(), (Type)BigintType.BIGINT);
        }

        @Override
        protected RowExpression visitDoubleLiteral(DoubleLiteral node, Void context) {
            return Expressions.constant(node.getValue(), (Type)DoubleType.DOUBLE);
        }

        @Override
        protected RowExpression visitDecimalLiteral(DecimalLiteral node, Void context) {
            DecimalParseResult parseResult = Decimals.parse((String)node.getValue());
            return Expressions.constant(parseResult.getObject(), (Type)parseResult.getType());
        }

        @Override
        protected RowExpression visitStringLiteral(StringLiteral node, Void context) {
            return Expressions.constant(Slices.utf8Slice((String)node.getValue()), (Type)VarcharType.createVarcharType((int)node.length()));
        }

        @Override
        protected RowExpression visitBinaryLiteral(BinaryLiteral node, Void context) {
            return Expressions.constant(Slices.wrappedBuffer((byte[])node.getValue()), (Type)VarbinaryType.VARBINARY);
        }

        @Override
        protected RowExpression visitGenericLiteral(GenericLiteral node, Void context) {
            Type type = this.getType(node);
            Objects.requireNonNull(type);
            Type type2 = type;
            int n = 0;
            return switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{CharType.class, TimeType.class, TimeWithTimeZoneType.class, TimestampType.class, TimestampWithTimeZoneType.class, JsonType.class, Type.class}, (Object)type2, n)) {
                case 0 -> {
                    CharType type = (CharType)type2;
                    yield Expressions.constant(Slices.utf8Slice((String)node.getValue()), (Type)type);
                }
                case 1 -> {
                    TimeType type = (TimeType)type2;
                    yield Expressions.constant(DateTimes.parseTime(node.getValue()), (Type)type);
                }
                case 2 -> {
                    TimeWithTimeZoneType type = (TimeWithTimeZoneType)type2;
                    yield Expressions.constant(DateTimes.parseTimeWithTimeZone(type.getPrecision(), node.getValue()), (Type)type);
                }
                case 3 -> {
                    TimestampType type = (TimestampType)type2;
                    yield Expressions.constant(DateTimes.parseTimestamp(type.getPrecision(), node.getValue()), (Type)type);
                }
                case 4 -> {
                    TimestampWithTimeZoneType type = (TimestampWithTimeZoneType)type2;
                    yield Expressions.constant(DateTimes.parseTimestampWithTimeZone(type.getPrecision(), node.getValue()), (Type)type);
                }
                case 5 -> {
                    JsonType unused = (JsonType)type2;
                    yield Expressions.call(this.metadata.resolveBuiltinFunction("json_parse", TypeSignatureProvider.fromTypes(new Type[]{VarcharType.VARCHAR})), Expressions.constant(Slices.utf8Slice((String)node.getValue()), (Type)VarcharType.VARCHAR));
                }
                default -> {
                    Type type = type2;
                    yield Expressions.call(this.metadata.getCoercion((Type)VarcharType.VARCHAR, type), Expressions.constant(Slices.utf8Slice((String)node.getValue()), (Type)VarcharType.VARCHAR));
                }
            };
        }

        @Override
        protected RowExpression visitIntervalLiteral(IntervalLiteral node, Void context) {
            long value = node.isYearToMonth() ? (long)node.getSign().multiplier() * DateTimeUtils.parseYearMonthInterval(node.getValue(), node.getStartField(), node.getEndField()) : (long)node.getSign().multiplier() * DateTimeUtils.parseDayTimeInterval(node.getValue(), node.getStartField(), node.getEndField());
            return Expressions.constant(value, this.getType(node));
        }

        @Override
        protected RowExpression visitComparisonExpression(ComparisonExpression node, Void context) {
            RowExpression left = (RowExpression)this.process(node.getLeft(), context);
            RowExpression right = (RowExpression)this.process(node.getRight(), context);
            ComparisonExpression.Operator operator = node.getOperator();
            switch (node.getOperator()) {
                case NOT_EQUAL: {
                    return new CallExpression(this.metadata.resolveBuiltinFunction("not", TypeSignatureProvider.fromTypes(new Type[]{BooleanType.BOOLEAN})), (List<RowExpression>)ImmutableList.of((Object)this.visitComparisonExpression(ComparisonExpression.Operator.EQUAL, left, right)));
                }
                case GREATER_THAN: {
                    return this.visitComparisonExpression(ComparisonExpression.Operator.LESS_THAN, right, left);
                }
                case GREATER_THAN_OR_EQUAL: {
                    return this.visitComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, right, left);
                }
            }
            return this.visitComparisonExpression(operator, left, right);
        }

        private RowExpression visitComparisonExpression(ComparisonExpression.Operator operator, RowExpression left, RowExpression right) {
            return Expressions.call(this.standardFunctionResolution.comparisonFunction(operator, left.getType(), right.getType()), left, right);
        }

        @Override
        protected RowExpression visitFunctionCall(FunctionCall node, Void context) {
            List arguments = (List)node.getArguments().stream().map(value -> (RowExpression)this.process((Expression)value, context)).collect(ImmutableList.toImmutableList());
            return new CallExpression(this.metadata.decodeFunction(node.getName()), arguments);
        }

        @Override
        protected RowExpression visitSymbolReference(SymbolReference node, Void context) {
            Integer field = this.layout.get(Symbol.from(node));
            if (field != null) {
                return Expressions.field(field, this.getType(node));
            }
            return new VariableReferenceExpression(node.getName(), this.getType(node));
        }

        @Override
        protected RowExpression visitLambdaExpression(LambdaExpression node, Void context) {
            RowExpression body = (RowExpression)this.process(node.getBody(), context);
            Type type = this.getType(node);
            List typeParameters = type.getTypeParameters();
            List<Type> argumentTypes = typeParameters.subList(0, typeParameters.size() - 1);
            List<String> argumentNames = node.getArguments();
            return new LambdaDefinitionExpression(argumentTypes, argumentNames, body);
        }

        @Override
        protected RowExpression visitBindExpression(BindExpression node, Void context) {
            ImmutableList.Builder valueTypesBuilder = ImmutableList.builder();
            ImmutableList.Builder argumentsBuilder = ImmutableList.builder();
            for (Expression value : node.getValues()) {
                RowExpression valueRowExpression = (RowExpression)this.process(value, context);
                valueTypesBuilder.add((Object)valueRowExpression.getType());
                argumentsBuilder.add((Object)valueRowExpression);
            }
            RowExpression function = (RowExpression)this.process(node.getFunction(), context);
            argumentsBuilder.add((Object)function);
            return new SpecialForm(SpecialForm.Form.BIND, this.getType(node), (List<RowExpression>)argumentsBuilder.build());
        }

        @Override
        protected RowExpression visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) {
            RowExpression left = (RowExpression)this.process(node.getLeft(), context);
            RowExpression right = (RowExpression)this.process(node.getRight(), context);
            return Expressions.call(this.standardFunctionResolution.arithmeticFunction(node.getOperator(), left.getType(), right.getType()), left, right);
        }

        @Override
        protected RowExpression visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) {
            RowExpression expression = (RowExpression)this.process(node.getValue(), context);
            switch (node.getSign()) {
                case PLUS: {
                    return expression;
                }
                case MINUS: {
                    return Expressions.call(this.metadata.resolveOperator(OperatorType.NEGATION, (List<? extends Type>)ImmutableList.of((Object)expression.getType())), expression);
                }
            }
            throw new UnsupportedOperationException("Unsupported unary operator: " + String.valueOf((Object)node.getSign()));
        }

        @Override
        protected RowExpression visitLogicalExpression(LogicalExpression node, Void context) {
            return new SpecialForm(switch (node.getOperator()) {
                case LogicalExpression.Operator.AND -> SpecialForm.Form.AND;
                case LogicalExpression.Operator.OR -> SpecialForm.Form.OR;
                default -> throw new IllegalStateException("Unknown logical operator: " + String.valueOf((Object)node.getOperator()));
            }, (Type)BooleanType.BOOLEAN, (List)node.getTerms().stream().map(term -> (RowExpression)this.process((Expression)term, context)).collect(ImmutableList.toImmutableList()));
        }

        @Override
        protected RowExpression visitCast(Cast node, Void context) {
            RowExpression value = (RowExpression)this.process(node.getExpression(), context);
            Type returnType = this.getType(node);
            if (this.typeCoercion.isTypeOnlyCoercion(value.getType(), returnType)) {
                return Visitor.changeType(value, returnType);
            }
            if (node.isSafe()) {
                return Expressions.call(this.metadata.getCoercion(GlobalFunctionCatalog.builtinFunctionName("TRY_CAST"), value.getType(), returnType), value);
            }
            return Expressions.call(this.metadata.getCoercion(value.getType(), returnType), value);
        }

        private static RowExpression changeType(RowExpression value, Type targetType) {
            ChangeTypeVisitor visitor = new ChangeTypeVisitor(targetType);
            return value.accept(visitor, null);
        }

        @Override
        protected RowExpression visitCoalesceExpression(CoalesceExpression node, Void context) {
            List arguments = (List)node.getOperands().stream().map(value -> (RowExpression)this.process((Expression)value, context)).collect(ImmutableList.toImmutableList());
            return new SpecialForm(SpecialForm.Form.COALESCE, this.getType(node), arguments);
        }

        @Override
        protected RowExpression visitSimpleCaseExpression(SimpleCaseExpression node, Void context) {
            ImmutableList.Builder arguments = ImmutableList.builder();
            RowExpression value = (RowExpression)this.process(node.getOperand(), context);
            arguments.add((Object)value);
            ImmutableList.Builder functionDependencies = ImmutableList.builder();
            for (WhenClause clause : node.getWhenClauses()) {
                RowExpression operand = (RowExpression)this.process(clause.getOperand(), context);
                RowExpression result = (RowExpression)this.process(clause.getResult(), context);
                functionDependencies.add((Object)this.metadata.resolveOperator(OperatorType.EQUAL, (List<? extends Type>)ImmutableList.of((Object)value.getType(), (Object)operand.getType())));
                arguments.add((Object)new SpecialForm(SpecialForm.Form.WHEN, this.getType(clause), operand, result));
            }
            Type returnType = this.getType(node);
            arguments.add((Object)node.getDefaultValue().map(defaultValue -> (RowExpression)this.process((Expression)defaultValue, context)).orElse(Expressions.constantNull(returnType)));
            return new SpecialForm(SpecialForm.Form.SWITCH, returnType, (List<RowExpression>)arguments.build(), (List<ResolvedFunction>)functionDependencies.build());
        }

        @Override
        protected RowExpression visitSearchedCaseExpression(SearchedCaseExpression node, Void context) {
            RowExpression expression = node.getDefaultValue().map(value -> (RowExpression)this.process((Expression)value, context)).orElse(Expressions.constantNull(this.getType(node)));
            for (WhenClause clause : node.getWhenClauses().reversed()) {
                expression = new SpecialForm(SpecialForm.Form.IF, this.getType(node), (RowExpression)this.process(clause.getOperand(), context), (RowExpression)this.process(clause.getResult(), context), expression);
            }
            return expression;
        }

        @Override
        protected RowExpression visitIfExpression(IfExpression node, Void context) {
            ImmutableList.Builder arguments = ImmutableList.builder();
            arguments.add((Object)((RowExpression)this.process(node.getCondition(), context))).add((Object)((RowExpression)this.process(node.getTrueValue(), context)));
            if (node.getFalseValue().isPresent()) {
                arguments.add((Object)((RowExpression)this.process(node.getFalseValue().get(), context)));
            } else {
                arguments.add((Object)Expressions.constantNull(this.getType(node)));
            }
            return new SpecialForm(SpecialForm.Form.IF, this.getType(node), (List<RowExpression>)arguments.build());
        }

        @Override
        protected RowExpression visitInPredicate(InPredicate node, Void context) {
            ImmutableList.Builder arguments = ImmutableList.builder();
            RowExpression value = (RowExpression)this.process(node.getValue(), context);
            arguments.add((Object)value);
            for (Expression testValue : node.getValueList()) {
                arguments.add((Object)((RowExpression)this.process(testValue, context)));
            }
            ImmutableList functionDependencies = ImmutableList.builder().add((Object)this.metadata.resolveOperator(OperatorType.EQUAL, (List<? extends Type>)ImmutableList.of((Object)value.getType(), (Object)value.getType()))).add((Object)this.metadata.resolveOperator(OperatorType.HASH_CODE, (List<? extends Type>)ImmutableList.of((Object)value.getType()))).add((Object)this.metadata.resolveOperator(OperatorType.INDETERMINATE, (List<? extends Type>)ImmutableList.of((Object)value.getType()))).build();
            return new SpecialForm(SpecialForm.Form.IN, (Type)BooleanType.BOOLEAN, (List<RowExpression>)arguments.build(), (List<ResolvedFunction>)functionDependencies);
        }

        @Override
        protected RowExpression visitIsNotNullPredicate(IsNotNullPredicate node, Void context) {
            RowExpression expression = (RowExpression)this.process(node.getValue(), context);
            return this.notExpression(new SpecialForm(SpecialForm.Form.IS_NULL, (Type)BooleanType.BOOLEAN, (List<RowExpression>)ImmutableList.of((Object)expression)));
        }

        @Override
        protected RowExpression visitIsNullPredicate(IsNullPredicate node, Void context) {
            RowExpression expression = (RowExpression)this.process(node.getValue(), context);
            return new SpecialForm(SpecialForm.Form.IS_NULL, (Type)BooleanType.BOOLEAN, expression);
        }

        @Override
        protected RowExpression visitNotExpression(NotExpression node, Void context) {
            return this.notExpression((RowExpression)this.process(node.getValue(), context));
        }

        private RowExpression notExpression(RowExpression value) {
            return new CallExpression(this.metadata.resolveBuiltinFunction("not", TypeSignatureProvider.fromTypes(new Type[]{BooleanType.BOOLEAN})), (List<RowExpression>)ImmutableList.of((Object)value));
        }

        @Override
        protected RowExpression visitNullIfExpression(NullIfExpression node, Void context) {
            RowExpression first = (RowExpression)this.process(node.getFirst(), context);
            RowExpression second = (RowExpression)this.process(node.getSecond(), context);
            ResolvedFunction resolvedFunction = this.metadata.resolveOperator(OperatorType.EQUAL, (List<? extends Type>)ImmutableList.of((Object)first.getType(), (Object)second.getType()));
            ImmutableList functionDependencies = ImmutableList.builder().add((Object)resolvedFunction).add((Object)this.metadata.getCoercion(first.getType(), (Type)resolvedFunction.getSignature().getArgumentTypes().get(0))).add((Object)this.metadata.getCoercion(second.getType(), (Type)resolvedFunction.getSignature().getArgumentTypes().get(0))).build();
            return new SpecialForm(SpecialForm.Form.NULL_IF, this.getType(node), (List<RowExpression>)ImmutableList.of((Object)first, (Object)second), (List<ResolvedFunction>)functionDependencies);
        }

        @Override
        protected RowExpression visitBetweenPredicate(BetweenPredicate node, Void context) {
            RowExpression value = (RowExpression)this.process(node.getValue(), context);
            RowExpression min = (RowExpression)this.process(node.getMin(), context);
            RowExpression max = (RowExpression)this.process(node.getMax(), context);
            ImmutableList functionDependencies = ImmutableList.of((Object)this.metadata.resolveOperator(OperatorType.LESS_THAN_OR_EQUAL, (List<? extends Type>)ImmutableList.of((Object)value.getType(), (Object)max.getType())));
            return new SpecialForm(SpecialForm.Form.BETWEEN, (Type)BooleanType.BOOLEAN, (List<RowExpression>)ImmutableList.of((Object)value, (Object)min, (Object)max), (List<ResolvedFunction>)functionDependencies);
        }

        @Override
        protected RowExpression visitSubscriptExpression(SubscriptExpression node, Void context) {
            RowExpression base = (RowExpression)this.process(node.getBase(), context);
            RowExpression index = (RowExpression)this.process(node.getIndex(), context);
            if (this.getType(node.getBase()) instanceof RowType) {
                long value = (Long)((ConstantExpression)index).getValue();
                return new SpecialForm(SpecialForm.Form.DEREFERENCE, this.getType(node), base, Expressions.constant(value - 1L, (Type)IntegerType.INTEGER));
            }
            return Expressions.call(this.metadata.resolveOperator(OperatorType.SUBSCRIPT, (List<? extends Type>)ImmutableList.of((Object)base.getType(), (Object)index.getType())), base, index);
        }

        @Override
        protected RowExpression visitRow(Row node, Void context) {
            List arguments = (List)node.getItems().stream().map(value -> (RowExpression)this.process((Expression)value, context)).collect(ImmutableList.toImmutableList());
            Type returnType = this.getType(node);
            return new SpecialForm(SpecialForm.Form.ROW_CONSTRUCTOR, returnType, arguments);
        }

        private static class ChangeTypeVisitor
        implements RowExpressionVisitor<RowExpression, Void> {
            private final Type targetType;

            private ChangeTypeVisitor(Type targetType) {
                this.targetType = targetType;
            }

            @Override
            public RowExpression visitCall(CallExpression call, Void context) {
                return new CallExpression(call.getResolvedFunction(), call.getArguments());
            }

            @Override
            public RowExpression visitSpecialForm(SpecialForm specialForm, Void context) {
                return new SpecialForm(specialForm.getForm(), this.targetType, specialForm.getArguments(), specialForm.getFunctionDependencies());
            }

            @Override
            public RowExpression visitInputReference(InputReferenceExpression reference, Void context) {
                return Expressions.field(reference.getField(), this.targetType);
            }

            @Override
            public RowExpression visitConstant(ConstantExpression literal, Void context) {
                return Expressions.constant(literal.getValue(), this.targetType);
            }

            @Override
            public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context) {
                throw new UnsupportedOperationException();
            }

            @Override
            public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context) {
                return new VariableReferenceExpression(reference.getName(), this.targetType);
            }
        }
    }
}

