/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.cost;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.cost.CostCalculator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.LocalCostEstimate;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.cost.PlanNodeStatsEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.cost.TaskCountEstimator;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanVisitor;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.planner.plan.JoinNode;
import com.facebook.presto.sql.planner.plan.SemiJoinNode;
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.concurrent.ThreadSafe;
import javax.inject.Inject;

@ThreadSafe
public class CostCalculatorWithEstimatedExchanges
implements CostCalculator {
    private final CostCalculator costCalculator;
    private final TaskCountEstimator taskCountEstimator;

    @Inject
    public CostCalculatorWithEstimatedExchanges(CostCalculator costCalculator, TaskCountEstimator taskCountEstimator) {
        this.costCalculator = Objects.requireNonNull(costCalculator, "costCalculator is null");
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public PlanCostEstimate calculateCost(PlanNode node, StatsProvider stats, CostProvider sourcesCosts, Session session) {
        ExchangeCostEstimator exchangeCostEstimator = new ExchangeCostEstimator(stats, this.taskCountEstimator);
        PlanCostEstimate costEstimate = this.costCalculator.calculateCost(node, stats, sourcesCosts, session);
        LocalCostEstimate estimatedExchangeCost = (LocalCostEstimate)node.accept((PlanVisitor)exchangeCostEstimator, null);
        return CostCalculatorWithEstimatedExchanges.addExchangeCost(costEstimate, estimatedExchangeCost);
    }

    private static PlanCostEstimate addExchangeCost(PlanCostEstimate costEstimate, LocalCostEstimate estimatedExchangeCost) {
        return new PlanCostEstimate(costEstimate.getCpuCost() + estimatedExchangeCost.getCpuCost(), costEstimate.getMaxMemory() + estimatedExchangeCost.getMaxMemory(), costEstimate.getMaxMemoryWhenOutputting() + estimatedExchangeCost.getMaxMemory(), costEstimate.getNetworkCost() + estimatedExchangeCost.getNetworkCost());
    }

    public static LocalCostEstimate calculateRemoteGatherCost(double inputSizeInBytes) {
        return LocalCostEstimate.ofNetwork(inputSizeInBytes);
    }

    public static LocalCostEstimate calculateRemoteRepartitionCost(double inputSizeInBytes) {
        return LocalCostEstimate.of(inputSizeInBytes, 0.0, inputSizeInBytes);
    }

    public static LocalCostEstimate calculateCteProducerCost(Session session, StatsProvider statsProvider, PlanNode source) {
        double inputSizeInBytes = statsProvider.getStats(source).getOutputSizeInBytes(source);
        double cteProducerReplicationCoefficient = SystemSessionProperties.getCteProducerReplicationCoefficient(session);
        return LocalCostEstimate.of(cteProducerReplicationCoefficient * inputSizeInBytes, 0.0, cteProducerReplicationCoefficient * inputSizeInBytes);
    }

    public static LocalCostEstimate calculateLocalRepartitionCost(double inputSizeInBytes) {
        return LocalCostEstimate.ofCpu(inputSizeInBytes);
    }

    public static LocalCostEstimate calculateRemoteReplicateCost(double inputSizeInBytes, int destinationTaskCount) {
        return LocalCostEstimate.ofNetwork(inputSizeInBytes * (double)destinationTaskCount);
    }

    public static LocalCostEstimate calculateJoinCostWithoutOutput(PlanNode probe, PlanNode build, StatsProvider stats, boolean replicated, int estimatedSourceDistributedTaskCount) {
        LocalCostEstimate exchangesCost = CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(probe, build, stats, replicated, estimatedSourceDistributedTaskCount);
        LocalCostEstimate inputCost = CostCalculatorWithEstimatedExchanges.calculateJoinInputCost(probe, build, stats, replicated, estimatedSourceDistributedTaskCount);
        return LocalCostEstimate.addPartialComponents(exchangesCost, inputCost, new LocalCostEstimate[0]);
    }

    private static LocalCostEstimate calculateJoinExchangeCost(PlanNode probe, PlanNode build, StatsProvider stats, boolean replicated, int estimatedSourceDistributedTaskCount) {
        double probeSizeInBytes = stats.getStats(probe).getOutputSizeInBytes(probe);
        double buildSizeInBytes = stats.getStats(build).getOutputSizeInBytes(build);
        if (replicated) {
            LocalCostEstimate replicateCost = CostCalculatorWithEstimatedExchanges.calculateRemoteReplicateCost(buildSizeInBytes, estimatedSourceDistributedTaskCount);
            LocalCostEstimate localRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(buildSizeInBytes);
            return LocalCostEstimate.addPartialComponents(replicateCost, localRepartitionCost, new LocalCostEstimate[0]);
        }
        LocalCostEstimate probeCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(probeSizeInBytes);
        LocalCostEstimate buildRemoteRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(buildSizeInBytes);
        LocalCostEstimate buildLocalRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(buildSizeInBytes);
        return LocalCostEstimate.addPartialComponents(probeCost, buildRemoteRepartitionCost, buildLocalRepartitionCost);
    }

    public static LocalCostEstimate calculateJoinInputCost(PlanNode probe, PlanNode build, StatsProvider stats, boolean replicated, int estimatedSourceDistributedTaskCount) {
        int buildSizeMultiplier = replicated ? estimatedSourceDistributedTaskCount : 1;
        PlanNodeStatsEstimate probeStats = stats.getStats(probe);
        PlanNodeStatsEstimate buildStats = stats.getStats(build);
        double buildSideSize = buildStats.getOutputSizeInBytes(build);
        double probeSideSize = probeStats.getOutputSizeInBytes(probe);
        double cpuCost = probeSideSize + buildSideSize * (double)buildSizeMultiplier;
        if (replicated) {
            cpuCost += buildSideSize * (double)(buildSizeMultiplier - 1);
        }
        double memoryCost = buildSideSize * (double)buildSizeMultiplier;
        return LocalCostEstimate.of(cpuCost, memoryCost, 0.0);
    }

    private static class ExchangeCostEstimator
    extends InternalPlanVisitor<LocalCostEstimate, Void> {
        private final StatsProvider stats;
        private final TaskCountEstimator taskCountEstimator;

        ExchangeCostEstimator(StatsProvider stats, TaskCountEstimator taskCountEstimator) {
            this.stats = Objects.requireNonNull(stats, "stats is null");
            this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
        }

        public LocalCostEstimate visitPlan(PlanNode node, Void context) {
            return LocalCostEstimate.zero();
        }

        @Override
        public LocalCostEstimate visitGroupReference(GroupReference node, Void context) {
            throw new UnsupportedOperationException();
        }

        public LocalCostEstimate visitAggregation(AggregationNode node, Void context) {
            PlanNode source = node.getSource();
            double inputSizeInBytes = this.getStats(source).getOutputSizeInBytes(source);
            LocalCostEstimate remoteRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(inputSizeInBytes);
            LocalCostEstimate localRepartitionCost = CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(inputSizeInBytes);
            return LocalCostEstimate.addPartialComponents(remoteRepartitionCost, localRepartitionCost, new LocalCostEstimate[0]);
        }

        @Override
        public LocalCostEstimate visitJoin(JoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getLeft(), node.getRight(), this.stats, Objects.equals(node.getDistributionType(), Optional.of(JoinDistributionType.REPLICATED)), this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        @Override
        public LocalCostEstimate visitSemiJoin(SemiJoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getSource(), node.getFilteringSource(), this.stats, Objects.equals(node.getDistributionType(), Optional.of(SemiJoinNode.DistributionType.REPLICATED)), this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        @Override
        public LocalCostEstimate visitSpatialJoin(SpatialJoinNode node, Void context) {
            return CostCalculatorWithEstimatedExchanges.calculateJoinExchangeCost(node.getLeft(), node.getRight(), this.stats, node.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED, this.taskCountEstimator.estimateSourceDistributedTaskCount());
        }

        public LocalCostEstimate visitUnion(UnionNode node, Void context) {
            double inputSizeInBytes = this.getStats((PlanNode)node).getOutputSizeInBytes((PlanNode)node);
            return CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost(inputSizeInBytes);
        }

        @Override
        public LocalCostEstimate visitSequence(SequenceNode node, Void context) {
            return LocalCostEstimate.addPartialComponents((List)node.getSources().stream().map(n -> (LocalCostEstimate)n.accept((PlanVisitor)this, (Object)context)).collect(ImmutableList.toImmutableList()));
        }

        public LocalCostEstimate visitIntersect(IntersectNode node, Void context) {
            double inputSizeInBytes = this.getStats((PlanNode)node).getOutputSizeInBytes((PlanNode)node);
            return CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost(inputSizeInBytes);
        }

        private PlanNodeStatsEstimate getStats(PlanNode node) {
            return this.stats.getStats(node);
        }
    }
}

