/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.inject.Inject;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.spi.statistics.StatsUtil;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.ArithmeticBinaryExpression;
import io.trino.sql.ir.ArithmeticUnaryExpression;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.CoalesceExpression;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.Literal;
import io.trino.sql.ir.NodeRef;
import io.trino.sql.ir.NullLiteral;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.IrExpressionInterpreter;
import io.trino.sql.planner.IrLiteralInterpreter;
import io.trino.sql.planner.IrTypeAnalyzer;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.util.MoreMath;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;

public class ScalarStatsCalculator {
    private final PlannerContext plannerContext;
    private final IrTypeAnalyzer typeAnalyzer;

    @Inject
    public ScalarStatsCalculator(PlannerContext plannerContext, IrTypeAnalyzer typeAnalyzer) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext cannot be null");
        this.typeAnalyzer = Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    public SymbolStatsEstimate calculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session, TypeProvider types) {
        return (SymbolStatsEstimate)new Visitor(inputStatistics, session, types).process(scalarExpression);
    }

    private static SymbolStatsEstimate nullStatsEstimate() {
        return SymbolStatsEstimate.builder().setDistinctValuesCount(0.0).setNullsFraction(1.0).build();
    }

    private class Visitor
    extends IrVisitor<SymbolStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final IrLiteralInterpreter literalInterpreter;
        private final TypeProvider types;

        Visitor(PlanNodeStatsEstimate input, Session session, TypeProvider types) {
            this.input = input;
            this.session = session;
            this.literalInterpreter = new IrLiteralInterpreter(ScalarStatsCalculator.this.plannerContext, session);
            this.types = types;
        }

        @Override
        protected SymbolStatsEstimate visitExpression(Expression node, Void context) {
            return SymbolStatsEstimate.unknown();
        }

        @Override
        protected SymbolStatsEstimate visitSymbolReference(SymbolReference node, Void context) {
            return this.input.getSymbolStatistics(Symbol.from(node));
        }

        @Override
        protected SymbolStatsEstimate visitNullLiteral(NullLiteral node, Void context) {
            return ScalarStatsCalculator.nullStatsEstimate();
        }

        @Override
        protected SymbolStatsEstimate visitLiteral(Literal node, Void context) {
            Type type = ScalarStatsCalculator.this.typeAnalyzer.getType(this.session, TypeProvider.empty(), node);
            Object value = this.literalInterpreter.evaluate(node, type);
            OptionalDouble doubleValue = StatsUtil.toStatsRepresentation((Type)type, (Object)value);
            SymbolStatsEstimate.Builder estimate = SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0);
            if (doubleValue.isPresent()) {
                estimate.setLowValue(doubleValue.getAsDouble());
                estimate.setHighValue(doubleValue.getAsDouble());
            }
            return estimate.build();
        }

        @Override
        protected SymbolStatsEstimate visitFunctionCall(FunctionCall node, Void context) {
            Map<NodeRef<Expression>, Type> expressionTypes = ScalarStatsCalculator.this.typeAnalyzer.getTypes(this.session, this.types, node);
            IrExpressionInterpreter interpreter = new IrExpressionInterpreter(node, ScalarStatsCalculator.this.plannerContext, this.session, expressionTypes);
            Object value = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
            if (value == null || value instanceof NullLiteral) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            if (value instanceof Expression && !IrUtils.isEffectivelyLiteral(ScalarStatsCalculator.this.plannerContext, this.session, (Expression)value)) {
                return SymbolStatsEstimate.unknown();
            }
            return SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0).build();
        }

        @Override
        protected SymbolStatsEstimate visitCast(Cast node, Void context) {
            SymbolStatsEstimate sourceStats = (SymbolStatsEstimate)this.process(node.getExpression());
            double distinctValuesCount = sourceStats.getDistinctValuesCount();
            double lowValue = sourceStats.getLowValue();
            double highValue = sourceStats.getHighValue();
            if (this.isIntegralType(ScalarStatsCalculator.this.typeAnalyzer.getType(this.session, this.types, node))) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double integersInRange = highValue - lowValue + 1.0;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) {
                        distinctValuesCount = integersInRange;
                    }
                }
            }
            return SymbolStatsEstimate.builder().setNullsFraction(sourceStats.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        private boolean isIntegralType(Type type) {
            if (type instanceof BigintType || type instanceof IntegerType || type instanceof SmallintType || type instanceof TinyintType) {
                return true;
            }
            if (type instanceof DecimalType) {
                return ((DecimalType)type).getScale() == 0;
            }
            return false;
        }

        @Override
        protected SymbolStatsEstimate visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) {
            SymbolStatsEstimate stats = (SymbolStatsEstimate)this.process(node.getValue());
            switch (node.getSign()) {
                case PLUS: {
                    return stats;
                }
                case MINUS: {
                    return SymbolStatsEstimate.buildFrom(stats).setLowValue(-stats.getHighValue()).setHighValue(-stats.getLowValue()).build();
                }
            }
            throw new IllegalStateException("Unexpected sign: " + String.valueOf((Object)node.getSign()));
        }

        @Override
        protected SymbolStatsEstimate visitArithmeticBinary(ArithmeticBinaryExpression node, Void context) {
            Objects.requireNonNull(node, "node is null");
            SymbolStatsEstimate left = (SymbolStatsEstimate)this.process(node.getLeft());
            SymbolStatsEstimate right = (SymbolStatsEstimate)this.process(node.getRight());
            if (left.isUnknown() || right.isUnknown()) {
                return SymbolStatsEstimate.unknown();
            }
            SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder().setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())).setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()).setDistinctValuesCount(MoreMath.min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), this.input.getOutputRowCount()));
            double leftLow = left.getLowValue();
            double leftHigh = left.getHighValue();
            double rightLow = right.getLowValue();
            double rightHigh = right.getHighValue();
            if (Double.isNaN(leftLow) || Double.isNaN(leftHigh) || Double.isNaN(rightLow) || Double.isNaN(rightHigh)) {
                result.setLowValue(Double.NaN).setHighValue(Double.NaN);
            } else if (node.getOperator() == ArithmeticBinaryExpression.Operator.DIVIDE && rightLow < 0.0 && rightHigh > 0.0) {
                result.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (node.getOperator() == ArithmeticBinaryExpression.Operator.MODULUS) {
                double maxDivisor = MoreMath.max(Math.abs(rightLow), Math.abs(rightHigh));
                if (leftHigh <= 0.0) {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(0.0);
                } else if (leftLow >= 0.0) {
                    result.setLowValue(0.0).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                } else {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                }
            } else {
                double v1 = this.operate(node.getOperator(), leftLow, rightLow);
                double v2 = this.operate(node.getOperator(), leftLow, rightHigh);
                double v3 = this.operate(node.getOperator(), leftHigh, rightLow);
                double v4 = this.operate(node.getOperator(), leftHigh, rightHigh);
                double lowValue = MoreMath.min(v1, v2, v3, v4);
                double highValue = MoreMath.max(v1, v2, v3, v4);
                result.setLowValue(lowValue).setHighValue(highValue);
            }
            return result.build();
        }

        private double operate(ArithmeticBinaryExpression.Operator operator, double left, double right) {
            switch (operator) {
                case ADD: {
                    return left + right;
                }
                case SUBTRACT: {
                    return left - right;
                }
                case MULTIPLY: {
                    return left * right;
                }
                case DIVIDE: {
                    return left / right;
                }
                case MODULUS: {
                    return left % right;
                }
            }
            throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + String.valueOf((Object)operator));
        }

        @Override
        protected SymbolStatsEstimate visitCoalesceExpression(CoalesceExpression node, Void context) {
            Objects.requireNonNull(node, "node is null");
            SymbolStatsEstimate result = null;
            for (Expression operand : node.getOperands()) {
                SymbolStatsEstimate operandEstimates = (SymbolStatsEstimate)this.process(operand);
                if (result != null) {
                    result = this.estimateCoalesce(result, operandEstimates);
                    continue;
                }
                result = operandEstimates;
            }
            return Objects.requireNonNull(result, "result is null");
        }

        private SymbolStatsEstimate estimateCoalesce(SymbolStatsEstimate left, SymbolStatsEstimate right) {
            if (left.getNullsFraction() == 0.0) {
                return left;
            }
            if (left.getNullsFraction() == 1.0) {
                return right;
            }
            return SymbolStatsEstimate.builder().setLowValue(MoreMath.min(left.getLowValue(), right.getLowValue())).setHighValue(MoreMath.max(left.getHighValue(), right.getHighValue())).setDistinctValuesCount(left.getDistinctValuesCount() + MoreMath.min(right.getDistinctValuesCount(), this.input.getOutputRowCount() * left.getNullsFraction())).setNullsFraction(left.getNullsFraction() * right.getNullsFraction()).setAverageRowSize(MoreMath.max(left.getAverageRowSize(), right.getAverageRowSize())).build();
        }
    }
}

