/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.Ordering;
import io.airlift.units.DataSize;
import io.trino.SystemSessionProperties;
import io.trino.cost.CostCalculatorWithEstimatedExchanges;
import io.trino.cost.CostComparator;
import io.trino.cost.LocalCostEstimate;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.OptimizerConfig;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.iterative.rule.DetermineJoinDistributionType;
import io.trino.sql.planner.iterative.rule.PlanNodeWithCost;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import java.util.ArrayList;
import java.util.Objects;

public class DetermineSemiJoinDistributionType
implements Rule<SemiJoinNode> {
    private final TaskCountEstimator taskCountEstimator;
    private final CostComparator costComparator;
    private static final Pattern<SemiJoinNode> PATTERN = Patterns.semiJoin().matching(semiJoin -> semiJoin.getDistributionType().isEmpty());

    public DetermineSemiJoinDistributionType(CostComparator costComparator, TaskCountEstimator taskCountEstimator) {
        this.costComparator = Objects.requireNonNull(costComparator, "costComparator is null");
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public Pattern<SemiJoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(SemiJoinNode semiJoinNode, Captures captures, Rule.Context context) {
        OptimizerConfig.JoinDistributionType joinDistributionType = SystemSessionProperties.getJoinDistributionType(context.getSession());
        switch (joinDistributionType) {
            case AUTOMATIC: {
                return Rule.Result.ofPlanNode(this.getCostBasedDistributionType(semiJoinNode, context));
            }
            case PARTITIONED: {
                return Rule.Result.ofPlanNode(semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.PARTITIONED));
            }
            case BROADCAST: {
                return Rule.Result.ofPlanNode(semiJoinNode.withDistributionType(SemiJoinNode.DistributionType.REPLICATED));
            }
        }
        throw new IllegalArgumentException("Unknown join_distribution_type: " + joinDistributionType);
    }

    private PlanNode getCostBasedDistributionType(SemiJoinNode node, Rule.Context context) {
        if (!this.canReplicate(node, context)) {
            return node.withDistributionType(SemiJoinNode.DistributionType.PARTITIONED);
        }
        ArrayList<PlanNodeWithCost> possibleJoinNodes = new ArrayList<PlanNodeWithCost>();
        possibleJoinNodes.add(this.getSemiJoinNodeWithCost(node.withDistributionType(SemiJoinNode.DistributionType.REPLICATED), context));
        possibleJoinNodes.add(this.getSemiJoinNodeWithCost(node.withDistributionType(SemiJoinNode.DistributionType.PARTITIONED), context));
        if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents())) {
            return this.getSizeBaseDistributionType(node, context);
        }
        Ordering planNodeOrderings = this.costComparator.forSession(context.getSession()).onResultOf(PlanNodeWithCost::getCost);
        return ((PlanNodeWithCost)planNodeOrderings.min(possibleJoinNodes)).getPlanNode();
    }

    private PlanNode getSizeBaseDistributionType(SemiJoinNode node, Rule.Context context) {
        DataSize joinMaxBroadcastTableSize = SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession());
        if (DetermineJoinDistributionType.getSourceTablesSizeInBytes(node.getFilteringSource(), context) <= (double)joinMaxBroadcastTableSize.toBytes()) {
            return node.withDistributionType(SemiJoinNode.DistributionType.REPLICATED);
        }
        return node.withDistributionType(SemiJoinNode.DistributionType.PARTITIONED);
    }

    private boolean canReplicate(SemiJoinNode node, Rule.Context context) {
        DataSize joinMaxBroadcastTableSize = SystemSessionProperties.getJoinMaxBroadcastTableSize(context.getSession());
        PlanNode buildSide = node.getFilteringSource();
        PlanNodeStatsEstimate buildSideStatsEstimate = context.getStatsProvider().getStats(buildSide);
        double buildSideSizeInBytes = buildSideStatsEstimate.getOutputSizeInBytes(buildSide.getOutputSymbols(), context.getSymbolAllocator().getTypes());
        return buildSideSizeInBytes <= (double)joinMaxBroadcastTableSize.toBytes() || DetermineJoinDistributionType.getSourceTablesSizeInBytes(buildSide, context) <= (double)joinMaxBroadcastTableSize.toBytes();
    }

    private PlanNodeWithCost getSemiJoinNodeWithCost(SemiJoinNode possibleJoinNode, Rule.Context context) {
        TypeProvider types = context.getSymbolAllocator().getTypes();
        StatsProvider stats = context.getStatsProvider();
        boolean replicated = possibleJoinNode.getDistributionType().get() == SemiJoinNode.DistributionType.REPLICATED;
        int estimatedSourceDistributedTaskCount = this.taskCountEstimator.estimateSourceDistributedTaskCount(context.getSession());
        LocalCostEstimate cost = CostCalculatorWithEstimatedExchanges.calculateJoinCostWithoutOutput(possibleJoinNode.getSource(), possibleJoinNode.getFilteringSource(), stats, types, replicated, estimatedSourceDistributedTaskCount);
        return new PlanNodeWithCost(cost.toPlanCost(), possibleJoinNode);
    }
}

