/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.FilterStatsCalculator;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.SimpleStatsRule;
import com.facebook.presto.cost.StatisticRange;
import com.facebook.presto.cost.StatsNormalizer;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.cost.SymbolStatsEstimate;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.util.MoreMath;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.IntStream;

public class JoinStatsRule
extends SimpleStatsRule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private static final double DEFAULT_UNMATCHED_JOIN_COMPLEMENT_NDVS_COEFFICIENT = 0.5;
    private final FilterStatsCalculator filterStatsCalculator;
    private final double unmatchedJoinComplementNdvsCoefficient;

    public JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer normalizer) {
        this(filterStatsCalculator, normalizer, 0.5);
    }

    @VisibleForTesting
    JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer normalizer, double unmatchedJoinComplementNdvsCoefficient) {
        super(normalizer);
        this.filterStatsCalculator = Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator is null");
        this.unmatchedJoinComplementNdvsCoefficient = unmatchedJoinComplementNdvsCoefficient;
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(JoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) {
        PlanNodeStatsEstimate leftStats = sourceStats.getStats(node.getLeft());
        PlanNodeStatsEstimate rightStats = sourceStats.getStats(node.getRight());
        PlanNodeStatsEstimate crossJoinStats = this.crossJoinStats(node, leftStats, rightStats);
        switch (node.getType()) {
            case INNER: {
                return Optional.of(this.computeInnerJoinStats(node, crossJoinStats, session, types));
            }
            case LEFT: {
                return Optional.of(this.computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
            case RIGHT: {
                return Optional.of(this.computeRightJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
            case FULL: {
                return Optional.of(this.computeFullJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
        }
        throw new IllegalStateException("Unknown join type: " + (Object)((Object)node.getType()));
    }

    private PlanNodeStatsEstimate computeFullJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate rightJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), this.flippedCriteria(node), rightStats, leftStats);
        return this.addJoinComplementStats(rightStats, this.computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, session, types), rightJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeLeftJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate innerJoinStats = this.computeInnerJoinStats(node, crossJoinStats, session, types);
        PlanNodeStatsEstimate leftJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), node.getCriteria(), leftStats, rightStats);
        return this.addJoinComplementStats(leftStats, innerJoinStats, leftJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeRightJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate innerJoinStats = this.computeInnerJoinStats(node, crossJoinStats, session, types);
        PlanNodeStatsEstimate rightJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), this.flippedCriteria(node), rightStats, leftStats);
        return this.addJoinComplementStats(rightStats, innerJoinStats, rightJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode node, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        List<JoinNode.EquiJoinClause> equiJoinClauses = node.getCriteria();
        PlanNodeStatsEstimate equiJoinClausesFilteredStats = IntStream.range(0, equiJoinClauses.size()).mapToObj(drivingClauseId -> {
            JoinNode.EquiJoinClause drivingClause = (JoinNode.EquiJoinClause)equiJoinClauses.get(drivingClauseId);
            List<JoinNode.EquiJoinClause> remainingClauses = JoinStatsRule.copyWithout(equiJoinClauses, drivingClauseId);
            return this.filterByEquiJoinClauses(crossJoinStats, drivingClause, remainingClauses, session, types);
        }).min(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)).orElse(crossJoinStats);
        return node.getFilter().map(filter -> this.filterStatsCalculator.filterStats(equiJoinClausesFilteredStats, (Expression)filter, session, types)).orElse(equiJoinClausesFilteredStats);
    }

    private static <T> List<T> copyWithout(List<? extends T> list, int filteredOutIndex) {
        ArrayList<T> copy = new ArrayList<T>(list);
        copy.remove(filteredOutIndex);
        return copy;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate stats, JoinNode.EquiJoinClause drivingClause, List<JoinNode.EquiJoinClause> auxiliaryClauses, Session session, TypeProvider types) {
        ComparisonExpression drivingPredicate = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)drivingClause.getLeft().toSymbolReference(), (Expression)drivingClause.getRight().toSymbolReference());
        PlanNodeStatsEstimate filteredStats = this.filterStatsCalculator.filterStats(stats, (Expression)drivingPredicate, session, types);
        for (JoinNode.EquiJoinClause clause : auxiliaryClauses) {
            filteredStats = this.filterByAuxiliaryClause(filteredStats, clause);
        }
        return filteredStats;
    }

    private PlanNodeStatsEstimate filterByAuxiliaryClause(PlanNodeStatsEstimate stats, JoinNode.EquiJoinClause clause) {
        SymbolStatsEstimate leftStats = stats.getSymbolStatistics(clause.getLeft());
        SymbolStatsEstimate rightStats = stats.getSymbolStatistics(clause.getRight());
        StatisticRange leftRange = StatisticRange.from(leftStats);
        StatisticRange rightRange = StatisticRange.from(rightStats);
        StatisticRange intersect = leftRange.intersect(rightRange);
        double leftFilterValue = JoinStatsRule.firstNonNaN(leftRange.overlapPercentWith(intersect), 1.0);
        double rightFilterValue = JoinStatsRule.firstNonNaN(rightRange.overlapPercentWith(intersect), 1.0);
        double leftNdvInRange = leftFilterValue * leftRange.getDistinctValuesCount();
        double rightNdvInRange = rightFilterValue * rightRange.getDistinctValuesCount();
        double retainedNdv = MoreMath.min(leftNdvInRange, rightNdvInRange);
        SymbolStatsEstimate newLeftStats = SymbolStatsEstimate.buildFrom(leftStats).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
        SymbolStatsEstimate newRightStats = SymbolStatsEstimate.buildFrom(rightStats).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
        return stats.mapSymbolColumnStatistics(clause.getLeft(), oldLeftStats -> newLeftStats).mapSymbolColumnStatistics(clause.getRight(), oldRightStats -> newRightStats).mapOutputRowCount(rowCount -> rowCount * 0.9);
    }

    private static double firstNonNaN(double ... values) {
        for (double value : values) {
            if (Double.isNaN(value)) continue;
            return value;
        }
        throw new IllegalArgumentException("All values are NaN");
    }

    @VisibleForTesting
    PlanNodeStatsEstimate calculateJoinComplementStats(Optional<Expression> filter, List<JoinNode.EquiJoinClause> criteria, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats) {
        if (rightStats.getOutputRowCount() == 0.0) {
            return leftStats;
        }
        if (criteria.isEmpty()) {
            if (filter.isPresent()) {
                return PlanNodeStatsEstimate.UNKNOWN_STATS;
            }
            return leftStats.mapOutputRowCount(rowCount -> 0.0);
        }
        int numberOfFilterClauses = filter.map(exression -> ExpressionUtils.extractConjuncts(exression).size()).orElse(0);
        return IntStream.range(0, criteria.size()).mapToObj(drivingClauseId -> {
            JoinNode.EquiJoinClause drivingClause = (JoinNode.EquiJoinClause)criteria.get(drivingClauseId);
            return this.calculateJoinComplementStats(leftStats, rightStats, drivingClause, criteria.size() - 1 + numberOfFilterClauses);
        }).max(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)).get();
    }

    private PlanNodeStatsEstimate calculateJoinComplementStats(PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, JoinNode.EquiJoinClause drivingClause, int numberOfRemainingClauses) {
        double matchingRightNDV;
        PlanNodeStatsEstimate result = leftStats;
        SymbolStatsEstimate leftColumnStats = leftStats.getSymbolStatistics(drivingClause.getLeft());
        SymbolStatsEstimate rightColumnStats = rightStats.getSymbolStatistics(drivingClause.getRight());
        double leftNDV = leftColumnStats.getDistinctValuesCount();
        if (leftNDV > (matchingRightNDV = rightColumnStats.getDistinctValuesCount() * this.unmatchedJoinComplementNdvsCoefficient)) {
            double nonMatchingLeftValuesFraction = leftColumnStats.getValuesFraction() * (leftNDV - matchingRightNDV) / leftNDV;
            double scaleFactor = nonMatchingLeftValuesFraction + leftColumnStats.getNullsFraction();
            double newLeftNullsFraction = leftColumnStats.getNullsFraction() / scaleFactor;
            result = result.mapSymbolColumnStatistics(drivingClause.getLeft(), columnStats -> SymbolStatsEstimate.buildFrom(columnStats).setLowValue(leftColumnStats.getLowValue()).setHighValue(leftColumnStats.getHighValue()).setNullsFraction(newLeftNullsFraction).setDistinctValuesCount(leftNDV - matchingRightNDV).build());
            result = result.mapOutputRowCount(rowCount -> rowCount * scaleFactor);
        } else if (leftNDV <= matchingRightNDV) {
            result = result.mapSymbolColumnStatistics(drivingClause.getLeft(), columnStats -> SymbolStatsEstimate.buildFrom(columnStats).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(1.0).setDistinctValuesCount(0.0).build());
            result = result.mapOutputRowCount(rowCount -> rowCount * leftColumnStats.getNullsFraction());
        } else {
            return PlanNodeStatsEstimate.UNKNOWN_STATS;
        }
        result = result.mapOutputRowCount(rowCount -> Math.min(leftStats.getOutputRowCount(), rowCount / Math.pow(0.9, numberOfRemainingClauses)));
        return result;
    }

    @VisibleForTesting
    PlanNodeStatsEstimate addJoinComplementStats(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate baseJoinStats, PlanNodeStatsEstimate joinComplementStats) {
        Preconditions.checkState((boolean)baseJoinStats.getSymbolsWithKnownStatistics().containsAll(joinComplementStats.getSymbolsWithKnownStatistics()));
        double joinOutputRowCount = baseJoinStats.getOutputRowCount();
        double joinComplementOutputRowCount = joinComplementStats.getOutputRowCount();
        double totalRowCount = joinOutputRowCount + joinComplementOutputRowCount;
        PlanNodeStatsEstimate.Builder outputStats = PlanNodeStatsEstimate.buildFrom(baseJoinStats);
        outputStats.setOutputRowCount(joinOutputRowCount + joinComplementOutputRowCount);
        for (Symbol symbol : joinComplementStats.getSymbolsWithKnownStatistics()) {
            SymbolStatsEstimate sourceSymbolStats = sourceStats.getSymbolStatistics(symbol);
            SymbolStatsEstimate innerSymbolStats = baseJoinStats.getSymbolStatistics(symbol);
            SymbolStatsEstimate joinComplementSymbolStats = joinComplementStats.getSymbolStatistics(symbol);
            double newNullsFraction = (innerSymbolStats.getNullsFraction() * joinOutputRowCount + joinComplementSymbolStats.getNullsFraction() * joinComplementOutputRowCount) / totalRowCount;
            outputStats.addSymbolStatistics(symbol, SymbolStatsEstimate.buildFrom(innerSymbolStats).setLowValue(sourceSymbolStats.getLowValue()).setHighValue(sourceSymbolStats.getHighValue()).setDistinctValuesCount(sourceSymbolStats.getDistinctValuesCount()).setNullsFraction(newNullsFraction).build());
        }
        for (Symbol symbol : Sets.difference(baseJoinStats.getSymbolsWithKnownStatistics(), joinComplementStats.getSymbolsWithKnownStatistics())) {
            SymbolStatsEstimate innerSymbolStats = baseJoinStats.getSymbolStatistics(symbol);
            double newNullsFraction = (innerSymbolStats.getNullsFraction() * joinOutputRowCount + joinComplementOutputRowCount) / totalRowCount;
            outputStats.addSymbolStatistics(symbol, innerSymbolStats.mapNullsFraction(nullsFraction -> newNullsFraction));
        }
        return outputStats.build();
    }

    private PlanNodeStatsEstimate crossJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats) {
        PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder().setOutputRowCount(leftStats.getOutputRowCount() * rightStats.getOutputRowCount());
        node.getLeft().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics((Symbol)symbol, leftStats.getSymbolStatistics((Symbol)symbol)));
        node.getRight().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics((Symbol)symbol, rightStats.getSymbolStatistics((Symbol)symbol)));
        return builder.build();
    }

    private List<JoinNode.EquiJoinClause> flippedCriteria(JoinNode node) {
        return (List)node.getCriteria().stream().map(JoinNode.EquiJoinClause::flip).collect(ImmutableList.toImmutableList());
    }
}

