/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.FilterStatsCalculator;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.PlanNodeStatsEstimateMath;
import io.trino.cost.SimpleStatsRule;
import io.trino.cost.StatisticRange;
import io.trino.cost.StatsNormalizer;
import io.trino.cost.StatsProvider;
import io.trino.cost.SymbolStatsEstimate;
import io.trino.cost.TableStatsProvider;
import io.trino.matching.Pattern;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.util.MoreMath;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

public class JoinStatsRule
extends SimpleStatsRule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();
    private static final double DEFAULT_UNMATCHED_JOIN_COMPLEMENT_NDVS_COEFFICIENT = 0.5;
    private final FilterStatsCalculator filterStatsCalculator;
    private final StatsNormalizer normalizer;
    private final double unmatchedJoinComplementNdvsCoefficient;

    public JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer normalizer) {
        this(filterStatsCalculator, normalizer, 0.5);
    }

    @VisibleForTesting
    JoinStatsRule(FilterStatsCalculator filterStatsCalculator, StatsNormalizer normalizer, double unmatchedJoinComplementNdvsCoefficient) {
        super(normalizer);
        this.filterStatsCalculator = Objects.requireNonNull(filterStatsCalculator, "filterStatsCalculator is null");
        this.normalizer = normalizer;
        this.unmatchedJoinComplementNdvsCoefficient = unmatchedJoinComplementNdvsCoefficient;
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    protected Optional<PlanNodeStatsEstimate> doCalculate(JoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) {
        PlanNodeStatsEstimate leftStats = sourceStats.getStats(node.getLeft());
        PlanNodeStatsEstimate rightStats = sourceStats.getStats(node.getRight());
        PlanNodeStatsEstimate crossJoinStats = this.crossJoinStats(node, leftStats, rightStats, types);
        switch (node.getType()) {
            case INNER: {
                return Optional.of(this.computeInnerJoinStats(node, crossJoinStats, session, types));
            }
            case LEFT: {
                return Optional.of(this.computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
            case RIGHT: {
                return Optional.of(this.computeRightJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
            case FULL: {
                return Optional.of(this.computeFullJoinStats(node, leftStats, rightStats, crossJoinStats, session, types));
            }
        }
        throw new IllegalStateException("Unknown join type: " + node.getType());
    }

    private PlanNodeStatsEstimate computeFullJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate rightJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), this.flippedCriteria(node), rightStats, leftStats, types);
        return this.addJoinComplementStats(rightStats, this.computeLeftJoinStats(node, leftStats, rightStats, crossJoinStats, session, types), rightJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeLeftJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate innerJoinStats = this.computeInnerJoinStats(node, crossJoinStats, session, types);
        PlanNodeStatsEstimate leftJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), node.getCriteria(), leftStats, rightStats, types);
        return this.addJoinComplementStats(leftStats, innerJoinStats, leftJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeRightJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        PlanNodeStatsEstimate innerJoinStats = this.computeInnerJoinStats(node, crossJoinStats, session, types);
        PlanNodeStatsEstimate rightJoinComplementStats = this.calculateJoinComplementStats(node.getFilter(), this.flippedCriteria(node), rightStats, leftStats, types);
        return this.addJoinComplementStats(rightStats, innerJoinStats, rightJoinComplementStats);
    }

    private PlanNodeStatsEstimate computeInnerJoinStats(JoinNode node, PlanNodeStatsEstimate crossJoinStats, Session session, TypeProvider types) {
        List<JoinNode.EquiJoinClause> equiJoinCriteria = node.getCriteria();
        if (equiJoinCriteria.isEmpty()) {
            if (node.getFilter().isEmpty()) {
                return crossJoinStats;
            }
            return this.filterStatsCalculator.filterStats(crossJoinStats, node.getFilter().get(), session, types);
        }
        PlanNodeStatsEstimate equiJoinEstimate = this.filterByEquiJoinClauses(crossJoinStats, node.getCriteria(), session, types);
        if (equiJoinEstimate.isOutputRowCountUnknown()) {
            return PlanNodeStatsEstimate.unknown();
        }
        if (node.getFilter().isEmpty()) {
            return equiJoinEstimate;
        }
        PlanNodeStatsEstimate filteredEquiJoinEstimate = this.filterStatsCalculator.filterStats(equiJoinEstimate, node.getFilter().get(), session, types);
        if (filteredEquiJoinEstimate.isOutputRowCountUnknown()) {
            return this.normalizer.normalize(equiJoinEstimate.mapOutputRowCount(rowCount -> rowCount * 0.9), types);
        }
        return filteredEquiJoinEstimate;
    }

    private PlanNodeStatsEstimate filterByEquiJoinClauses(PlanNodeStatsEstimate stats, Collection<JoinNode.EquiJoinClause> clauses, Session session, TypeProvider types) {
        Preconditions.checkArgument((!clauses.isEmpty() ? 1 : 0) != 0, (Object)"clauses is empty");
        List knownEstimates = (List)clauses.stream().map(clause -> {
            ComparisonExpression predicate = new ComparisonExpression(ComparisonExpression.Operator.EQUAL, (Expression)clause.getLeft().toSymbolReference(), (Expression)clause.getRight().toSymbolReference());
            return new PlanNodeStatsEstimateWithClause(this.filterStatsCalculator.filterStats(stats, (Expression)predicate, session, types), (JoinNode.EquiJoinClause)clause);
        }).collect(ImmutableList.toImmutableList());
        double outputRowCount = PlanNodeStatsEstimateMath.estimateCorrelatedConjunctionRowCount(stats, (List)knownEstimates.stream().map(PlanNodeStatsEstimateWithClause::getEstimate).collect(ImmutableList.toImmutableList()), SystemSessionProperties.getJoinMultiClauseIndependenceFactor(session));
        if (Double.isNaN(outputRowCount)) {
            return PlanNodeStatsEstimate.unknown();
        }
        return this.normalizer.normalize(new PlanNodeStatsEstimate(outputRowCount, JoinStatsRule.intersectCorrelatedJoinClause(stats, knownEstimates)), types);
    }

    private static Map<Symbol, SymbolStatsEstimate> intersectCorrelatedJoinClause(PlanNodeStatsEstimate stats, List<PlanNodeStatsEstimateWithClause> equiJoinClauseEstimates) {
        PlanNodeStatsEstimate.Builder result = PlanNodeStatsEstimate.builder().addSymbolStatistics(stats.getSymbolStatistics());
        for (PlanNodeStatsEstimateWithClause estimateWithClause : equiJoinClauseEstimates) {
            JoinNode.EquiJoinClause clause = estimateWithClause.getClause();
            SymbolStatsEstimate leftStats = stats.getSymbolStatistics(clause.getLeft());
            SymbolStatsEstimate rightStats = stats.getSymbolStatistics(clause.getRight());
            StatisticRange leftRange = StatisticRange.from(leftStats);
            StatisticRange rightRange = StatisticRange.from(rightStats);
            StatisticRange intersect = leftRange.intersect(rightRange);
            double leftFilterValue = MoreMath.firstNonNaN(leftRange.overlapPercentWith(intersect), 1.0);
            double rightFilterValue = MoreMath.firstNonNaN(rightRange.overlapPercentWith(intersect), 1.0);
            double leftNdvInRange = leftFilterValue * leftRange.getDistinctValuesCount();
            double rightNdvInRange = rightFilterValue * rightRange.getDistinctValuesCount();
            double retainedNdv = MoreMath.min(leftNdvInRange, rightNdvInRange);
            SymbolStatsEstimate newLeftStats = SymbolStatsEstimate.buildFrom(leftStats).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
            SymbolStatsEstimate newRightStats = SymbolStatsEstimate.buildFrom(rightStats).setNullsFraction(0.0).setStatisticsRange(intersect).setDistinctValuesCount(retainedNdv).build();
            result.addSymbolStatistics(clause.getLeft(), newLeftStats).addSymbolStatistics(clause.getRight(), newRightStats);
        }
        return result.build().getSymbolStatistics();
    }

    @VisibleForTesting
    PlanNodeStatsEstimate calculateJoinComplementStats(Optional<Expression> filter, List<JoinNode.EquiJoinClause> criteria, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, TypeProvider types) {
        if (rightStats.getOutputRowCount() == 0.0) {
            return leftStats;
        }
        if (criteria.isEmpty()) {
            if (filter.isPresent()) {
                return PlanNodeStatsEstimate.unknown();
            }
            return this.normalizer.normalize(leftStats.mapOutputRowCount(rowCount -> 0.0), types);
        }
        int numberOfFilterClauses = filter.map(expression -> ExpressionUtils.extractConjuncts(expression).size()).orElse(0);
        return criteria.stream().map(drivingClause -> this.calculateJoinComplementStats(leftStats, rightStats, (JoinNode.EquiJoinClause)drivingClause, criteria.size() - 1 + numberOfFilterClauses)).filter(estimate -> !estimate.isOutputRowCountUnknown()).max(Comparator.comparingDouble(PlanNodeStatsEstimate::getOutputRowCount)).map(estimate -> this.normalizer.normalize((PlanNodeStatsEstimate)estimate, types)).orElse(PlanNodeStatsEstimate.unknown());
    }

    private PlanNodeStatsEstimate calculateJoinComplementStats(PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, JoinNode.EquiJoinClause drivingClause, int numberOfRemainingClauses) {
        double matchingRightNDV;
        PlanNodeStatsEstimate result = leftStats;
        SymbolStatsEstimate leftColumnStats = leftStats.getSymbolStatistics(drivingClause.getLeft());
        SymbolStatsEstimate rightColumnStats = rightStats.getSymbolStatistics(drivingClause.getRight());
        double leftNDV = leftColumnStats.getDistinctValuesCount();
        if (leftNDV > (matchingRightNDV = rightColumnStats.getDistinctValuesCount() * this.unmatchedJoinComplementNdvsCoefficient)) {
            double nonMatchingLeftValuesFraction = leftColumnStats.getValuesFraction() * (leftNDV - matchingRightNDV) / leftNDV;
            double scaleFactor = nonMatchingLeftValuesFraction + leftColumnStats.getNullsFraction();
            double newLeftNullsFraction = leftColumnStats.getNullsFraction() / scaleFactor;
            result = result.mapSymbolColumnStatistics(drivingClause.getLeft(), columnStats -> SymbolStatsEstimate.buildFrom(columnStats).setLowValue(leftColumnStats.getLowValue()).setHighValue(leftColumnStats.getHighValue()).setNullsFraction(newLeftNullsFraction).setDistinctValuesCount(leftNDV - matchingRightNDV).build());
            result = result.mapOutputRowCount(rowCount -> rowCount * scaleFactor);
        } else if (leftNDV <= matchingRightNDV) {
            result = result.mapSymbolColumnStatistics(drivingClause.getLeft(), columnStats -> SymbolStatsEstimate.buildFrom(columnStats).setLowValue(Double.NaN).setHighValue(Double.NaN).setNullsFraction(1.0).setDistinctValuesCount(0.0).build());
            result = result.mapOutputRowCount(rowCount -> rowCount * leftColumnStats.getNullsFraction());
        } else {
            return PlanNodeStatsEstimate.unknown();
        }
        result = result.mapOutputRowCount(rowCount -> Math.min(leftStats.getOutputRowCount(), rowCount / Math.pow(0.9, numberOfRemainingClauses)));
        return result;
    }

    @VisibleForTesting
    PlanNodeStatsEstimate addJoinComplementStats(PlanNodeStatsEstimate sourceStats, PlanNodeStatsEstimate innerJoinStats, PlanNodeStatsEstimate joinComplementStats) {
        double innerJoinRowCount = innerJoinStats.getOutputRowCount();
        double joinComplementRowCount = joinComplementStats.getOutputRowCount();
        if (joinComplementRowCount == 0.0) {
            return innerJoinStats;
        }
        double outputRowCount = innerJoinRowCount + joinComplementRowCount;
        PlanNodeStatsEstimate.Builder outputStats = PlanNodeStatsEstimate.buildFrom(innerJoinStats);
        outputStats.setOutputRowCount(outputRowCount);
        for (Symbol symbol : joinComplementStats.getSymbolsWithKnownStatistics()) {
            SymbolStatsEstimate leftSymbolStats = sourceStats.getSymbolStatistics(symbol);
            SymbolStatsEstimate innerJoinSymbolStats = innerJoinStats.getSymbolStatistics(symbol);
            SymbolStatsEstimate joinComplementSymbolStats = joinComplementStats.getSymbolStatistics(symbol);
            double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementSymbolStats.getNullsFraction() * joinComplementRowCount) / outputRowCount;
            outputStats.addSymbolStatistics(symbol, SymbolStatsEstimate.buildFrom(innerJoinSymbolStats).setLowValue(leftSymbolStats.getLowValue()).setHighValue(leftSymbolStats.getHighValue()).setDistinctValuesCount(leftSymbolStats.getDistinctValuesCount()).setNullsFraction(newNullsFraction).build());
        }
        for (Symbol symbol : Sets.difference(innerJoinStats.getSymbolsWithKnownStatistics(), joinComplementStats.getSymbolsWithKnownStatistics())) {
            SymbolStatsEstimate innerJoinSymbolStats = innerJoinStats.getSymbolStatistics(symbol);
            double newNullsFraction = (innerJoinSymbolStats.getNullsFraction() * innerJoinRowCount + joinComplementRowCount) / outputRowCount;
            outputStats.addSymbolStatistics(symbol, innerJoinSymbolStats.mapNullsFraction(nullsFraction -> newNullsFraction));
        }
        return outputStats.build();
    }

    private PlanNodeStatsEstimate crossJoinStats(JoinNode node, PlanNodeStatsEstimate leftStats, PlanNodeStatsEstimate rightStats, TypeProvider types) {
        PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder().setOutputRowCount(leftStats.getOutputRowCount() * rightStats.getOutputRowCount());
        node.getLeft().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics((Symbol)symbol, leftStats.getSymbolStatistics((Symbol)symbol)));
        node.getRight().getOutputSymbols().forEach(symbol -> builder.addSymbolStatistics((Symbol)symbol, rightStats.getSymbolStatistics((Symbol)symbol)));
        return this.normalizer.normalize(builder.build(), types);
    }

    private List<JoinNode.EquiJoinClause> flippedCriteria(JoinNode node) {
        return (List)node.getCriteria().stream().map(JoinNode.EquiJoinClause::flip).collect(ImmutableList.toImmutableList());
    }

    private static class PlanNodeStatsEstimateWithClause {
        private final PlanNodeStatsEstimate estimate;
        private final JoinNode.EquiJoinClause clause;

        private PlanNodeStatsEstimateWithClause(PlanNodeStatsEstimate estimate, JoinNode.EquiJoinClause clause) {
            this.estimate = Objects.requireNonNull(estimate, "estimate is null");
            this.clause = Objects.requireNonNull(clause, "clause is null");
        }

        private PlanNodeStatsEstimate getEstimate() {
            return this.estimate;
        }

        private JoinNode.EquiJoinClause getClause() {
            return this.clause;
        }
    }
}

