/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.collect.ImmutableMap;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.metadata.GlobalFunctionCatalog;
import io.trino.spi.function.CatalogSchemaFunctionName;
import io.trino.spi.function.OperatorType;
import io.trino.spi.statistics.StatsUtil;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.Coalesce;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrVisitor;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.optimizer.IrExpressionOptimizer;
import io.trino.sql.planner.Symbol;
import io.trino.util.MoreMath;
import java.lang.runtime.SwitchBootstraps;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;

public class ScalarStatsCalculator {
    private final PlannerContext plannerContext;

    @Inject
    public ScalarStatsCalculator(PlannerContext plannerContext) {
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext cannot be null");
    }

    public SymbolStatsEstimate calculate(Expression scalarExpression, PlanNodeStatsEstimate inputStatistics, Session session) {
        return (SymbolStatsEstimate)new Visitor(inputStatistics, session).process(scalarExpression);
    }

    private static SymbolStatsEstimate nullStatsEstimate() {
        return SymbolStatsEstimate.builder().setDistinctValuesCount(0.0).setNullsFraction(1.0).build();
    }

    private class Visitor
    extends IrVisitor<SymbolStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;

        Visitor(PlanNodeStatsEstimate input, Session session) {
            this.input = input;
            this.session = session;
        }

        @Override
        protected SymbolStatsEstimate visitExpression(Expression node, Void context) {
            return SymbolStatsEstimate.unknown();
        }

        @Override
        protected SymbolStatsEstimate visitReference(Reference node, Void context) {
            return this.input.getSymbolStatistics(Symbol.from(node));
        }

