/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanNodeCostEstimate;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.Patterns;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.google.common.collect.Ordering;
import java.util.ArrayList;
import java.util.Objects;

public class DetermineJoinDistributionType
implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(joinNode -> !joinNode.getDistributionType().isPresent());
    private final CostComparator costComparator;

    public DetermineJoinDistributionType(CostComparator costComparator) {
        this.costComparator = Objects.requireNonNull(costComparator, "costComparator is null");
    }

    @Override
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        FeaturesConfig.JoinDistributionType joinDistributionType = SystemSessionProperties.getJoinDistributionType(context.getSession());
        if (joinDistributionType == FeaturesConfig.JoinDistributionType.AUTOMATIC) {
            return Rule.Result.ofPlanNode(this.getCostBasedJoin(joinNode, context));
        }
        return Rule.Result.ofPlanNode(this.getSyntacticOrderJoin(joinNode, context, joinDistributionType));
    }

    private PlanNode getCostBasedJoin(JoinNode joinNode, Rule.Context context) {
        JoinNode flipped;
        CostProvider costProvider = context.getCostProvider();
        ArrayList<PlanNodeWithCost> possibleJoinNodes = new ArrayList<PlanNodeWithCost>();
        if (!DetermineJoinDistributionType.mustPartition(joinNode)) {
            possibleJoinNodes.add(DetermineJoinDistributionType.getJoinNodeWithCost(costProvider, joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED)));
        }
        if (!DetermineJoinDistributionType.mustReplicate(joinNode, context)) {
            possibleJoinNodes.add(DetermineJoinDistributionType.getJoinNodeWithCost(costProvider, joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED)));
        }
        if (!DetermineJoinDistributionType.mustPartition(flipped = joinNode.flipChildren())) {
            possibleJoinNodes.add(DetermineJoinDistributionType.getJoinNodeWithCost(costProvider, flipped.withDistributionType(JoinNode.DistributionType.REPLICATED)));
        }
        if (!DetermineJoinDistributionType.mustReplicate(flipped, context)) {
            possibleJoinNodes.add(DetermineJoinDistributionType.getJoinNodeWithCost(costProvider, flipped.withDistributionType(JoinNode.DistributionType.PARTITIONED)));
        }
        if (possibleJoinNodes.stream().anyMatch(result -> result.getCost().hasUnknownComponents()) || possibleJoinNodes.isEmpty()) {
            return this.getSyntacticOrderJoin(joinNode, context, FeaturesConfig.JoinDistributionType.AUTOMATIC);
        }
        Ordering planNodeOrderings = this.costComparator.forSession(context.getSession()).onResultOf(PlanNodeWithCost::getCost);
        return ((PlanNodeWithCost)planNodeOrderings.min(possibleJoinNodes)).getPlanNode();
    }

    private PlanNode getSyntacticOrderJoin(JoinNode joinNode, Rule.Context context, FeaturesConfig.JoinDistributionType joinDistributionType) {
        if (DetermineJoinDistributionType.mustPartition(joinNode)) {
            return joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        if (DetermineJoinDistributionType.mustReplicate(joinNode, context)) {
            return joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED);
        }
        if (joinDistributionType.canPartition()) {
            return joinNode.withDistributionType(JoinNode.DistributionType.PARTITIONED);
        }
        return joinNode.withDistributionType(JoinNode.DistributionType.REPLICATED);
    }

    private static boolean mustPartition(JoinNode joinNode) {
        JoinNode.Type type = joinNode.getType();
        return type == JoinNode.Type.RIGHT || type == JoinNode.Type.FULL;
    }

    private static boolean mustReplicate(JoinNode joinNode, Rule.Context context) {
        JoinNode.Type type = joinNode.getType();
        if (joinNode.getCriteria().isEmpty() && (type == JoinNode.Type.INNER || type == JoinNode.Type.LEFT)) {
            return true;
        }
        return QueryCardinalityUtil.isAtMostScalar(joinNode.getRight(), context.getLookup());
    }

    private static PlanNodeWithCost getJoinNodeWithCost(CostProvider costProvider, JoinNode possibleJoinNode) {
        return new PlanNodeWithCost(costProvider.getCumulativeCost(possibleJoinNode), possibleJoinNode);
    }

    private static class PlanNodeWithCost {
        private final PlanNode planNode;
        private final PlanNodeCostEstimate cost;

        public PlanNodeWithCost(PlanNodeCostEstimate cost, PlanNode planNode) {
            this.cost = Objects.requireNonNull(cost, "cost is null");
            this.planNode = Objects.requireNonNull(planNode, "planNode is null");
        }

        public PlanNode getPlanNode() {
            return this.planNode;
        }

        public PlanNodeCostEstimate getCost() {
            return this.cost;
        }
    }
}

