/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Throwables;
import com.google.common.base.Verify;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceUtf8;
import io.trino.Session;
import io.trino.metadata.OperatorNotFoundException;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DateTimeEncoding;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.RealType;
import io.trino.spi.type.TimeWithTimeZoneType;
import io.trino.spi.type.TimeZoneKey;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeManager;
import io.trino.spi.type.TypeUtils;
import io.trino.spi.type.VarcharType;
import io.trino.sql.InterpretedFunctionInvoker;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IsNull;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet;
import io.trino.type.TypeCoercion;
import java.lang.invoke.MethodHandle;
import java.time.Instant;
import java.time.ZoneId;
import java.time.temporal.ChronoUnit;
import java.time.zone.ZoneOffsetTransition;
import java.util.Objects;
import java.util.Optional;

public class UnwrapCastInComparison
extends ExpressionRewriteRuleSet {
    public UnwrapCastInComparison(PlannerContext plannerContext) {
        super(UnwrapCastInComparison.createRewrite(plannerContext));
    }

    private static ExpressionRewriteRuleSet.ExpressionRewriter createRewrite(PlannerContext plannerContext) {
        Objects.requireNonNull(plannerContext, "plannerContext is null");
        return (expression, context) -> UnwrapCastInComparison.unwrapCasts(context.getSession(), plannerContext, expression);
    }

    public static Expression unwrapCasts(Session session, PlannerContext plannerContext, Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext, session), expression);
    }

    private static Object withTimeZone(TimestampWithTimeZoneType type, Object value, TimeZoneKey newZone) {
        if (type.isShort()) {
            return DateTimeEncoding.packDateTimeWithZone((long)DateTimeEncoding.unpackMillisUtc((long)((Long)value)), (TimeZoneKey)newZone);
        }
        LongTimestampWithTimeZone longTimestampWithTimeZone = (LongTimestampWithTimeZone)value;
        return LongTimestampWithTimeZone.fromEpochMillisAndFraction((long)longTimestampWithTimeZone.getEpochMillis(), (int)longTimestampWithTimeZone.getPicosOfMilli(), (TimeZoneKey)newZone);
    }

    private static TimeZoneKey getTimeZone(TimestampWithTimeZoneType type, Object value) {
        if (type.isShort()) {
            return DateTimeEncoding.unpackZoneKey((long)((Long)value));
        }
        return TimeZoneKey.getTimeZoneKey((short)((LongTimestampWithTimeZone)value).getTimeZoneKey());
    }

    @VisibleForTesting
    static boolean isTimestampToTimestampWithTimeZoneInjectiveAt(ZoneId zone, Instant instant) {
        ZoneOffsetTransition transition = zone.getRules().previousTransition(instant.plusNanos(1L));
        return transition == null || transition.getDuration().isNegative() || transition.getDateTimeAfter().minusNanos(1L).atZone(zone).toInstant().isBefore(instant);
    }

    private static Instant getInstantWithTruncation(TimestampWithTimeZoneType type, Object value) {
        if (type.isShort()) {
            return Instant.ofEpochMilli(DateTimeEncoding.unpackMillisUtc((long)((Long)value)));
        }
        LongTimestampWithTimeZone longTimestampWithTimeZone = (LongTimestampWithTimeZone)value;
        return Instant.ofEpochMilli(longTimestampWithTimeZone.getEpochMillis()).plus((long)(longTimestampWithTimeZone.getPicosOfMilli() / 1000), ChronoUnit.NANOS);
    }

    public static Expression falseIfNotNull(Expression argument) {
        return IrUtils.and(new IsNull(argument), new Constant((Type)BooleanType.BOOLEAN, null));
    }

    private static class Visitor
    extends ExpressionRewriter<Void> {
        private final PlannerContext plannerContext;
        private final Session session;
        private final InterpretedFunctionInvoker functionInvoker;

        public Visitor(PlannerContext plannerContext, Session session) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getFunctionManager());
        }

        @Override
        public Expression rewriteComparison(Comparison node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            Comparison expression = treeRewriter.defaultRewrite(node, null);
            return this.unwrapCast(expression);
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        private Expression unwrapCast(Comparison expression) {
            Expression expression2;
            Object literalInSourceType;
            ResolvedFunction targetToSource;
            Object object;
            Constant constant;
            Expression expression22 = expression.left();
            if (!(expression22 instanceof Cast)) {
                return expression;
            }
            Cast cast = (Cast)expression22;
            Expression right = new IrExpressionInterpreter(expression.right(), this.plannerContext, this.session).optimize();
            Comparison.Operator operator = expression.operator();
            if (right instanceof Constant && (constant = (Constant)right).value() == null) {
                Record record;
                switch (operator) {
                    default: {
                        throw new MatchException(null, null);
                    }
                    case EQUAL: 
                    case NOT_EQUAL: 
                    case LESS_THAN: 
                    case LESS_THAN_OR_EQUAL: 
                    case GREATER_THAN: 
                    case GREATER_THAN_OR_EQUAL: {
                        record = new Constant((Type)BooleanType.BOOLEAN, null);
                        return record;
                    }
                    case IDENTICAL: {
                        record = new IsNull(cast);
                    }
                }
                return record;
            }
            if (!(right instanceof Constant)) return expression;
            Constant constant2 = (Constant)right;
            try {
                Type type = object = constant2.type();
            }
            catch (Throwable throwable) {
                throw new MatchException(throwable.toString(), throwable);
            }
            Object rightValue = object = constant2.value();
            Type sourceType = cast.expression().type();
            Type targetType = expression.right().type();
            if (sourceType instanceof TimestampType && targetType == DateType.DATE) {
                return this.unwrapTimestampToDateCast((TimestampType)sourceType, operator, cast.expression(), (Long)rightValue).orElse(expression);
            }
            if (targetType instanceof TimestampWithTimeZoneType) {
                rightValue = UnwrapCastInComparison.withTimeZone((TimestampWithTimeZoneType)targetType, rightValue, this.session.getTimeZoneKey());
            }
            if (!this.hasInjectiveImplicitCoercion(sourceType, targetType, rightValue)) {
                return expression;
            }
            if (TypeUtils.isFloatingPointNaN((Type)targetType, (Object)rightValue)) {
                switch (operator) {
                    case EQUAL: 
                    case LESS_THAN: 
                    case LESS_THAN_OR_EQUAL: 
                    case GREATER_THAN: 
                    case GREATER_THAN_OR_EQUAL: {
                        return UnwrapCastInComparison.falseIfNotNull(cast.expression());
                    }
                    case NOT_EQUAL: {
                        return this.trueIfNotNull(cast.expression());
                    }
                    case IDENTICAL: {
                        if (this.typeHasNaN(sourceType)) break;
                        return Booleans.FALSE;
                    }
                    default: {
                        throw new UnsupportedOperationException("Not yet implemented: " + String.valueOf((Object)operator));
                    }
                }
            }
            ResolvedFunction sourceToTarget = this.plannerContext.getMetadata().getCoercion(sourceType, targetType);
            Optional sourceRange = sourceType.getRange();
            if (sourceRange.isPresent()) {
                Object max = ((Type.Range)sourceRange.get()).getMax();
                Object maxInTargetType = null;
                try {
                    maxInTargetType = this.coerce(max, sourceToTarget);
                }
                catch (RuntimeException runtimeException) {
                    // empty catch block
                }
                if (maxInTargetType != null) {
                    int upperBoundComparison = this.compare(targetType, rightValue, maxInTargetType);
                    if (upperBoundComparison > 0) {
                        Expression expression3;
                        switch (operator) {
                            default: {
                                throw new MatchException(null, null);
                            }
                            case EQUAL: 
                            case GREATER_THAN: 
                            case GREATER_THAN_OR_EQUAL: {
                                expression3 = UnwrapCastInComparison.falseIfNotNull(cast.expression());
                                return expression3;
                            }
                            case NOT_EQUAL: 
                            case LESS_THAN: 
                            case LESS_THAN_OR_EQUAL: {
                                expression3 = this.trueIfNotNull(cast.expression());
                                return expression3;
                            }
                            case IDENTICAL: {
                                expression3 = Booleans.FALSE;
                            }
                        }
                        return expression3;
                    }
                    if (upperBoundComparison == 0) {
                        Expression expression4;
                        switch (operator) {
                            default: {
                                throw new MatchException(null, null);
                            }
                            case GREATER_THAN: {
                                expression4 = UnwrapCastInComparison.falseIfNotNull(cast.expression());
                                return expression4;
                            }
                            case GREATER_THAN_OR_EQUAL: {
                                expression4 = new Comparison(Comparison.Operator.EQUAL, cast.expression(), new Constant(sourceType, max));
                                return expression4;
                            }
                            case LESS_THAN_OR_EQUAL: {
                                expression4 = this.trueIfNotNull(cast.expression());
                                return expression4;
                            }
                            case LESS_THAN: {
                                expression4 = new Comparison(Comparison.Operator.NOT_EQUAL, cast.expression(), new Constant(sourceType, max));
                                return expression4;
                            }
                            case EQUAL: 
                            case NOT_EQUAL: 
                            case IDENTICAL: {
                                expression4 = new Comparison(operator, cast.expression(), new Constant(sourceType, max));
                            }
                        }
                        return expression4;
                    }
                    Object min = ((Type.Range)sourceRange.get()).getMin();
                    Object minInTargetType = this.coerce(min, sourceToTarget);
                    int lowerBoundComparison = this.compare(targetType, rightValue, minInTargetType);
                    if (lowerBoundComparison < 0) {
                        Expression expression5;
                        switch (operator) {
                            default: {
                                throw new MatchException(null, null);
                            }
                            case NOT_EQUAL: 
                            case GREATER_THAN: 
                            case GREATER_THAN_OR_EQUAL: {
                                expression5 = this.trueIfNotNull(cast.expression());
                                return expression5;
                            }
                            case EQUAL: 
                            case LESS_THAN: 
                            case LESS_THAN_OR_EQUAL: {
                                expression5 = UnwrapCastInComparison.falseIfNotNull(cast.expression());
                                return expression5;
                            }
                            case IDENTICAL: {
                                expression5 = Booleans.FALSE;
                            }
                        }
                        return expression5;
                    }
                    if (lowerBoundComparison == 0) {
                        Expression expression6;
                        switch (operator) {
                            default: {
                                throw new MatchException(null, null);
                            }
                            case LESS_THAN: {
                                expression6 = UnwrapCastInComparison.falseIfNotNull(cast.expression());
                                return expression6;
                            }
                            case LESS_THAN_OR_EQUAL: {
                                expression6 = new Comparison(Comparison.Operator.EQUAL, cast.expression(), new Constant(sourceType, min));
                                return expression6;
                            }
                            case GREATER_THAN_OR_EQUAL: {
                                expression6 = this.trueIfNotNull(cast.expression());
                                return expression6;
                            }
                            case GREATER_THAN: {
                                expression6 = new Comparison(Comparison.Operator.NOT_EQUAL, cast.expression(), new Constant(sourceType, min));
                                return expression6;
                            }
                            case EQUAL: 
                            case NOT_EQUAL: 
                            case IDENTICAL: {
                                expression6 = new Comparison(operator, cast.expression(), new Constant(sourceType, min));
                            }
                        }
                        return expression6;
                    }
                }
            }
            try {
                targetToSource = this.plannerContext.getMetadata().getCoercion(targetType, sourceType);
            }
            catch (OperatorNotFoundException e) {
                return expression;
            }
            try {
                literalInSourceType = this.coerce(rightValue, targetToSource);
            }
            catch (TrinoException e) {
                return expression;
            }
            if (!targetType.isOrderable()) return new Comparison(operator, cast.expression(), new Constant(sourceType, literalInSourceType));
            Object roundtripLiteral = this.coerce(literalInSourceType, sourceToTarget);
            int literalVsRoundtripped = this.compare(targetType, rightValue, roundtripLiteral);
            if (literalVsRoundtripped > 0) {
                Expression expression7;
                switch (operator) {
                    default: {
                        throw new MatchException(null, null);
                    }
                    case EQUAL: {
                        expression7 = UnwrapCastInComparison.falseIfNotNull(cast.expression());
                        return expression7;
                    }
                    case NOT_EQUAL: {
                        expression7 = this.trueIfNotNull(cast.expression());
                        return expression7;
                    }
                    case IDENTICAL: {
                        expression7 = Booleans.FALSE;
                        return expression7;
                    }
                    case LESS_THAN: 
                    case LESS_THAN_OR_EQUAL: {
                        if (sourceRange.isPresent() && this.compare(sourceType, ((Type.Range)sourceRange.get()).getMin(), literalInSourceType) == 0) {
                            expression7 = new Comparison(Comparison.Operator.EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType));
                            return expression7;
                        }
                        expression7 = new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType));
                        return expression7;
                    }
                    case GREATER_THAN: 
                    case GREATER_THAN_OR_EQUAL: {
                        expression7 = new Comparison(Comparison.Operator.GREATER_THAN, cast.expression(), new Constant(sourceType, literalInSourceType));
                    }
                }
                return expression7;
            }
            if (literalVsRoundtripped >= 0) return new Comparison(operator, cast.expression(), new Constant(sourceType, literalInSourceType));
            switch (operator) {
                default: {
                    throw new MatchException(null, null);
                }
                case EQUAL: {
                    expression2 = UnwrapCastInComparison.falseIfNotNull(cast.expression());
                    return expression2;
                }
                case NOT_EQUAL: {
                    expression2 = this.trueIfNotNull(cast.expression());
                    return expression2;
                }
                case IDENTICAL: {
                    expression2 = Booleans.FALSE;
                    return expression2;
                }
                case LESS_THAN: 
                case LESS_THAN_OR_EQUAL: {
                    expression2 = new Comparison(Comparison.Operator.LESS_THAN, cast.expression(), new Constant(sourceType, literalInSourceType));
                    return expression2;
                }
                case GREATER_THAN: 
                case GREATER_THAN_OR_EQUAL: {
                    expression2 = sourceRange.isPresent() && this.compare(sourceType, ((Type.Range)sourceRange.get()).getMax(), literalInSourceType) == 0 ? new Comparison(Comparison.Operator.EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType)) : new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, cast.expression(), new Constant(sourceType, literalInSourceType));
                }
            }
            return expression2;
        }

        private Optional<Expression> unwrapTimestampToDateCast(TimestampType sourceType, Comparison.Operator operator, Expression timestampExpression, long date) {
            ResolvedFunction targetToSource;
            try {
                targetToSource = this.plannerContext.getMetadata().getCoercion((Type)DateType.DATE, (Type)sourceType);
            }
            catch (OperatorNotFoundException e) {
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, (Throwable)((Object)e));
            }
            Constant dateTimestamp = new Constant((Type)sourceType, this.coerce(date, targetToSource));
            Constant nextDateTimestamp = new Constant((Type)sourceType, this.coerce(date + 1L, targetToSource));
            return switch (operator) {
                default -> throw new MatchException(null, null);
                case Comparison.Operator.EQUAL -> Optional.of(IrUtils.and(new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp), new Comparison(Comparison.Operator.LESS_THAN, timestampExpression, nextDateTimestamp)));
                case Comparison.Operator.NOT_EQUAL -> Optional.of(IrUtils.or(new Comparison(Comparison.Operator.LESS_THAN, timestampExpression, dateTimestamp), new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp)));
                case Comparison.Operator.LESS_THAN -> Optional.of(new Comparison(Comparison.Operator.LESS_THAN, timestampExpression, dateTimestamp));
                case Comparison.Operator.LESS_THAN_OR_EQUAL -> Optional.of(new Comparison(Comparison.Operator.LESS_THAN, timestampExpression, nextDateTimestamp));
                case Comparison.Operator.GREATER_THAN -> Optional.of(new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, nextDateTimestamp));
                case Comparison.Operator.GREATER_THAN_OR_EQUAL -> Optional.of(new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp));
                case Comparison.Operator.IDENTICAL -> Optional.of(IrUtils.and(IrExpressions.not(this.plannerContext.getMetadata(), new IsNull(timestampExpression)), new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, timestampExpression, dateTimestamp), new Comparison(Comparison.Operator.LESS_THAN, timestampExpression, nextDateTimestamp)));
            };
        }

        private boolean hasInjectiveImplicitCoercion(Type source, Type target, Object value) {
            if (source.equals((Object)BigintType.BIGINT) && target.equals((Object)DoubleType.DOUBLE) || source.equals((Object)BigintType.BIGINT) && target.equals((Object)RealType.REAL) || source.equals((Object)IntegerType.INTEGER) && target.equals((Object)RealType.REAL)) {
                if (target.equals((Object)DoubleType.DOUBLE)) {
                    double doubleValue = (Double)value;
                    return doubleValue > 9.223372036854776E18 || doubleValue < -9.223372036854776E18 || Double.isNaN(doubleValue) || doubleValue > -9.007199254740992E15 && doubleValue < 9.007199254740992E15;
                }
                float realValue = Float.intBitsToFloat(Math.toIntExact((Long)value));
                return source.equals((Object)BigintType.BIGINT) && (realValue > 9.223372E18f || realValue < -9.223372E18f) || source.equals((Object)IntegerType.INTEGER) && (realValue > 2.1474836E9f || realValue < -2.1474836E9f) || Float.isNaN(realValue) || realValue > -8388608.0f && realValue < 8388608.0f;
            }
            if (source instanceof DecimalType) {
                int precision = ((DecimalType)source).getPrecision();
                if (precision > 15 && target.equals((Object)DoubleType.DOUBLE)) {
                    return false;
                }
                if (precision > 7 && target.equals((Object)RealType.REAL)) {
                    return false;
                }
            }
            if (target instanceof TimestampWithTimeZoneType) {
                TimestampWithTimeZoneType timestampWithTimeZoneType = (TimestampWithTimeZoneType)target;
                if (source instanceof DateType) {
                    if (!UnwrapCastInComparison.getTimeZone(timestampWithTimeZoneType, value).equals((Object)this.session.getTimeZoneKey())) {
                        return false;
                    }
                    return UnwrapCastInComparison.isTimestampToTimestampWithTimeZoneInjectiveAt(this.session.getTimeZoneKey().getZoneId(), UnwrapCastInComparison.getInstantWithTruncation(timestampWithTimeZoneType, value));
                }
                if (source instanceof TimestampType) {
                    if (!UnwrapCastInComparison.getTimeZone(timestampWithTimeZoneType, value).equals((Object)this.session.getTimeZoneKey())) {
                        return false;
                    }
                    return UnwrapCastInComparison.isTimestampToTimestampWithTimeZoneInjectiveAt(this.session.getTimeZoneKey().getZoneId(), UnwrapCastInComparison.getInstantWithTruncation(timestampWithTimeZoneType, value));
                }
                return false;
            }
            if (target instanceof TimeWithTimeZoneType) {
                return false;
            }
            boolean coercible = new TypeCoercion(arg_0 -> ((TypeManager)this.plannerContext.getTypeManager()).getType(arg_0)).canCoerce(source, target);
            if (source instanceof VarcharType) {
                VarcharType sourceVarchar = (VarcharType)source;
                if (target instanceof CharType) {
                    CharType targetChar = (CharType)target;
                    if (sourceVarchar.isUnbounded() || sourceVarchar.getBoundedLength() > targetChar.getLength()) {
                        return false;
                    }
                    Verify.verify((boolean)coercible, (String)"%s was expected to be coercible to %s", (Object)source, (Object)target);
                    if (sourceVarchar.getBoundedLength() == 0) {
                        return true;
                    }
                    int actualLengthWithoutSpaces = SliceUtf8.countCodePoints((Slice)((Slice)value));
                    Verify.verify((actualLengthWithoutSpaces <= targetChar.getLength() ? 1 : 0) != 0, (String)"Incorrect char value [%s] for %s", (Object)((Slice)value).toStringUtf8(), (Object)targetChar);
                    return sourceVarchar.getBoundedLength() == actualLengthWithoutSpaces;
                }
            }
            return coercible;
        }

        private Object coerce(Object value, ResolvedFunction coercion) {
            return this.functionInvoker.invoke(coercion, this.session.toConnectorSession(), value);
        }

        private boolean typeHasNaN(Type type) {
            return type instanceof DoubleType || type instanceof RealType;
        }

        private int compare(Type type, Object first, Object second) {
            Objects.requireNonNull(first, "first is null");
            Objects.requireNonNull(second, "second is null");
            MethodHandle comparisonOperator = this.plannerContext.getTypeOperators().getComparisonUnorderedLastOperator(type, InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.NEVER_NULL}));
            try {
                return Math.toIntExact(comparisonOperator.invoke(first, second));
            }
            catch (Throwable throwable) {
                Throwables.throwIfUnchecked((Throwable)throwable);
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, throwable);
            }
        }

        public Expression trueIfNotNull(Expression argument) {
            return IrUtils.or(IrExpressions.not(this.plannerContext.getMetadata(), new IsNull(argument)), new Constant((Type)BooleanType.BOOLEAN, null));
        }
    }
}