        @Override
        protected SymbolStatsEstimate visitConstant(Constant node, Void context) {
            Type type = node.type();
            Object value = node.value();
            if (value == null) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            OptionalDouble doubleValue = StatsUtil.toStatsRepresentation((Type)type, (Object)value);
            SymbolStatsEstimate.Builder estimate = SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0);
            if (doubleValue.isPresent()) {
                estimate.setLowValue(doubleValue.getAsDouble());
                estimate.setHighValue(doubleValue.getAsDouble());
            }
            return estimate.build();
        }

        @Override
        protected SymbolStatsEstimate visitCall(Call node, Void context) {
            Constant constant;
            if (node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.NEGATION))) {
                SymbolStatsEstimate stats = (SymbolStatsEstimate)this.process(node.arguments().getFirst());
                return SymbolStatsEstimate.buildFrom(stats).setLowValue(-stats.getHighValue()).setHighValue(-stats.getLowValue()).build();
            }
            if (node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.ADD)) || node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.SUBTRACT)) || node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.MULTIPLY)) || node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.DIVIDE)) || node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.MODULUS))) {
                return this.processArithmetic(node);
            }
            Expression value = IrExpressionOptimizer.newOptimizer(ScalarStatsCalculator.this.plannerContext).process(node, this.session, (Map<Symbol, Expression>)ImmutableMap.of()).orElse(node);
            if (value instanceof Constant && (constant = (Constant)value).value() == null) {
                return ScalarStatsCalculator.nullStatsEstimate();
            }
            if (value instanceof Constant) {
                return SymbolStatsEstimate.builder().setNullsFraction(0.0).setDistinctValuesCount(1.0).build();
            }
            return SymbolStatsEstimate.unknown();
        }

        @Override
        protected SymbolStatsEstimate visitCast(Cast node, Void context) {
            SymbolStatsEstimate sourceStats = (SymbolStatsEstimate)this.process(node.expression());
            double distinctValuesCount = sourceStats.getDistinctValuesCount();
            double lowValue = sourceStats.getLowValue();
            double highValue = sourceStats.getHighValue();
            if (this.isIntegralType(node.type())) {
                if (Double.isFinite(lowValue)) {
                    lowValue = Math.round(lowValue);
                }
                if (Double.isFinite(highValue)) {
                    highValue = Math.round(highValue);
                }
                if (Double.isFinite(lowValue) && Double.isFinite(highValue)) {
                    double integersInRange = highValue - lowValue + 1.0;
                    if (!Double.isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) {
                        distinctValuesCount = integersInRange;
                    }
                }
            }
            return SymbolStatsEstimate.builder().setNullsFraction(sourceStats.getNullsFraction()).setLowValue(lowValue).setHighValue(highValue).setDistinctValuesCount(distinctValuesCount).build();
        }

        private boolean isIntegralType(Type type) {
            if (type instanceof BigintType || type instanceof IntegerType || type instanceof SmallintType || type instanceof TinyintType) {
                return true;
            }
            if (type instanceof DecimalType) {
                DecimalType decimalType = (DecimalType)type;
                return decimalType.getScale() == 0;
            }
            return false;
        }

        protected SymbolStatsEstimate processArithmetic(Call node) {
            Objects.requireNonNull(node, "node is null");
            SymbolStatsEstimate left = (SymbolStatsEstimate)this.process(node.arguments().get(0));
            SymbolStatsEstimate right = (SymbolStatsEstimate)this.process(node.arguments().get(1));
            if (left.isUnknown() || right.isUnknown()) {
                return SymbolStatsEstimate.unknown();
            }
            SymbolStatsEstimate.Builder result = SymbolStatsEstimate.builder().setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())).setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()).setDistinctValuesCount(MoreMath.min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), this.input.getOutputRowCount()));
            double leftLow = left.getLowValue();
            double leftHigh = left.getHighValue();
            double rightLow = right.getLowValue();
            double rightHigh = right.getHighValue();
            if (Double.isNaN(leftLow) || Double.isNaN(leftHigh) || Double.isNaN(rightLow) || Double.isNaN(rightHigh)) {
                result.setLowValue(Double.NaN).setHighValue(Double.NaN);
            } else if (node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.DIVIDE)) && rightLow < 0.0 && rightHigh > 0.0) {
                result.setLowValue(Double.NEGATIVE_INFINITY).setHighValue(Double.POSITIVE_INFINITY);
            } else if (node.function().name().equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.MODULUS))) {
                double maxDivisor = MoreMath.max(Math.abs(rightLow), Math.abs(rightHigh));
                if (leftHigh <= 0.0) {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(0.0);
                } else if (leftLow >= 0.0) {
                    result.setLowValue(0.0).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                } else {
                    result.setLowValue(MoreMath.max(-maxDivisor, leftLow)).setHighValue(MoreMath.min(maxDivisor, leftHigh));
                }
            } else {
                double v1 = this.operate(node.function().name(), leftLow, rightLow);
                double v2 = this.operate(node.function().name(), leftLow, rightHigh);
                double v3 = this.operate(node.function().name(), leftHigh, rightLow);
                double v4 = this.operate(node.function().name(), leftHigh, rightHigh);
                double lowValue = MoreMath.min(v1, v2, v3, v4);
                double highValue = MoreMath.max(v1, v2, v3, v4);
                result.setLowValue(lowValue).setHighValue(highValue);
            }
            return result.build();
        }

        private double operate(CatalogSchemaFunctionName function, double left, double right) {
            double d;
            CatalogSchemaFunctionName catalogSchemaFunctionName = function;
            Objects.requireNonNull(catalogSchemaFunctionName);
            CatalogSchemaFunctionName catalogSchemaFunctionName2 = catalogSchemaFunctionName;
            int n = 0;
            block7: while (true) {
                switch (SwitchBootstraps.typeSwitch("typeSwitch", new Object[]{CatalogSchemaFunctionName.class, CatalogSchemaFunctionName.class, CatalogSchemaFunctionName.class, CatalogSchemaFunctionName.class, CatalogSchemaFunctionName.class}, (CatalogSchemaFunctionName)catalogSchemaFunctionName2, n)) {
                    case 0: {
                        CatalogSchemaFunctionName name = catalogSchemaFunctionName2;
                        if (!name.equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.ADD))) {
                            n = 1;
                            continue block7;
                        }
                        d = left + right;
                        break block7;
                    }
                    case 1: {
                        CatalogSchemaFunctionName name = catalogSchemaFunctionName2;
                        if (!name.equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.SUBTRACT))) {
                            n = 2;
                            continue block7;
                        }
                        d = left - right;
                        break block7;
                    }
                    case 2: {
                        CatalogSchemaFunctionName name = catalogSchemaFunctionName2;
                        if (!name.equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.MULTIPLY))) {
                            n = 3;
                            continue block7;
                        }
                        d = left * right;
                        break block7;
                    }
                    case 3: {
                        CatalogSchemaFunctionName name = catalogSchemaFunctionName2;
                        if (!name.equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.DIVIDE))) {
                            n = 4;
                            continue block7;
                        }
                        d = left / right;
                        break block7;
                    }
                    case 4: {
                        CatalogSchemaFunctionName name = catalogSchemaFunctionName2;
                        if (!name.equals((Object)GlobalFunctionCatalog.builtinFunctionName(OperatorType.MODULUS))) {
                            n = 5;
                            continue block7;
                        }
                        d = left % right;
                        break block7;
                    }
                    default: {
                        throw new IllegalStateException("Unsupported binary arithmetic operation: " + String.valueOf(function));
                    }
                }
                break;
            }
            return d;
        }

        @Override
        protected SymbolStatsEstimate visitCoalesce(Coalesce node, Void context) {
            Objects.requireNonNull(node, "node is null");
            SymbolStatsEstimate result = null;
            for (Expression operand : node.operands()) {
                SymbolStatsEstimate operandEstimates = (SymbolStatsEstimate)this.process(operand);
                if (result != null) {
                    result = this.estimateCoalesce(result, operandEstimates);
                    continue;
                }
                result = operandEstimates;
            }
            return Objects.requireNonNull(result, "result is null");
        }

        private SymbolStatsEstimate estimateCoalesce(SymbolStatsEstimate left, SymbolStatsEstimate right) {
            if (left.getNullsFraction() == 0.0) {
                return left;
            }
            if (left.getNullsFraction() == 1.0) {
                return right;
            }
            return SymbolStatsEstimate.builder().setLowValue(MoreMath.min(left.getLowValue(), right.getLowValue())).setHighValue(MoreMath.max(left.getHighValue(), right.getHighValue())).setDistinctValuesCount(left.getDistinctValuesCount() + MoreMath.min(right.getDistinctValuesCount(), this.input.getOutputRowCount() * left.getNullsFraction())).setNullsFraction(left.getNullsFraction() * right.getNullsFraction()).setAverageRowSize(MoreMath.max(left.getAverageRowSize(), right.getAverageRowSize())).build();
        }
    }
}

