/*
 * Decompiled with CFR 0.152.
 */
package io.trino.cost;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.ThreadSafe;
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.cost.CostCalculator;
import io.trino.cost.CostCalculatorWithEstimatedExchanges;
import io.trino.cost.CostProvider;
import io.trino.cost.LocalCostEstimate;
import io.trino.cost.PlanCostEstimate;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.cost.StatsProvider;
import io.trino.cost.TaskCountEstimator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.EnforceSingleRowNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.FilterNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.LimitNode;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanVisitor;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.RowNumberNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import io.trino.sql.planner.plan.SpatialJoinNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.planner.plan.UnionNode;
import io.trino.sql.planner.plan.ValuesNode;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Stream;

@ThreadSafe
public class CostCalculatorUsingExchanges
implements CostCalculator {
    private final TaskCountEstimator taskCountEstimator;

    @Inject
    public CostCalculatorUsingExchanges(TaskCountEstimator taskCountEstimator) {
        this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
    }

    @Override
    public PlanCostEstimate calculateCost(PlanNode node, StatsProvider stats, CostProvider sourcesCosts, Session session) {
        CostEstimator costEstimator = new CostEstimator(stats, sourcesCosts, this.taskCountEstimator, session);
        return node.accept(costEstimator, null);
    }

    private static PlanCostEstimate addParallelSiblingsCost(PlanCostEstimate a, PlanCostEstimate b) {
        return new PlanCostEstimate(a.getCpuCost() + b.getCpuCost(), a.getMaxMemory() + b.getMaxMemory(), a.getMaxMemoryWhenOutputting() + b.getMaxMemoryWhenOutputting(), a.getNetworkCost() + b.getNetworkCost());
    }

    private static class CostEstimator
    extends PlanVisitor<PlanCostEstimate, Void> {
        private final StatsProvider stats;
        private final CostProvider sourcesCosts;
        private final TaskCountEstimator taskCountEstimator;
        private final Session session;

        CostEstimator(StatsProvider stats, CostProvider sourcesCosts, TaskCountEstimator taskCountEstimator, Session session) {
            this.stats = Objects.requireNonNull(stats, "stats is null");
            this.sourcesCosts = Objects.requireNonNull(sourcesCosts, "sourcesCosts is null");
            this.taskCountEstimator = Objects.requireNonNull(taskCountEstimator, "taskCountEstimator is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        @Override
        protected PlanCostEstimate visitPlan(PlanNode node, Void context) {
            return PlanCostEstimate.unknown();
        }

        @Override
        public PlanCostEstimate visitGroupReference(GroupReference node, Void context) {
            throw new UnsupportedOperationException();
        }

        @Override
        public PlanCostEstimate visitAssignUniqueId(AssignUniqueId node, Void context) {
            LocalCostEstimate localCost = LocalCostEstimate.ofCpu(this.getStats(node).getOutputSizeInBytes((Collection<Symbol>)ImmutableList.of((Object)node.getIdColumn())));
            return this.costForStreaming(node, localCost);
        }

        @Override
        public PlanCostEstimate visitRowNumber(RowNumberNode node, Void context) {
            ImmutableList symbols = node.getOutputSymbols();
            if (node.getMaxRowCountPerPartition().isEmpty()) {
                symbols = ImmutableList.builder().addAll(node.getPartitionBy()).add((Object)node.getRowNumberSymbol()).build();
            }
            PlanNodeStatsEstimate stats = this.getStats(node);
            double cpuCost = stats.getOutputSizeInBytes((Collection<Symbol>)symbols);
            double memoryCost = node.getPartitionBy().isEmpty() ? 0.0 : stats.getOutputSizeInBytes(node.getSource().getOutputSymbols());
            LocalCostEstimate localCost = LocalCostEstimate.of(cpuCost, memoryCost, 0.0);
            return this.costForStreaming(node, localCost);
        }

        @Override
        public PlanCostEstimate visitOutput(OutputNode node, Void context) {
            return this.costForStreaming(node, LocalCostEstimate.zero());
        }

        @Override
        public PlanCostEstimate visitTableScan(TableScanNode node, Void context) {
            LocalCostEstimate localCost = LocalCostEstimate.ofCpu(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
            return this.costForSource(node, localCost);
        }

        @Override
        public PlanCostEstimate visitFilter(FilterNode node, Void context) {
            LocalCostEstimate localCost = LocalCostEstimate.ofCpu(this.getStats(node.getSource()).getOutputSizeInBytes(node.getOutputSymbols()));
            return this.costForStreaming(node, localCost);
        }

        @Override
        public PlanCostEstimate visitProject(ProjectNode node, Void context) {
            LocalCostEstimate localCost = LocalCostEstimate.ofCpu(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
            return this.costForStreaming(node, localCost);
        }

        @Override
        public PlanCostEstimate visitAggregation(AggregationNode node, Void context) {
            if (node.getStep() != AggregationNode.Step.FINAL && node.getStep() != AggregationNode.Step.SINGLE) {
                return PlanCostEstimate.unknown();
            }
            PlanNodeStatsEstimate aggregationStats = this.getStats(node);
            PlanNodeStatsEstimate sourceStats = this.getStats(node.getSource());
            double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputSymbols());
            double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputSymbols());
            LocalCostEstimate localCost = LocalCostEstimate.of(cpuCost, memoryCost, 0.0);
            return this.costForAccumulation(node, localCost);
        }

        @Override
        public PlanCostEstimate visitJoin(JoinNode node, Void context) {
            LocalCostEstimate localCost = this.calculateJoinCost(node, node.getLeft(), node.getRight(), Objects.equals(node.getDistributionType(), Optional.of(JoinNode.DistributionType.REPLICATED)));
            return this.costForLookupJoin(node, localCost);
        }

        private LocalCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanNode build, boolean replicated) {
            int estimatedSourceDistributedTaskCount = this.taskCountEstimator.estimateSourceDistributedTaskCount(this.session);
            LocalCostEstimate joinInputCost = CostCalculatorWithEstimatedExchanges.calculateJoinInputCost(probe, build, this.stats, replicated, estimatedSourceDistributedTaskCount);
            LocalCostEstimate adjustedLocalExchangeCost = CostCalculatorWithEstimatedExchanges.adjustReplicatedJoinLocalExchangeCost(build, this.stats, replicated, estimatedSourceDistributedTaskCount);
            LocalCostEstimate joinOutputCost = this.calculateJoinOutputCost(join);
            return LocalCostEstimate.addPartialComponents(joinInputCost, adjustedLocalExchangeCost, joinOutputCost);
        }

        private LocalCostEstimate calculateJoinOutputCost(PlanNode join) {
            PlanNodeStatsEstimate outputStats = this.getStats(join);
            double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputSymbols());
            return LocalCostEstimate.ofCpu(joinOutputSize);
        }

        @Override
        public PlanCostEstimate visitExchange(ExchangeNode node, Void context) {
            return this.costForStreaming(node, this.calculateExchangeCost(node));
        }

        private LocalCostEstimate calculateExchangeCost(ExchangeNode node) {
            double inputSizeInBytes = this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols());
            switch (node.getScope()) {
                case LOCAL: {
                    switch (node.getType()) {
                        case GATHER: {
                            return LocalCostEstimate.zero();
                        }
                        case REPARTITION: {
                            return CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost(inputSizeInBytes);
                        }
                        case REPLICATE: {
                            return LocalCostEstimate.zero();
                        }
                    }
                    throw new IllegalArgumentException("Unexpected type: " + String.valueOf((Object)node.getType()));
                }
                case REMOTE: {
                    switch (node.getType()) {
                        case GATHER: {
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost(inputSizeInBytes);
                        }
                        case REPARTITION: {
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteRepartitionCost(inputSizeInBytes);
                        }
                        case REPLICATE: {
                            return CostCalculatorWithEstimatedExchanges.calculateRemoteReplicateCost(inputSizeInBytes, this.taskCountEstimator.estimateSourceDistributedTaskCount(this.session));
                        }
                    }
                    throw new IllegalArgumentException("Unexpected type: " + String.valueOf((Object)node.getType()));
                }
            }
            throw new IllegalArgumentException("Unexpected scope: " + String.valueOf((Object)node.getScope()));
        }

        @Override
        public PlanCostEstimate visitSemiJoin(SemiJoinNode node, Void context) {
            LocalCostEstimate localCost = this.calculateJoinCost(node, node.getSource(), node.getFilteringSource(), node.getDistributionType().orElse(SemiJoinNode.DistributionType.PARTITIONED) == SemiJoinNode.DistributionType.REPLICATED);
            return this.costForLookupJoin(node, localCost);
        }

        @Override
        public PlanCostEstimate visitSpatialJoin(SpatialJoinNode node, Void context) {
            LocalCostEstimate localCost = this.calculateJoinCost(node, node.getLeft(), node.getRight(), node.getDistributionType() == SpatialJoinNode.DistributionType.REPLICATED);
            return this.costForLookupJoin(node, localCost);
        }

        @Override
        public PlanCostEstimate visitValues(ValuesNode node, Void context) {
            return this.costForSource(node, LocalCostEstimate.zero());
        }

        @Override
        public PlanCostEstimate visitEnforceSingleRow(EnforceSingleRowNode node, Void context) {
            return this.costForAccumulation(node, LocalCostEstimate.zero());
        }

        @Override
        public PlanCostEstimate visitLimit(LimitNode node, Void context) {
            LocalCostEstimate localCost = LocalCostEstimate.ofCpu(this.getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
            return this.costForStreaming(node, localCost);
        }

        @Override
        public PlanCostEstimate visitUnion(UnionNode node, Void context) {
            return this.costForStreaming(node, LocalCostEstimate.zero());
        }

        private PlanCostEstimate costForSource(PlanNode node, LocalCostEstimate localCost) {
            Verify.verify((boolean)node.getSources().isEmpty(), (String)"Unexpected sources for %s: %s", (Object)node, node.getSources());
            return new PlanCostEstimate(localCost.getCpuCost(), localCost.getMaxMemory(), localCost.getMaxMemory(), localCost.getNetworkCost(), localCost);
        }

        private PlanCostEstimate costForAccumulation(PlanNode node, LocalCostEstimate localCost) {
            PlanCostEstimate sourcesCost = this.getSourcesEstimations(node).reduce(PlanCostEstimate.zero(), CostCalculatorUsingExchanges::addParallelSiblingsCost);
            return new PlanCostEstimate(sourcesCost.getCpuCost() + localCost.getCpuCost(), Math.max(sourcesCost.getMaxMemory(), sourcesCost.getMaxMemoryWhenOutputting() + localCost.getMaxMemory()), localCost.getMaxMemory(), sourcesCost.getNetworkCost() + localCost.getNetworkCost(), localCost);
        }

        private PlanCostEstimate costForStreaming(PlanNode node, LocalCostEstimate localCost) {
            PlanCostEstimate sourcesCost = this.getSourcesEstimations(node).reduce(PlanCostEstimate.zero(), CostCalculatorUsingExchanges::addParallelSiblingsCost);
            return new PlanCostEstimate(sourcesCost.getCpuCost() + localCost.getCpuCost(), Math.max(sourcesCost.getMaxMemory(), sourcesCost.getMaxMemoryWhenOutputting() + localCost.getMaxMemory()), sourcesCost.getMaxMemoryWhenOutputting() + localCost.getMaxMemory(), sourcesCost.getNetworkCost() + localCost.getNetworkCost(), localCost);
        }

        private PlanCostEstimate costForLookupJoin(PlanNode node, LocalCostEstimate localCost) {
            Verify.verify((node.getSources().size() == 2 ? 1 : 0) != 0, (String)"Unexpected number of sources for %s: %s", (Object)node, node.getSources());
            List sourcesCosts = (List)this.getSourcesEstimations(node).collect(ImmutableList.toImmutableList());
            Verify.verify((sourcesCosts.size() == 2 ? 1 : 0) != 0);
            PlanCostEstimate probeCost = (PlanCostEstimate)sourcesCosts.get(0);
            PlanCostEstimate buildCost = (PlanCostEstimate)sourcesCosts.get(1);
            return new PlanCostEstimate(probeCost.getCpuCost() + buildCost.getCpuCost() + localCost.getCpuCost(), Math.max(probeCost.getMaxMemory() + buildCost.getMaxMemory(), probeCost.getMaxMemory() + buildCost.getMaxMemoryWhenOutputting() + localCost.getMaxMemory()), probeCost.getMaxMemoryWhenOutputting() + localCost.getMaxMemory(), probeCost.getNetworkCost() + buildCost.getNetworkCost() + localCost.getNetworkCost(), localCost);
        }

        private PlanNodeStatsEstimate getStats(PlanNode node) {
            return this.stats.getStats(node);
        }

        private Stream<PlanCostEstimate> getSourcesEstimations(PlanNode node) {
            return node.getSources().stream().map(this.sourcesCosts::getCost);
        }
    }
}

