/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import io.trino.Session;
import io.trino.metadata.OperatorNotFoundException;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.type.DateType;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.InterpretedFunctionInvoker;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.ExpressionRewriter;
import io.trino.sql.tree.ExpressionTreeRewriter;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.NullLiteral;
import java.util.Objects;
import java.util.Optional;

public class UnwrapTimestampToDateCastInComparison
extends ExpressionRewriteRuleSet {
    public UnwrapTimestampToDateCastInComparison(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) {
        super(UnwrapTimestampToDateCastInComparison.createRewrite(plannerContext, typeAnalyzer));
    }

    private static ExpressionRewriteRuleSet.ExpressionRewriter createRewrite(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) {
        Objects.requireNonNull(plannerContext, "plannerContext is null");
        Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
        return (expression, context) -> UnwrapTimestampToDateCastInComparison.unwrapCasts(context.getSession(), plannerContext, typeAnalyzer, context.getSymbolAllocator().getTypes(), expression);
    }

    public static Expression unwrapCasts(Session session, PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, TypeProvider types, Expression expression) {
        return ExpressionTreeRewriter.rewriteWith((ExpressionRewriter)new Visitor(plannerContext, typeAnalyzer, session, types), (Expression)expression);
    }

    private static class Visitor
    extends ExpressionRewriter<Void> {
        private final PlannerContext plannerContext;
        private final TypeAnalyzer typeAnalyzer;
        private final Session session;
        private final TypeProvider types;
        private final InterpretedFunctionInvoker functionInvoker;
        private final LiteralEncoder literalEncoder;

        public Visitor(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.types = Objects.requireNonNull(types, "types is null");
            this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getMetadata());
            this.literalEncoder = new LiteralEncoder(plannerContext);
        }

        public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            ComparisonExpression expression = (ComparisonExpression)treeRewriter.defaultRewrite((Expression)node, null);
            return this.unwrapCast(expression);
        }

        private Expression unwrapCast(ComparisonExpression expression) {
            if (!(expression.getLeft() instanceof Cast)) {
                return expression;
            }
            Object right = new ExpressionInterpreter(expression.getRight(), this.plannerContext, this.session, this.typeAnalyzer.getTypes(this.session, this.types, expression.getRight())).optimize(NoOpSymbolResolver.INSTANCE);
            Cast cast = (Cast)expression.getLeft();
            ComparisonExpression.Operator operator = expression.getOperator();
            if (right == null || right instanceof NullLiteral) {
                return expression;
            }
            if (right instanceof Expression) {
                return expression;
            }
            Type sourceType = this.typeAnalyzer.getType(this.session, this.types, cast.getExpression());
            Type targetType = this.typeAnalyzer.getType(this.session, this.types, expression.getRight());
            if (sourceType instanceof TimestampType && targetType == DateType.DATE) {
                return this.unwrapTimestampToDateCast(this.session, (TimestampType)sourceType, (DateType)targetType, operator, cast.getExpression(), (Long)right).orElse((Expression)expression);
            }
            return expression;
        }

        private Optional<Expression> unwrapTimestampToDateCast(Session session, TimestampType sourceType, DateType targetType, ComparisonExpression.Operator operator, Expression timestampExpression, long date) {
            ResolvedFunction targetToSource;
            try {
                targetToSource = this.plannerContext.getMetadata().getCoercion(session, (Type)targetType, (Type)sourceType);
            }
            catch (OperatorNotFoundException e) {
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, (Throwable)((Object)e));
            }
            Expression dateTimestamp = this.literalEncoder.toExpression(session, this.coerce(date, targetToSource), (Type)sourceType);
            Expression nextDateTimestamp = this.literalEncoder.toExpression(session, this.coerce(date + 1L, targetToSource), (Type)sourceType);
            switch (operator) {
                case EQUAL: {
                    return Optional.of(ExpressionUtils.and(new Expression[]{new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp), new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, timestampExpression, nextDateTimestamp)}));
                }
                case NOT_EQUAL: {
                    return Optional.of(ExpressionUtils.or(new Expression[]{new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, timestampExpression, dateTimestamp), new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp)}));
                }
                case LESS_THAN: {
                    return Optional.of(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, timestampExpression, dateTimestamp));
                }
                case LESS_THAN_OR_EQUAL: {
                    return Optional.of(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, timestampExpression, nextDateTimestamp));
                }
                case GREATER_THAN: {
                    return Optional.of(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp));
                }
                case GREATER_THAN_OR_EQUAL: {
                    return Optional.of(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp));
                }
                case IS_DISTINCT_FROM: {
                    return Optional.of(ExpressionUtils.or(new Expression[]{new IsNullPredicate(timestampExpression), new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, timestampExpression, dateTimestamp), new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp)}));
                }
            }
            throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Unsupported operator: " + operator);
        }

        private Object coerce(Object value, ResolvedFunction coercion) {
            return this.functionInvoker.invoke(coercion, this.session.toConnectorSession(), value);
        }
    }
}

