/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.CostCalculatorUsingExchanges;
import com.facebook.presto.cost.PlanNodeCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.execution.scheduler.NodeSchedulerConfig;
import com.facebook.presto.metadata.InternalNodeManager;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.plan.AggregationNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.PlanNode;
import com.facebook.presto.sql.planner.plan.PlanVisitor;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.IntSupplier;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class CostCalculatorWithEstimatedExchanges
implements CostCalculator {
    private final CostCalculator costCalculator;
    private final IntSupplier numberOfNodes;

    @Inject
    public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, NodeSchedulerConfig nodeSchedulerConfig, InternalNodeManager nodeManager) {
        this(costCalculator, CostCalculatorUsingExchanges.currentNumberOfWorkerNodes(nodeSchedulerConfig.isIncludeCoordinator(), nodeManager));
    }

    public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, IntSupplier numberOfNodes) {
        this.costCalculator = Objects.requireNonNull(costCalculator, "costCalculator is null");
        this.numberOfNodes = Objects.requireNonNull(numberOfNodes, "numberOfNodes is null");
    }

    @Override
    public PlanNodeCostEstimate calculateCost(PlanNode node, StatsProvider stats, Lookup lookup, Session session, TypeProvider types) {
        ExchangeCostEstimator exchangeCostEstimator = new ExchangeCostEstimator(this.numberOfNodes.getAsInt(), stats, lookup, types);
        PlanNodeCostEstimate estimatedExchangeCost = node.accept(exchangeCostEstimator, null);
        return this.costCalculator.calculateCost(node, stats, lookup, session, types).add(estimatedExchangeCost);
    }

    private static class ExchangeCostEstimator
    extends PlanVisitor<PlanNodeCostEstimate, Void> {
        private final int numberOfNodes;
        private final StatsProvider stats;
        private final Lookup lookup;
        private final TypeProvider types;

        ExchangeCostEstimator(int numberOfNodes, StatsProvider stats, Lookup lookup, TypeProvider types) {
            this.numberOfNodes = numberOfNodes;
            this.stats = Objects.requireNonNull(stats, "stats is null");
            this.lookup = Objects.requireNonNull(lookup, "lookup is null");
            this.types = Objects.requireNonNull(types, "types is null");
        }

        @Override
        protected PlanNodeCostEstimate visitPlan(PlanNode node, Void context) {
            return PlanNodeCostEstimate.ZERO_COST;
        }

        @Override
        public PlanNodeCostEstimate visitGroupReference(GroupReference node, Void context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public PlanNodeCostEstimate visitAggregation(AggregationNode node, Void context) {
            PlanNodeStatsEstimate sourceStats = this.getStats(node.getSource());
            List<Symbol> sourceSymbols = node.getSource().getOutputSymbols();
            PlanNodeCostEstimate remoteRepartitionCost = CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, sourceStats, sourceSymbols, ExchangeNode.Type.REPARTITION, ExchangeNode.Scope.REMOTE, this.types);
            PlanNodeCostEstimate localRepartitionCost = CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, sourceStats, sourceSymbols, ExchangeNode.Type.REPARTITION, ExchangeNode.Scope.LOCAL, this.types);
            return remoteRepartitionCost.add(localRepartitionCost);
        }

        @Override
        public PlanNodeCostEstimate visitJoin(JoinNode node, Void context) {
            return this.calculateJoinCost(node.getLeft(), node.getRight(), Objects.equals(node.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)));
        }

        @Override
        public PlanNodeCostEstimate visitSemiJoin(SemiJoinNode node, Void context) {
            return this.calculateJoinCost(node.getSource(), node.getFilteringSource(), Objects.equals(node.getDistributionType(), Optional.of(SemiJoinNode.DistributionType.REPLICATED)));
        }

        private PlanNodeCostEstimate calculateJoinCost(PlanNode probe, PlanNode build, boolean replicated) {
            if (replicated) {
                PlanNodeCostEstimate replicateCost = CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, this.getStats(build), build.getOutputSymbols(), ExchangeNode.Type.REPLICATE, ExchangeNode.Scope.REMOTE, this.types);
                PlanNodeCostEstimate localRepartitionCost = CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, this.getStats(build), build.getOutputSymbols(), ExchangeNode.Type.REPARTITION, ExchangeNode.Scope.LOCAL, this.types);
                return replicateCost.add(localRepartitionCost);
            }
            PlanNodeCostEstimate probeCost = CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, this.getStats(probe), probe.getOutputSymbols(), ExchangeNode.Type.REPARTITION, ExchangeNode.Scope.REMOTE, this.types);
            PlanNodeCostEstimate buildRemoteRepartitionCost = CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, this.getStats(build), build.getOutputSymbols(), ExchangeNode.Type.REPARTITION, ExchangeNode.Scope.REMOTE, this.types);
            PlanNodeCostEstimate buildLocalRepartitionCost = CostCalculatorUsingExchanges.calculateExchangeCost(this.numberOfNodes, this.getStats(build), build.getOutputSymbols(), ExchangeNode.Type.REPARTITION, ExchangeNode.Scope.LOCAL, this.types);
            return probeCost.add(buildRemoteRepartitionCost).add(buildLocalRepartitionCost);
        }

        private PlanNodeStatsEstimate getStats(PlanNode node) {
            return this.stats.getStats(node);
        }
    }
}

