/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatisticRange;
import io.trino.cost.StatsProvider;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.util.MoreMath;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

public final class PlanNodeStatsEstimateMath {
    private static final List<Class<? extends PlanNode>> EXPANDING_NODE_CLASSES = ImmutableList.of(JoinNode.class, UnnestNode.class);

    private PlanNodeStatsEstimateMath() {
    }

    public static PlanNodeStatsEstimate subtractSubsetStats(PlanNodeStatsEstimate superset, PlanNodeStatsEstimate subset) {
        double subsetRowCount;
        if (superset.isOutputRowCountUnknown() || subset.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        double supersetRowCount = superset.getOutputRowCount();
        double outputRowCount = Double.max(supersetRowCount - (subsetRowCount = subset.getOutputRowCount()), 0.0);
        if (outputRowCount == 0.0) {
            return PlanNodeStatsEstimateMath.createZeroStats(superset);
        }
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
        result.setOutputRowCount(outputRowCount);
        superset.getSymbolsWithKnownStatistics().forEach(symbol -> {
            double subsetNonNullsCount;
            double subsetValuesPerDistinctValue;
            double supersetNonNullsCount;
            double supersetValuesPerDistinctValue;
            SymbolStatsEstimate supersetSymbolStats = superset.getSymbolStatistics((Symbol)symbol);
            SymbolStatsEstimate subsetSymbolStats = subset.getSymbolStatistics((Symbol)symbol);
            SymbolStatsEstimate.Builder newSymbolStats = SymbolStatsEstimate.builder();
            newSymbolStats.setAverageRowSize(supersetSymbolStats.getAverageRowSize());
            double supersetNullsCount = supersetSymbolStats.getNullsFraction() * supersetRowCount;
            double subsetNullsCount = subsetSymbolStats.getNullsFraction() * subsetRowCount;
            double newNullsCount = Double.max(supersetNullsCount - subsetNullsCount, 0.0);
            newSymbolStats.setNullsFraction(Double.min(newNullsCount, outputRowCount) / outputRowCount);
            double supersetDistinctValues = supersetSymbolStats.getDistinctValuesCount();
            double subsetDistinctValues = subsetSymbolStats.getDistinctValuesCount();
            double newDistinctValuesCount = Double.isNaN(supersetDistinctValues) || Double.isNaN(subsetDistinctValues) ? Double.NaN : (supersetDistinctValues == 0.0 ? 0.0 : (subsetDistinctValues == 0.0 ? supersetDistinctValues : ((supersetValuesPerDistinctValue = (supersetNonNullsCount = supersetRowCount - supersetNullsCount) / supersetDistinctValues) <= (subsetValuesPerDistinctValue = (subsetNonNullsCount = subsetRowCount - subsetNullsCount) / subsetDistinctValues) ? Double.max(supersetDistinctValues - subsetDistinctValues, 0.0) : supersetDistinctValues)));
            newSymbolStats.setDistinctValuesCount(newDistinctValuesCount);
            newSymbolStats.setLowValue(supersetSymbolStats.getLowValue());
            newSymbolStats.setHighValue(supersetSymbolStats.getHighValue());
            result.addSymbolStatistics((Symbol)symbol, newSymbolStats.build());
        });
        return result.build();
    }

    public static PlanNodeStatsEstimate capStats(PlanNodeStatsEstimate stats, PlanNodeStatsEstimate cap) {
        if (stats.isOutputRowCountUnknown() || cap.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
        double cappedRowCount = Double.min(stats.getOutputRowCount(), cap.getOutputRowCount());
        result.setOutputRowCount(cappedRowCount);
        stats.getSymbolsWithKnownStatistics().forEach(symbol -> {
            SymbolStatsEstimate symbolStats = stats.getSymbolStatistics((Symbol)symbol);
            SymbolStatsEstimate capSymbolStats = cap.getSymbolStatistics((Symbol)symbol);
            SymbolStatsEstimate.Builder newSymbolStats = SymbolStatsEstimate.builder();
            newSymbolStats.setAverageRowSize(symbolStats.getAverageRowSize());
            newSymbolStats.setDistinctValuesCount(Double.min(symbolStats.getDistinctValuesCount(), capSymbolStats.getDistinctValuesCount()));
            newSymbolStats.setLowValue(Double.max(symbolStats.getLowValue(), capSymbolStats.getLowValue()));
            newSymbolStats.setHighValue(Double.min(symbolStats.getHighValue(), capSymbolStats.getHighValue()));
            double numberOfNulls = stats.getOutputRowCount() * symbolStats.getNullsFraction();
            double capNumberOfNulls = cap.getOutputRowCount() * capSymbolStats.getNullsFraction();
            double cappedNumberOfNulls = Double.min(numberOfNulls, capNumberOfNulls);
            double cappedNullsFraction = cappedRowCount == 0.0 ? 1.0 : cappedNumberOfNulls / cappedRowCount;
            newSymbolStats.setNullsFraction(cappedNullsFraction);
            result.addSymbolStatistics((Symbol)symbol, newSymbolStats.build());
        });
        return result.build();
    }

    public static Map<Symbol, SymbolStatsEstimate> intersectCorrelatedStats(List<PlanNodeStatsEstimate> estimates) {
        Preconditions.checkArgument((!estimates.isEmpty() ? 1 : 0) != 0, (Object)"estimates is empty");
        if (estimates.size() == 1) {
            return estimates.get(0).getSymbolStatistics();
        }
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
        estimates.stream().flatMap(estimate -> estimate.getSymbolsWithKnownStatistics().stream()).distinct().forEach(symbol -> {
            List symbolStatsEstimates = (List)estimates.stream().map(estimate -> estimate.getSymbolStatistics((Symbol)symbol)).collect(ImmutableList.toImmutableList());
            StatisticRange intersect = symbolStatsEstimates.stream().map(StatisticRange::from).reduce(StatisticRange::intersect).orElseThrow();
            double nullsFraction = symbolStatsEstimates.stream().map(SymbolStatsEstimate::getNullsFraction).reduce(MoreMath::minExcludeNaN).orElseThrow();
            double averageRowSize = symbolStatsEstimates.stream().map(SymbolStatsEstimate::getAverageRowSize).reduce(MoreMath::averageExcludingNaNs).orElseThrow();
            result.addSymbolStatistics((Symbol)symbol, SymbolStatsEstimate.builder().setStatisticsRange(intersect).setNullsFraction(nullsFraction).setAverageRowSize(averageRowSize).build());
        });
        return result.build().getSymbolStatistics();
    }

    public static double estimateCorrelatedConjunctionRowCount(PlanNodeStatsEstimate input, List<PlanNodeStatsEstimate> estimates, double independenceFactor) {
        Preconditions.checkArgument((!estimates.isEmpty() ? 1 : 0) != 0, (Object)"estimates is empty");
        if (input.isOutputRowCountUnknown() || input.getOutputRowCount() == 0.0) {
            return input.getOutputRowCount();
        }
        List knownSortedEstimates = (List)estimates.stream().filter(estimateInfo -> !estimateInfo.isOutputRowCountUnknown()).sorted(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)).collect(ImmutableList.toImmutableList());
        if (knownSortedEstimates.isEmpty()) {
            return Double.NaN;
        }
        PlanNodeStatsEstimate combinedEstimate = (PlanNodeStatsEstimate)knownSortedEstimates.get(0);
        double combinedSelectivity = combinedEstimate.getOutputRowCount() / input.getOutputRowCount();
        double combinedIndependenceFactor = 1.0;
        for (int i = 1; i < knownSortedEstimates.size(); ++i) {
            PlanNodeStatsEstimate term = (PlanNodeStatsEstimate)knownSortedEstimates.get(i);
            combinedSelectivity *= Math.pow(term.getOutputRowCount() / input.getOutputRowCount(), combinedIndependenceFactor *= independenceFactor);
        }
        double outputRowCount = input.getOutputRowCount() * combinedSelectivity;
        boolean hasUnestimatedTerm = estimates.stream().anyMatch(PlanNodeStatsEstimate::isOutputRowCountUnknown);
        return hasUnestimatedTerm ? outputRowCount * 0.9 : outputRowCount;
    }

    private static PlanNodeStatsEstimate createZeroStats(PlanNodeStatsEstimate stats) {
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder();
        result.setOutputRowCount(0.0);
        stats.getSymbolsWithKnownStatistics().forEach(symbol -> result.addSymbolStatistics((Symbol)symbol, SymbolStatsEstimate.zero()));
        return result.build();
    }

    public static double getFirstKnownOutputSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider) {
        return Stream.of(node).map(lookup::resolve).mapToDouble(resolvedNode -> {
            double outputSizeInBytes = statsProvider.getStats((PlanNode)resolvedNode).getOutputSizeInBytes(resolvedNode.getOutputSymbols());
            if (!Double.isNaN(outputSizeInBytes)) {
                return outputSizeInBytes;
            }
            if (EXPANDING_NODE_CLASSES.stream().anyMatch(clazz -> clazz.isInstance(resolvedNode))) {
                return Double.NaN;
            }
            List<PlanNode> sourceNodes = resolvedNode.getSources();
            if (sourceNodes.isEmpty()) {
                return Double.NaN;
            }
            double sourcesOutputSizeInBytes = 0.0;
            for (PlanNode sourceNode : sourceNodes) {
                double firstKnownOutputSizeInBytes = PlanNodeStatsEstimateMath.getFirstKnownOutputSizeInBytes(sourceNode, lookup, statsProvider);
                if (Double.isNaN(firstKnownOutputSizeInBytes)) {
                    return Double.NaN;
                }
                sourcesOutputSizeInBytes += firstKnownOutputSizeInBytes;
            }
            return sourcesOutputSizeInBytes;
        }).sum();
    }

