/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Enums;
import com.google.common.base.Throwables;
import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.trino.Session;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.function.InvocationConvention;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.LongTimestamp;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.TimestampWithTimeZoneType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.InterpretedFunctionInvoker;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Between;
import io.trino.sql.ir.Booleans;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.ExpressionRewriter;
import io.trino.sql.ir.ExpressionTreeRewriter;
import io.trino.sql.ir.IsNull;
import io.trino.sql.ir.Logical;
import io.trino.sql.ir.Not;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.iterative.rule.ExpressionRewriteRuleSet;
import io.trino.sql.planner.iterative.rule.UnwrapCastInComparison;
import io.trino.type.DateTimes;
import java.lang.invoke.MethodHandle;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

public class UnwrapDateTruncInComparison
extends ExpressionRewriteRuleSet {
    public UnwrapDateTruncInComparison(PlannerContext plannerContext) {
        super(UnwrapDateTruncInComparison.createRewrite(plannerContext));
    }

    private static ExpressionRewriteRuleSet.ExpressionRewriter createRewrite(PlannerContext plannerContext) {
        Objects.requireNonNull(plannerContext, "plannerContext is null");
        return (expression, context) -> UnwrapDateTruncInComparison.unwrapDateTrunc(context.getSession(), plannerContext, expression);
    }

    private static Expression unwrapDateTrunc(Session session, PlannerContext plannerContext, Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext, session), expression);
    }

    private static class Visitor
    extends ExpressionRewriter<Void> {
        private final PlannerContext plannerContext;
        private final Session session;
        private final InterpretedFunctionInvoker functionInvoker;

        public Visitor(PlannerContext plannerContext, Session session) {
            this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
            this.session = Objects.requireNonNull(session, "session is null");
            this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getFunctionManager());
        }

        @Override
        public Expression rewriteComparison(Comparison node, Void context, ExpressionTreeRewriter<Void> treeRewriter) {
            Comparison expression = treeRewriter.defaultRewrite(node, null);
            return this.unwrapDateTrunc(expression);
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        private Expression unwrapDateTrunc(Comparison expression) {
            Expression expression2;
            Type rightType;
            Object object;
            Constant constant;
            Call call;
            Expression expression22 = expression.left();
            if (!(expression22 instanceof Call) || !(call = (Call)expression22).function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName("date_trunc")) || call.arguments().size() != 2) {
                return expression;
            }
            Expression unitExpression = call.arguments().get(0);
            if (!(unitExpression.type() instanceof VarcharType) || !(unitExpression instanceof Constant)) {
                return expression;
            }
            Slice unitName = (Slice)new IrExpressionInterpreter(unitExpression, this.plannerContext, this.session).evaluate();
            if (unitName == null) {
                return expression;
            }
            Expression argument = call.arguments().get(1);
            Expression right = new IrExpressionInterpreter(expression.right(), this.plannerContext, this.session).optimize();
            if (right instanceof Constant && (constant = (Constant)right).value() == null) {
                Record record;
                switch (expression.operator()) {
                    default: {
                        throw new MatchException(null, null);
                    }
                    case EQUAL: 
                    case NOT_EQUAL: 
                    case LESS_THAN: 
                    case LESS_THAN_OR_EQUAL: 
                    case GREATER_THAN: 
                    case GREATER_THAN_OR_EQUAL: {
                        record = new Constant((Type)BooleanType.BOOLEAN, null);
                        return record;
                    }
                    case IDENTICAL: {
                        record = new IsNull(argument);
                    }
                }
                return record;
            }
            if (!(right instanceof Constant)) return expression;
            Constant constant2 = (Constant)right;
            try {
                rightType = object = constant2.type();
            }
            catch (Throwable throwable) {
                throw new MatchException(throwable.toString(), throwable);
            }
            Object rightValue = object = constant2.value();
            if (rightType instanceof TimestampWithTimeZoneType) {
                return expression;
            }
            ResolvedFunction resolvedFunction = call.function();
            Optional unitIfSupported = Enums.getIfPresent(SupportedUnit.class, (String)unitName.toStringUtf8().toUpperCase(Locale.ENGLISH)).toJavaUtil();
            if (unitIfSupported.isEmpty()) {
                return expression;
            }
            SupportedUnit unit = (SupportedUnit)((Object)unitIfSupported.get());
            if (rightType == DateType.DATE && (unit == SupportedUnit.DAY || unit == SupportedUnit.HOUR)) {
                return expression;
            }
            Object rangeLow = this.functionInvoker.invoke(resolvedFunction, this.session.toConnectorSession(), (List<Object>)ImmutableList.of((Object)unitName, (Object)rightValue));
            int compare = this.compare(rightType, rangeLow, rightValue);
            Verify.verify((compare <= 0 ? 1 : 0) != 0, (String)"Truncation of %s value %s resulted in a bigger value %s", (Object)rightType, (Object)rightValue, (Object)rangeLow);
            boolean rightValueAtRangeLow = compare == 0;
            switch (expression.operator()) {
                default: {
                    throw new MatchException(null, null);
                }
                case EQUAL: {
                    if (!rightValueAtRangeLow) {
                        expression2 = UnwrapCastInComparison.falseIfNotNull(argument);
                        return expression2;
                    }
                    expression2 = this.between(argument, rightType, rangeLow, this.calculateRangeEndInclusive(rangeLow, rightType, unit));
                    return expression2;
                }
                case NOT_EQUAL: {
                    if (!rightValueAtRangeLow) {
                        expression2 = UnwrapCastInComparison.trueIfNotNull(argument);
                        return expression2;
                    }
                    expression2 = new Not(this.between(argument, rightType, rangeLow, this.calculateRangeEndInclusive(rangeLow, rightType, unit)));
                    return expression2;
                }
                case IDENTICAL: {
                    if (!rightValueAtRangeLow) {
                        expression2 = Booleans.FALSE;
                        return expression2;
                    }
                    expression2 = Logical.and(new Not(new IsNull(argument)), this.between(argument, rightType, rangeLow, this.calculateRangeEndInclusive(rangeLow, rightType, unit)));
                    return expression2;
                }
                case LESS_THAN: {
                    if (rightValueAtRangeLow) {
                        expression2 = new Comparison(Comparison.Operator.LESS_THAN, argument, new Constant(rightType, rangeLow));
                        return expression2;
                    }
                    expression2 = new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, argument, new Constant(rightType, this.calculateRangeEndInclusive(rangeLow, rightType, unit)));
                    return expression2;
                }
                case LESS_THAN_OR_EQUAL: {
                    expression2 = new Comparison(Comparison.Operator.LESS_THAN_OR_EQUAL, argument, new Constant(rightType, this.calculateRangeEndInclusive(rangeLow, rightType, unit)));
                    return expression2;
                }
                case GREATER_THAN: {
                    expression2 = new Comparison(Comparison.Operator.GREATER_THAN, argument, new Constant(rightType, this.calculateRangeEndInclusive(rangeLow, rightType, unit)));
                    return expression2;
                }
                case GREATER_THAN_OR_EQUAL: {
                    expression2 = rightValueAtRangeLow ? new Comparison(Comparison.Operator.GREATER_THAN_OR_EQUAL, argument, new Constant(rightType, rangeLow)) : new Comparison(Comparison.Operator.GREATER_THAN, argument, new Constant(rightType, this.calculateRangeEndInclusive(rangeLow, rightType, unit)));
                }
            }
            return expression2;
        }

        private Object calculateRangeEndInclusive(Object rangeStart, Type type, SupportedUnit rangeUnit) {
            if (type == DateType.DATE) {
                LocalDate date = LocalDate.ofEpochDay((Long)rangeStart);
                LocalDate endExclusive = switch (rangeUnit.ordinal()) {
                    default -> throw new MatchException(null, null);
                    case 0, 1 -> throw new UnsupportedOperationException("Unsupported type and unit: %s, %s".formatted(new Object[]{type, rangeUnit}));
                    case 2 -> date.plusMonths(1L);
                    case 3 -> date.plusYears(1L);
                };
                return endExclusive.toEpochDay() - 1L;
            }
            if (type instanceof TimestampType) {
                TimestampType timestampType = (TimestampType)type;
                if (timestampType.isShort()) {
                    long epochMicros = (Long)rangeStart;
                    long epochSecond = Math.floorDiv(epochMicros, 1000000);
                    int microOfSecond = Math.floorMod(epochMicros, 1000000);
                    Verify.verify((microOfSecond == 0 ? 1 : 0) != 0, (String)"Unexpected micros, value should be rounded to %s: %s", (Object)((Object)rangeUnit), (int)microOfSecond);
                    LocalDateTime dateTime = LocalDateTime.ofEpochSecond(epochSecond, 0, ZoneOffset.UTC);
                    LocalDateTime endExclusive = switch (rangeUnit.ordinal()) {
                        default -> throw new MatchException(null, null);
                        case 0 -> dateTime.plusHours(1L);
                        case 1 -> dateTime.plusDays(1L);
                        case 2 -> dateTime.plusMonths(1L);
                        case 3 -> dateTime.plusYears(1L);
                    };
                    Verify.verify((endExclusive.getNano() == 0 ? 1 : 0) != 0, (String)"Unexpected nanos in %s, value not rounded to %s", (Object)endExclusive, (Object)((Object)rangeUnit));
                    long endExclusiveMicros = endExclusive.toEpochSecond(ZoneOffset.UTC) * 1000000L;
                    return endExclusiveMicros - DateTimes.scaleFactor(timestampType.getPrecision(), 6);
                }
                LongTimestamp longTimestamp = (LongTimestamp)rangeStart;
                Verify.verify((longTimestamp.getPicosOfMicro() == 0 ? 1 : 0) != 0, (String)"Unexpected picos in %s, value not rounded to %s", (Object)rangeStart, (Object)((Object)rangeUnit));
                long endInclusiveMicros = (Long)this.calculateRangeEndInclusive(longTimestamp.getEpochMicros(), (Type)TimestampType.createTimestampType((int)6), rangeUnit);
                return new LongTimestamp(endInclusiveMicros, Math.toIntExact(1000000L - DateTimes.scaleFactor(timestampType.getPrecision(), 12)));
            }
            throw new UnsupportedOperationException("Unsupported type: " + String.valueOf(type));
        }

        private Between between(Expression argument, Type type, Object minInclusive, Object maxInclusive) {
            return new Between(argument, new Constant(type, minInclusive), new Constant(type, maxInclusive));
        }

        private int compare(Type type, Object first, Object second) {
            Objects.requireNonNull(first, "first is null");
            Objects.requireNonNull(second, "second is null");
            MethodHandle comparisonOperator = this.plannerContext.getTypeOperators().getComparisonUnorderedLastOperator(type, InvocationConvention.simpleConvention((InvocationConvention.InvocationReturnConvention)InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL, (InvocationConvention.InvocationArgumentConvention[])new InvocationConvention.InvocationArgumentConvention[]{InvocationConvention.InvocationArgumentConvention.NEVER_NULL, InvocationConvention.InvocationArgumentConvention.NEVER_NULL}));
            try {
                return Math.toIntExact(comparisonOperator.invoke(first, second));
            }
            catch (Throwable throwable) {
                Throwables.throwIfUnchecked((Throwable)throwable);
                throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, throwable);
            }
        }
    }

    private static enum SupportedUnit {
        HOUR,
        DAY,
        MONTH,
        YEAR;

    }
}

