/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import io.airlift.units.DataSize;
import io.trino.SystemSessionProperties;
import io.trino.cost.CostCalculatorWithEstimatedExchanges;
import io.trino.cost.CostComparator;
import io.trino.cost.LocalCostEstimate;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.PlanNodeWithCost;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Stream;

public class DetermineJoinDistributionType
implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(joinNode -> joinNode.getDistributionType().isEmpty());
    private static final List<Class<? extends PlanNode>> EXPANDING_NODE_CLASSES = ImmutableList.of(JoinNode.class, UnnestNode.class);
    private static final double SIZE_DIFFERENCE_THRESHOLD = 8.0;
    private final CostComparator costComparator;
    private final TaskCountEstimator taskCountEstimator;

    public DetermineJoinDistributionType(CostComparator costComparator, TaskCountEstimator taskCountEstimator) {
        this.costComparator = Objects.requireNonNull(costComparator, "costComparator is null");
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        OptimizerConfig.JoinDistributionType joinDistributionType = SystemSessionProperties.getJoinDistributionType(context.getSession());
        if (joinDistributionType == OptimizerConfig.JoinDistributionType.AUTOMATIC) {
            return Rule.Result.ofPlanNode(this.getCostBasedJoin(joinNode, context));
        }
        return Rule.Result.ofPlanNode(this.getSyntacticOrderJoin(joinNode, context, joinDistributionType));
    }

    public static boolean canReplicate(JoinNode joinNode, Rule.Context context) {
        OptimizerConfig.JoinDistributionType joinDistributionType = SystemSessionProperties.getJoinDistributionType(context.getSession());
        if (!joinDistributionType.canReplicate()) {
            return false;
        }
        DataSize joinMaxBroadcastTableSize = SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession());
        PlanNode buildSide = joinNode.getRight();
        PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide);
        double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide.getOutputSymbols());
        return buildSideSizeInBytes <= (double)joinMaxBroadcastTableSize.toBytes() || DetermineJoinDistributionType.getSourceTablesSizeInBytes(buildSide, context) <= (double)joinMaxBroadcastTableSize.toBytes();
    }

    public static double getSourceTablesSizeInBytes(PlanNode node, Rule.Context context) {
        return DetermineJoinDistributionType.getSourceTablesSizeInBytes(node, context.getLookup(), context.getStatsProvider());
    }

    @VisibleForTesting
    static double getSourceTablesSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider) {
        boolean hasExpandingNodes = PlanNodeSearcher.searchFrom(node, lookup).whereIsInstanceOfAny(EXPANDING_NODE_CLASSES).matches();
        if (hasExpandingNodes) {
            return Double.NaN;
        }
        List<PlanNode> sourceNodes = PlanNodeSearcher.searchFrom(node, lookup).whereIsInstanceOfAny(TableScanNode.class, ValuesNode.class).findAll();
        return sourceNodes.stream().mapToDouble(sourceNode -> statsProvider.getStats((PlanNode)sourceNode).getOutputSizeInBytes(sourceNode.getOutputSymbols())).sum();
    }

    private static double getFirstKnownOutputSizeInBytes(PlanNode node, Rule.Context context) {
        return DetermineJoinDistributionType.getFirstKnownOutputSizeInBytes(node, context.getLookup(), context.getStatsProvider());
    }

    @VisibleForTesting
    static double getFirstKnownOutputSizeInBytes(PlanNode node, Lookup lookup, StatsProvider statsProvider) {
        return Stream.of(node).map(lookup::resolve).mapToDouble(resolvedNode -> {
            double outputSizeInBytes = statsProvider.getStats((PlanNode)resolvedNode).getOutputSizeInBytes(resolvedNode.getOutputSymbols());
            if (!Double.isNaN(outputSizeInBytes)) {
                return outputSizeInBytes;
            }
            if (EXPANDING_NODE_CLASSES.stream().anyMatch(clazz -> clazz.isInstance(resolvedNode))) {
                return Double.NaN;
            }
            List<PlanNode> sourceNodes = resolvedNode.getSources();
            if (sourceNodes.isEmpty()) {
                return Double.NaN;
            }
            double sourcesOutputSizeInBytes = 0.0;
            for (PlanNode sourceNode : sourceNodes) {
                double firstKnownOutputSizeInBytes = DetermineJoinDistributionType.getFirstKnownOutputSizeInBytes(sourceNode, lookup, statsProvider);
                if (Double.isNaN(firstKnownOutputSizeInBytes)) {
                    return Double.NaN;
                }
                sourcesOutputSizeInBytes += firstKnownOutputSizeInBytes;
            }
            return sourcesOutputSizeInBytes;
        }).sum();
    }

    private PlanNode getCostBasedJoin(JoinNode joinNode, Rule.Context context) {
        ArrayList<PlanNodeWithCost> possibleJoinNodes = new ArrayList<PlanNodeWithCost>();
        this.addJoinsWithDifferentDistributions(joinNode, possibleJoinNodes, context);
        this.addJoinsWithDifferentDistributions(joinNode.flipChildren(), possibleJoinNodes, context);
        if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents()) || possibleJoinNodes.isEmpty()) {
            return this.getSizeBasedJoin(joinNode, context);
        }
        Ordering planNodeOrderings = this.costComparator.forSession(context.getSession()).onResultOf(PlanNodeWithCost::getCost);
        return ((PlanNodeWithCost)planNodeOrderings.min(possibleJoinNodes)).getPlanNode();
    }

    private JoinNode getSizeBasedJoin(JoinNode joinNode, Rule.Context context) {
        boolean isRightSideSmall;
        DataSize joinMaxBroadcastTableSize = SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession());
        boolean bl = isRightSideSmall = DetermineJoinDistributionType.getSourceTablesSizeInBytes(joinNode.getRight(), context) <= (double)joinMaxBroadcastTableSize.toBytes();
        if (isRightSideSmall && !DetermineJoinDistributionType.mustPartition(joinNode)) {
            return joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED);
        }
        boolean isLeftSideSmall = DetermineJoinDistributionType.getSourceTablesSizeInBytes(joinNode.getLeft(), context) <= (double)joinMaxBroadcastTableSize.toBytes();
        JoinNode flippedJoin = joinNode.flipChildren();
        if (isLeftSideSmall && !DetermineJoinDistributionType.mustPartition(flippedJoin)) {
            return flippedJoin.withDistributionType(JoinNode.DistributionType.REPLICATED);
        }
        if (isRightSideSmall) {
            return joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        if (isLeftSideSmall) {
            return flippedJoin.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        double leftOutputSizeInBytes = DetermineJoinDistributionType.getFirstKnownOutputSizeInBytes(joinNode.getLeft(), context);
        double rightOutputSizeInBytes = DetermineJoinDistributionType.getFirstKnownOutputSizeInBytes(joinNode.getRight(), context);
        if (rightOutputSizeInBytes * 8.0 < leftOutputSizeInBytes && !DetermineJoinDistributionType.mustReplicate(joinNode, context)) {
            return joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        if (leftOutputSizeInBytes * 8.0 < rightOutputSizeInBytes && !DetermineJoinDistributionType.mustReplicate(flippedJoin, context)) {
            return flippedJoin.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        return this.getSyntacticOrderJoin(joinNode, context, OptimizerConfig.JoinDistributionType.AUTOMATIC);
    }

    private void addJoinsWithDifferentDistributions(JoinNode joinNode, List<PlanNodeWithCost> possibleJoinNodes, Rule.Context context) {
        if (!DetermineJoinDistributionType.mustPartition(joinNode) && DetermineJoinDistributionType.canReplicate(joinNode, context)) {
            possibleJoinNodes.add(this.getJoinNodeWithCost(context, joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED)));
        }
        if (!DetermineJoinDistributionType.mustReplicate(joinNode, context)) {
            possibleJoinNodes.add(this.getJoinNodeWithCost(context, joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED)));
        }
    }

    private JoinNode getSyntacticOrderJoin(JoinNode joinNode, Rule.Context context, OptimizerConfig.JoinDistributionType joinDistributionType) {
        if (DetermineJoinDistributionType.mustPartition(joinNode)) {
            return joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        if (DetermineJoinDistributionType.mustReplicate(joinNode, context)) {
            return joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED);
        }
        if (joinDistributionType.canPartition()) {
            return joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        return joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED);
    }

    private static boolean mustPartition(JoinNode joinNode) {
        JoinType type = joinNode.getType();
        return type == JoinType.RIGHT || type == JoinType.FULL;
    }

    private static boolean mustReplicate(JoinNode joinNode, Rule.Context context) {
        JoinType type = joinNode.getType();
        if (joinNode.getCriteria().isEmpty() && (type == JoinType.INNER || type == JoinType.LEFT)) {
            return true;
        }
        return QueryCardinalityUtil.isAtMostScalar(joinNode.getRight(), context.getLookup());
    }

    private PlanNodeWithCost getJoinNodeWithCost(Rule.Context context, JoinNode possibleJoinNode) {
        StatsProvider stats = context.getStatsProvider();
        boolean replicated = possibleJoinNode.getDistributionType().get() == JoinNode.DistributionType.REPLICATED;
        int estimatedSourceDistributedTaskCount = this.taskCountEstimator.estimateSourceDistributedTaskCount(context.getSession());
        LocalCostEstimate cost = CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput(possibleJoinNode.getLeft(), possibleJoinNode.getRight(), stats, replicated, estimatedSourceDistributedTaskCount);
        return new PlanNodeWithCost(cost.toPlanCost(), possibleJoinNode);
    }
}

