/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatisticRange;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.util.MoreMath;
import java.util.Optional;
import java.util.OptionalDouble;

public final class ComparisonStatsCalculator {
    private ComparisonStatsCalculator() {
    }

    public static Optional<PlanNodeStatsEstimate> comparisonExpressionToLiteralStats(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> symbol, SymbolStatsEstimate expressionStats, OptionalDouble doubleLiteral, ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL: {
                return ComparisonStatsCalculator.expressionToLiteralEquality(inputStatistics, symbol, expressionStats, doubleLiteral);
            }
            case NOT_EQUAL: {
                return ComparisonStatsCalculator.expressionToLiteralNonEquality(inputStatistics, symbol, expressionStats, doubleLiteral);
            }
            case LESS_THAN: 
            case LESS_THAN_OR_EQUAL: {
                return ComparisonStatsCalculator.expressionToLiteralLessThan(inputStatistics, symbol, expressionStats, doubleLiteral);
            }
            case GREATER_THAN: 
            case GREATER_THAN_OR_EQUAL: {
                return ComparisonStatsCalculator.expressionToLiteralGreaterThan(inputStatistics, symbol, expressionStats, doubleLiteral);
            }
        }
        return Optional.empty();
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralRangeComparison(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> symbol, SymbolStatsEstimate expressionStats, StatisticRange literalRange) {
        StatisticRange range = StatisticRange.from(expressionStats);
        StatisticRange intersectRange = range.intersect(literalRange);
        double filterFactor = range.overlapPercentWith(intersectRange);
        PlanNodeStatsEstimate estimate = inputStatistics.mapOutputRowCount(rowCount -> filterFactor * (1.0 - expressionStats.getNullsFraction()) * rowCount);
        if (symbol.isPresent()) {
            SymbolStatsEstimate symbolNewEstimate = SymbolStatsEstimate.builder().setAverageRowSize(expressionStats.getAverageRowSize()).setStatisticsRange(intersectRange).setNullsFraction(0.0).build();
            estimate = estimate.mapSymbolColumnStatistics(symbol.get(), oldStats -> symbolNewEstimate);
        }
        return Optional.of(estimate);
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralEquality(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> symbol, SymbolStatsEstimate expressionStats, OptionalDouble literal) {
        StatisticRange literalRange = literal.isPresent() ? new StatisticRange(literal.getAsDouble(), literal.getAsDouble(), 1.0) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        return ComparisonStatsCalculator.expressionToLiteralRangeComparison(inputStatistics, symbol, expressionStats, literalRange);
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralNonEquality(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> symbol, SymbolStatsEstimate expressionStats, OptionalDouble literal) {
        StatisticRange range = StatisticRange.from(expressionStats);
        StatisticRange literalRange = literal.isPresent() ? new StatisticRange(literal.getAsDouble(), literal.getAsDouble(), 1.0) : new StatisticRange(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 1.0);
        StatisticRange intersectRange = range.intersect(literalRange);
        double filterFactor = 1.0 - range.overlapPercentWith(intersectRange);
        PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics);
        estimate.setOutputRowCount(filterFactor * (1.0 - expressionStats.getNullsFraction()) * inputStatistics.getOutputRowCount());
        if (symbol.isPresent()) {
            SymbolStatsEstimate symbolNewEstimate = SymbolStatsEstimate.buildFrom(expressionStats).setNullsFraction(0.0).setDistinctValuesCount(MoreMath.max(expressionStats.getDistinctValuesCount() - 1.0, 0.0)).build();
            estimate = estimate.addSymbolStatistics(symbol.get(), symbolNewEstimate);
        }
        return Optional.of(estimate.build());
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralLessThan(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> symbol, SymbolStatsEstimate expressionStats, OptionalDouble literal) {
        return ComparisonStatsCalculator.expressionToLiteralRangeComparison(inputStatistics, symbol, expressionStats, new StatisticRange(Double.NEGATIVE_INFINITY, literal.orElse(Double.POSITIVE_INFINITY), Double.NaN));
    }

    private static Optional<PlanNodeStatsEstimate> expressionToLiteralGreaterThan(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> symbol, SymbolStatsEstimate expressionStats, OptionalDouble literal) {
        return ComparisonStatsCalculator.expressionToLiteralRangeComparison(inputStatistics, symbol, expressionStats, new StatisticRange(literal.orElse(Double.NEGATIVE_INFINITY), Double.POSITIVE_INFINITY, Double.NaN));
    }

    public static Optional<PlanNodeStatsEstimate> comparisonExpressionToExpressionStats(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> left, SymbolStatsEstimate leftStats, Optional<Symbol> right, SymbolStatsEstimate rightStats, ComparisonExpression.Operator operator) {
        switch (operator) {
            case EQUAL: {
                return ComparisonStatsCalculator.expressionToExpressionEquality(inputStatistics, left, leftStats, right, rightStats);
            }
            case NOT_EQUAL: {
                return ComparisonStatsCalculator.expressionToExpressionNonEquality(inputStatistics, left, leftStats, right, rightStats);
            }
        }
        return Optional.empty();
    }

    private static Optional<PlanNodeStatsEstimate> expressionToExpressionEquality(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> left, SymbolStatsEstimate leftStats, Optional<Symbol> right, SymbolStatsEstimate rightStats) {
        if (Double.isNaN(leftStats.getDistinctValuesCount()) || Double.isNaN(rightStats.getDistinctValuesCount())) {
            return Optional.empty();
        }
        StatisticRange leftRange = StatisticRange.from(leftStats);
        StatisticRange rightRange = StatisticRange.from(rightStats);
        StatisticRange intersect = leftRange.intersect(rightRange);
        double nullsFilterFactor = (1.0 - leftStats.getNullsFraction()) * (1.0 - rightStats.getNullsFraction());
        double leftNdv = leftRange.getDistinctValuesCount();
        double rightNdv = rightRange.getDistinctValuesCount();
        double filterFactor = 1.0 / MoreMath.max(leftNdv, rightNdv, 1.0);
        double retainedNdv = MoreMath.min(leftNdv, rightNdv);
        PlanNodeStatsEstimate.Builder estimate = PlanNodeStatsEstimate.buildFrom(inputStatistics).setOutputRowCount(inputStatistics.getOutputRowCount() * nullsFilterFactor * filterFactor);
        SymbolStatsEstimate equalityStats = SymbolStatsEstimate.builder().setAverageRowSize(ComparisonStatsCalculator.averageExcludingNaNs(leftStats.getAverageRowSize(), rightStats.getAverageRowSize())).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
        left.ifPresent(symbol -> estimate.addSymbolStatistics((Symbol)symbol, equalityStats));
        right.ifPresent(symbol -> estimate.addSymbolStatistics((Symbol)symbol, equalityStats));
        return Optional.of(estimate.build());
    }

    private static double averageExcludingNaNs(double first, double second) {
        if (Double.isNaN(first) && Double.isNaN(second)) {
            return Double.NaN;
        }
        if (!Double.isNaN(first) && !Double.isNaN(second)) {
            return (first + second) / 2.0;
        }
        return MoreMath.firstNonNaN(first, second);
    }

    private static Optional<PlanNodeStatsEstimate> expressionToExpressionNonEquality(PlanNodeStatsEstimate inputStatistics, Optional<Symbol> left, SymbolStatsEstimate leftStats, Optional<Symbol> right, SymbolStatsEstimate rightStats) {
        SymbolStatsEstimate rightNullsFiltered;
        SymbolStatsEstimate leftNullsFiltered;
        double nullsFilterFactor = (1.0 - leftStats.getNullsFraction()) * (1.0 - rightStats.getNullsFraction());
        PlanNodeStatsEstimate inputNullsFiltered = inputStatistics.mapOutputRowCount(size -> size * nullsFilterFactor);
        Optional<PlanNodeStatsEstimate> equalityStats = ComparisonStatsCalculator.expressionToExpressionEquality(inputNullsFiltered, left, leftNullsFiltered = leftStats.mapNullsFraction(nullsFraction -> 0.0), right, rightNullsFiltered = rightStats.mapNullsFraction(nullsFration -> 0.0));
        if (!equalityStats.isPresent()) {
            return Optional.empty();
        }
        PlanNodeStatsEstimate resultStats = inputNullsFiltered.mapOutputRowCount(rowCount -> {
            double equalityFilterFactor = ((PlanNodeStatsEstimate)equalityStats.get()).getOutputRowCount() / inputNullsFiltered.getOutputRowCount();
            if (!Double.isFinite(equalityFilterFactor)) {
                equalityFilterFactor = 0.0;
            }
            return rowCount * (1.0 - equalityFilterFactor);
        });
        if (left.isPresent()) {
            resultStats = resultStats.mapSymbolColumnStatistics(left.get(), stats -> leftNullsFiltered);
        }
        if (right.isPresent()) {
            resultStats = resultStats.mapSymbolColumnStatistics(right.get(), stats -> rightNullsFiltered);
        }
        return Optional.of(resultStats);
    }
}