    public static double getSourceTablesSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider) {
        boolean hasExpandingNodes = PlanNodeSearcher.searchFrom(node, lookup).whereIsInstanceOfAny(EXPANDING_NODE_CLASSES).matches();
        if (hasExpandingNodes) {
            return Double.NaN;
        }
        List<PlanNode> sourceNodes = PlanNodeSearcher.searchFrom(node, lookup).whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class, RemoteSourceNode.class).findAll();
        return sourceNodes.stream().mapToDouble(sourceNode -> statsProvider.getStats((PlanNode)sourceNode).getOutputSizeInBytes(sourceNode.getOutputSymbols())).sum();
    }

    public static PlanNodeStatsEstimate addStatsAndSumDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) {
        return PlanNodeStatsEstimateMath.addStats(left, right, StatisticRange::addAndSumDistinctValues);
    }

    public static PlanNodeStatsEstimate addStatsAndMaxDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) {
        return PlanNodeStatsEstimateMath.addStats(left, right, StatisticRange::addAndMaxDistinctValues);
    }

    public static PlanNodeStatsEstimate addStatsAndCollapseDistinctValues(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right) {
        return PlanNodeStatsEstimateMath.addStats(left, right, StatisticRange::addAndCollapseDistinctValues);
    }

    private static PlanNodeStatsEstimate addStats(PlanNodeStatsEstimate left, PlanNodeStatsEstimate right, RangeAdditionStrategy strategy) {
        if (left.isOutputRowCountUnknown() || right.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder();
        double newRowCount = left.getOutputRowCount() + right.getOutputRowCount();
        Stream.concat(left.getSymbolsWithKnownStatistics().stream(), right.getSymbolsWithKnownStatistics().stream()).distinct().forEach(symbol -> {
            SymbolStatsEstimate symbolStats = SymbolStatsEstimate.zero();
            if (newRowCount > 0.0) {
                symbolStats = PlanNodeStatsEstimateMath.addColumnStats(left.getSymbolStatistics((Symbol)symbol), left.getOutputRowCount(), right.getSymbolStatistics((Symbol)symbol), right.getOutputRowCount(), newRowCount, strategy);
            }
            statsBuilder.addSymbolStatistics((Symbol)symbol, symbolStats);
        });
        return statsBuilder.setOutputRowCount(newRowCount).build();
    }

    private static SymbolStatsEstimate addColumnStats(SymbolStatsEstimate leftStats, double leftRows, SymbolStatsEstimate rightStats, double rightRows, double newRowCount, RangeAdditionStrategy strategy) {
        Preconditions.checkArgument((newRowCount > 0.0 ? 1 : 0) != 0, (Object)"newRowCount must be greater than zero");
        StatisticRange leftRange = StatisticRange.from(leftStats);
        StatisticRange rightRange = StatisticRange.from(rightStats);
        StatisticRange sum = strategy.add(leftRange, rightRange);
        double nullsCountRight = rightStats.getNullsFraction() * rightRows;
        double nullsCountLeft = leftStats.getNullsFraction() * leftRows;
        double totalSizeLeft = (leftRows - nullsCountLeft) * leftStats.getAverageRowSize();
        double totalSizeRight = (rightRows - nullsCountRight) * rightStats.getAverageRowSize();
        double newNullsFraction = (nullsCountLeft + nullsCountRight) / newRowCount;
        double newNonNullsRowCount = newRowCount * (1.0 - newNullsFraction);
        double newAverageRowSize = newNonNullsRowCount == 0.0 ? 0.0 : (totalSizeLeft + totalSizeRight) / newNonNullsRowCount;
        return SymbolStatsEstimate.builder().setStatisticsRange(sum).setAverageRowSize(newAverageRowSize).setNullsFraction(newNullsFraction).build();
    }

    @FunctionalInterface
    private static interface RangeAdditionStrategy {
        public StatisticRange add(StatisticRange var1, StatisticRange var2);
    }
}

