/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatisticRange;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.sql.planner.Symbol;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.util.MoreMath;
import java.util.Optional;
import java.util.OptionalDouble;

public final class ComparisonStatsCalculator {
    private ComparisonStatsCalculator() {
    }

    public static PlanNodeStatsEstimate estimateExpressionToLiteralComparison(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate expressionStatistics, Optional<Symbol> expressionSymbol, OptionalDouble literalValue, ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionEqualToLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue);
            }
            case NOT_EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionNotEqualToLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue);
            }
            case LESS_THAN: 
            case LESS_THAN_OR_EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionLessThanLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue);
            }
            case GREATER_THAN: 
            case GREATER_THAN_OR_EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionGreaterThanLiteral(inputStatistics, expressionStatistics, expressionSymbol, literalValue);
            }
            case IS_DISTINCT_FROM: {
                return PlanNodeStatsEstimate.unknown();
            }
        }
        throw new IllegalArgumentException("Unexpected comparison operator: " + operator);
    }

    private static PlanNodeStatsEstimate estimateExpressionEqualToLiteral(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate expressionStatistics, Optional<Symbol> expressionSymbol, OptionalDouble literalValue) {
        StatisticRange filterRange = literalValue.isPresent() ? new StatisticRange(literalValue.getAsDouble(), literalValue.getAsDouble(), 1.0) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        return ComparisonStatsCalculator.estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange);
    }

    private static PlanNodeStatsEstimate estimateExpressionNotEqualToLiteral(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate expressionStatistics, Optional<Symbol> expressionSymbol, OptionalDouble literalValue) {
        StatisticRange expressionRange = StatisticRange.from(expressionStatistics);
        StatisticRange filterRange = literalValue.isPresent() ? new StatisticRange(literalValue.getAsDouble(), literalValue.getAsDouble(), 1.0) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        StatisticRange intersectRange = expressionRange.intersect(filterRange);
        double filterFactor = 1.0 - expressionRange.overlapPercentWith(intersectRange);
        PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics);
        estimate.setOutputRowCount(filterFactor * (1.0 - expressionStatistics.getNullsFraction()) * inputStatistics.getOutputRowCount());
        if (expressionSymbol.isPresent()) {
            SymbolStatsEstimate symbolNewEstimate = SymbolStatsEstimate.buildFrom(expressionStatistics).setNullsFraction(0.0).setDistinctValuesCount(MoreMath.max(expressionStatistics.getDistinctValuesCount() - 1.0, 0.0)).build();
            estimate = estimate.addSymbolStatistics(expressionSymbol.get(), symbolNewEstimate);
        }
        return estimate.build();
    }

    private static PlanNodeStatsEstimate estimateExpressionLessThanLiteral(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate expressionStatistics, Optional<Symbol> expressionSymbol, OptionalDouble literalValue) {
        StatisticRange filterRange = new StatisticRange(Double.NEGATIVE_INFINITY, literalValue.orElse(Double.POSITIVE_INFINITY), Double.NaN);
        return ComparisonStatsCalculator.estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange);
    }

    private static PlanNodeStatsEstimate estimateExpressionGreaterThanLiteral(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate expressionStatistics, Optional<Symbol> expressionSymbol, OptionalDouble literalValue) {
        StatisticRange filterRange = new StatisticRange(literalValue.orElse(Double.NEGATIVE_INFINITY), Double.POSITIVE_INFINITY, Double.NaN);
        return ComparisonStatsCalculator.estimateFilterRange(inputStatistics, expressionStatistics, expressionSymbol, filterRange);
    }

    private static PlanNodeStatsEstimate estimateFilterRange(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate expressionStatistics, Optional<Symbol> expressionSymbol, StatisticRange filterRange) {
        StatisticRange expressionRange = StatisticRange.from(expressionStatistics);
        StatisticRange intersectRange = expressionRange.intersect(filterRange);
        double filterFactor = expressionRange.overlapPercentWith(intersectRange);
        PlanNodeStatsEstimate estimate = inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1.0 - expressionStatistics.getNullsFraction()) * rowCount);
        if (expressionSymbol.isPresent()) {
            SymbolStatsEstimate symbolNewEstimate = SymbolStatsEstimate.builder().setAverageRowSize(expressionStatistics.getAverageRowSize()).setStatisticsRange(intersectRange).setNullsFraction(0.0).build();
            estimate = estimate.mapSymbolColumnStatistics(expressionSymbol.get(), oldStats -> symbolNewEstimate);
        }
        return estimate;
    }

    public static PlanNodeStatsEstimate estimateExpressionToExpressionComparison(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate leftExpressionStatistics, Optional<Symbol> leftExpressionSymbol, SymbolStatsEstimate rightExpressionStatistics, Optional<Symbol> rightExpressionSymbol, ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionSymbol, rightExpressionStatistics, rightExpressionSymbol);
            }
            case NOT_EQUAL: {
                return ComparisonStatsCalculator.estimateExpressionNotEqualToExpression(inputStatistics, leftExpressionStatistics, leftExpressionSymbol, rightExpressionStatistics, rightExpressionSymbol);
            }
            case LESS_THAN: 
            case LESS_THAN_OR_EQUAL: 
            case GREATER_THAN: 
            case GREATER_THAN_OR_EQUAL: 
            case IS_DISTINCT_FROM: {
                return PlanNodeStatsEstimate.unknown();
            }
        }
        throw new IllegalArgumentException("Unexpected comparison operator: " + operator);
    }

    private static PlanNodeStatsEstimate estimateExpressionEqualToExpression(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate leftExpressionStatistics, Optional<Symbol> leftExpressionSymbol, SymbolStatsEstimate rightExpressionStatistics, Optional<Symbol> rightExpressionSymbol) {
        if (Double.isNaN(leftExpressionStatistics.getDistinctValuesCount()) || Double.isNaN(rightExpressionStatistics.getDistinctValuesCount())) {
            return PlanNodeStatsEstimate.unknown();
        }
        StatisticRange leftExpressionRange = StatisticRange.from(leftExpressionStatistics);
        StatisticRange rightExpressionRange = StatisticRange.from(rightExpressionStatistics);
        StatisticRange intersect = leftExpressionRange.intersect(rightExpressionRange);
        double nullsFilterFactor = (1.0 - leftExpressionStatistics.getNullsFraction()) * (1.0 - rightExpressionStatistics.getNullsFraction());
        double leftNdv = leftExpressionRange.getDistinctValuesCount();
        double rightNdv = rightExpressionRange.getDistinctValuesCount();
        double filterFactor = 1.0 / MoreMath.max(leftNdv, rightNdv, 1.0);
        double retainedNdv = MoreMath.min(leftNdv, rightNdv);
        PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics).setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor * filterFactor);
        SymbolStatsEstimate equalityStats = SymbolStatsEstimate.builder().setAverageRowSize(MoreMath.averageExcludingNaNs(leftExpressionStatistics.getAverageRowSize(), rightExpressionStatistics.getAverageRowSize())).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
        leftExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics((Symbol)symbol, equalityStats));
        rightExpressionSymbol.ifPresent(symbol -> estimate.addSymbolStatistics((Symbol)symbol, equalityStats));
        return estimate.build();
    }

    private static PlanNodeStatsEstimate estimateExpressionNotEqualToExpression(PlanNodeStatsEstimate inputStatistics, SymbolStatsEstimate leftExpressionStatistics, Optional<Symbol> leftExpressionSymbol, SymbolStatsEstimate rightExpressionStatistics, Optional<Symbol> rightExpressionSymbol) {
        SymbolStatsEstimate rightNullsFiltered;
        SymbolStatsEstimate leftNullsFiltered;
        double nullsFilterFactor = (1.0 - leftExpressionStatistics.getNullsFraction()) * (1.0 - rightExpressionStatistics.getNullsFraction());
        PlanNodeStatsEstimate inputNullsFiltered = inputStatistics.mapOutputRowCount(size -> size * nullsFilterFactor);
        PlanNodeStatsEstimate equalityStats = ComparisonStatsCalculator.estimateExpressionEqualToExpression(inputNullsFiltered, leftNullsFiltered = leftExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0), leftExpressionSymbol, rightNullsFiltered = rightExpressionStatistics.mapNullsFraction(nullsFraction -> 0.0), rightExpressionSymbol);
        if (equalityStats.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.buildFrom(inputNullsFiltered);
        double equalityFilterFactor = equalityStats.getOutputRowCount() / inputNullsFiltered.getOutputRowCount();
        if (!Double.isFinite(equalityFilterFactor)) {
            equalityFilterFactor = 0.0;
        }
        result.setOutputRowCount(inputNullsFiltered.getOutputRowCount() * (1.0 - equalityFilterFactor));
        leftExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics((Symbol)symbol, leftNullsFiltered));
        rightExpressionSymbol.ifPresent(symbol -> result.addSymbolStatistics((Symbol)symbol, rightNullsFiltered));
        return result.build();
    }
}

