/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.graph.Traverser;
import io.trino.Session;
import io.trino.cost.CachingTableStatsProvider;
import io.trino.cost.RuntimeInfoProvider;
import io.trino.cost.StatsAndCosts;
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.Partitioning;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.PlanFragmentIdAllocator;
import io.trino.sql.planner.PlanFragmenter;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SimplePlanVisitor;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.AdaptivePlanNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import io.trino.sql.planner.sanity.PlanSanityChecker;
import io.trino.tracing.ScopedSpan;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public class AdaptivePlanner {
    private final Session session;
    private final PlannerContext plannerContext;
    private final List<AdaptivePlanOptimizer> planOptimizers;
    private final PlanFragmenter planFragmenter;
    private final PlanSanityChecker planSanityChecker;
    private final WarningCollector warningCollector;
    private final PlanOptimizersStatsCollector planOptimizersStatsCollector;
    private final CachingTableStatsProvider tableStatsProvider;
    private final Set<PlanNodeId> cummulativeChangedPlanNodes = new HashSet<PlanNodeId>();

    public AdaptivePlanner(Session session, PlannerContext plannerContext, List<AdaptivePlanOptimizer> planOptimizers, PlanFragmenter planFragmenter, PlanSanityChecker planSanityChecker, WarningCollector warningCollector, PlanOptimizersStatsCollector planOptimizersStatsCollector, CachingTableStatsProvider tableStatsProvider) {
        this.session = Objects.requireNonNull(session, "session is null");
        this.plannerContext = Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.planOptimizers = Objects.requireNonNull(planOptimizers, "planOptimizers is null");
        this.planFragmenter = Objects.requireNonNull(planFragmenter, "planFragmenter is null");
        this.planSanityChecker = Objects.requireNonNull(planSanityChecker, "planSanityChecker is null");
        this.warningCollector = Objects.requireNonNull(warningCollector, "warningCollector is null");
        this.planOptimizersStatsCollector = Objects.requireNonNull(planOptimizersStatsCollector, "planOptimizersStatsCollector is null");
        this.tableStatsProvider = Objects.requireNonNull(tableStatsProvider, "tableStatsProvider is null");
    }

    public SubPlan optimize(SubPlan root, RuntimeInfoProvider runtimeInfoProvider) {
        if (runtimeInfoProvider.getRuntimeOutputStats(root.getFragment().getId()).isAccurate()) {
            return root;
        }
        List subPlans = (List)this.traverse(root).collect(ImmutableList.toImmutableList());
        PlanFragmentIdAllocator fragmentIdAllocator = new PlanFragmentIdAllocator(this.getMaxPlanFragmentId(subPlans) + 1);
        SymbolAllocator symbolAllocator = this.createSymbolAllocator(subPlans);
        ReplaceRemoteSourcesWithExchanges rewriter = new ReplaceRemoteSourcesWithExchanges(runtimeInfoProvider);
        PlanNode currentAdaptivePlan = SimplePlanRewriter.rewriteWith(rewriter, root.getFragment().getRoot(), root.getChildren());
        PlanNode initialPlan = this.getInitialPlan(currentAdaptivePlan);
        PlanNode currentPlan = this.getCurrentPlan(currentAdaptivePlan);
        ExchangeSourceIdToSubPlanCollector exchangeSourceIdToSubPlanCollector = new ExchangeSourceIdToSubPlanCollector();
        currentAdaptivePlan.accept(exchangeSourceIdToSubPlanCollector, subPlans);
        Map<ExchangeSourceId, SubPlan> exchangeSourceIdToSubPlan = exchangeSourceIdToSubPlanCollector.getExchangeSourceIdToSubPlan();
        PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(this.getMaxPlanId(currentPlan) + 1);
        AdaptivePlanOptimizer.Result optimizationResult = this.optimizePlan(currentPlan, symbolAllocator, runtimeInfoProvider, idAllocator);
        if (optimizationResult.changedPlanNodes().isEmpty()) {
            return root;
        }
        this.cummulativeChangedPlanNodes.addAll(optimizationResult.changedPlanNodes());
        PlanNode adaptivePlan = this.addAdaptivePlanNode(idAllocator, initialPlan, optimizationResult.plan(), this.cummulativeChangedPlanNodes);
        try (ScopedSpan scopedSpan = ScopedSpan.scopedSpan(this.plannerContext.getTracer(), "validate-adaptive-plan");){
            this.planSanityChecker.validateAdaptivePlan(adaptivePlan, this.session, this.plannerContext, this.warningCollector);
        }
        return this.planFragmenter.createSubPlans(this.session, new Plan(adaptivePlan, StatsAndCosts.empty()), false, this.warningCollector, fragmentIdAllocator, new PartitioningScheme(Partitioning.create(SystemPartitioningHandle.SINGLE_DISTRIBUTION, (List<Symbol>)ImmutableList.of()), adaptivePlan.getOutputSymbols()), this.getUnchangedSubPlans(adaptivePlan, optimizationResult.changedPlanNodes(), exchangeSourceIdToSubPlan));
    }

    private AdaptivePlanOptimizer.Result optimizePlan(PlanNode plan, SymbolAllocator symbolAllocator, RuntimeInfoProvider runtimeInfoProvider, PlanNodeIdAllocator idAllocator) {
        AdaptivePlanOptimizer.Result result = new AdaptivePlanOptimizer.Result(plan, Set.of());
        ImmutableSet.Builder changedPlanNodes = ImmutableSet.builder();
        for (AdaptivePlanOptimizer optimizer : this.planOptimizers) {
            result = optimizer.optimizeAndMarkPlanChanges(result.plan(), new PlanOptimizer.Context(this.session, symbolAllocator, idAllocator, this.warningCollector, this.planOptimizersStatsCollector, this.tableStatsProvider, runtimeInfoProvider));
            changedPlanNodes.addAll(result.changedPlanNodes());
        }
        return new AdaptivePlanOptimizer.Result(result.plan(), (Set<PlanNodeId>)changedPlanNodes.build());
    }

    private PlanNode addAdaptivePlanNode(PlanNodeIdAllocator idAllocator, PlanNode initialPlan, PlanNode optimizedPlanNode, Set<PlanNodeId> changedPlanNodes) {
        if (changedPlanNodes.contains(optimizedPlanNode.getId())) {
            return new AdaptivePlanNode(idAllocator.getNextId(), initialPlan, SymbolsExtractor.extractOutputSymbols(initialPlan), optimizedPlanNode);
        }
        Verify.verify((initialPlan.getSources().size() == optimizedPlanNode.getSources().size() ? 1 : 0) != 0);
        ImmutableList.Builder sources = ImmutableList.builder();
        for (int i = 0; i < initialPlan.getSources().size(); ++i) {
            PlanNode initialSource = initialPlan.getSources().get(i);
            PlanNode optimizedSource = optimizedPlanNode.getSources().get(i);
            sources.add((Object)this.addAdaptivePlanNode(idAllocator, initialSource, optimizedSource, changedPlanNodes));
        }
        return optimizedPlanNode.replaceChildren((List<PlanNode>)sources.build());
    }

    private Map<ExchangeSourceId, SubPlan> getUnchangedSubPlans(PlanNode adaptivePlan, Set<PlanNodeId> changedPlanIds, Map<ExchangeSourceId, SubPlan> exchangeSourceIdToSubPlan) {
        HashSet<PlanNodeId> changedPlanIdsWithDownstream = new HashSet<PlanNodeId>();
        for (PlanNodeId changedId : changedPlanIds) {
            changedPlanIdsWithDownstream.addAll(this.getDownstreamPlanNodeIds(adaptivePlan, changedId));
        }
        return (Map)exchangeSourceIdToSubPlan.entrySet().stream().filter(entry -> !changedPlanIdsWithDownstream.contains(((ExchangeSourceId)entry.getKey()).exchangeId()) && !changedPlanIdsWithDownstream.contains(((ExchangeSourceId)entry.getKey()).sourceId())).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
    }

    private Set<PlanNodeId> getDownstreamPlanNodeIds(PlanNode root, PlanNodeId id) {
        if (root.getId().equals(id)) {
            return ImmutableSet.of((Object)id);
        }
        HashSet<PlanNodeId> upstreamNodes = new HashSet<PlanNodeId>();
        root.getSources().stream().map(source -> this.getDownstreamPlanNodeIds((PlanNode)source, id)).forEach(upstreamNodes::addAll);
        if (!upstreamNodes.isEmpty()) {
            upstreamNodes.add(root.getId());
        }
        return upstreamNodes;
    }

    private PlanNode getCurrentPlan(PlanNode node) {
        return SimplePlanRewriter.rewriteWith(new CurrentPlanRewriter(), node);
    }

    private PlanNode getInitialPlan(PlanNode node) {
        return SimplePlanRewriter.rewriteWith(new InitialPlanRewriter(), node);
    }

    private SymbolAllocator createSymbolAllocator(List<SubPlan> subPlans) {
        return new SymbolAllocator((Collection)subPlans.stream().map(SubPlan::getFragment).map(PlanFragment::getSymbols).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet()));
    }

    private int getMaxPlanFragmentId(List<SubPlan> subPlans) {
        return subPlans.stream().map(SubPlan::getFragment).map(PlanFragment::getId).mapToInt(fragmentId -> Integer.parseInt(fragmentId.toString())).max().orElseThrow();
    }

    private int getMaxPlanId(PlanNode node) {
        return this.traverse(node).map(PlanNode::getId).mapToInt(planNodeId -> Integer.parseInt(planNodeId.toString())).max().orElseThrow();
    }

    private Stream<PlanNode> traverse(PlanNode node) {
        Iterable iterable = Traverser.forTree(PlanNode::getSources).depthFirstPreOrder((Object)node);
        return StreamSupport.stream(iterable.spliterator(), false);
    }

    private Stream<SubPlan> traverse(SubPlan subPlan) {
        Iterable iterable = Traverser.forTree(SubPlan::getChildren).depthFirstPreOrder((Object)subPlan);
        return StreamSupport.stream(iterable.spliterator(), false);
    }

    private static boolean containsAdaptivePlanNode(PlanNode node) {
        return PlanNodeSearcher.searchFrom(node).whereIsInstanceOfAny(AdaptivePlanNode.class).matches();
    }

    private static class ReplaceRemoteSourcesWithExchanges
    extends SimplePlanRewriter<List<SubPlan>> {
        private final RuntimeInfoProvider runtimeInfoProvider;

        private ReplaceRemoteSourcesWithExchanges(RuntimeInfoProvider runtimeInfoProvider) {
            this.runtimeInfoProvider = Objects.requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null");
        }

        @Override
        public PlanNode visitAdaptivePlanNode(AdaptivePlanNode node, SimplePlanRewriter.RewriteContext<List<SubPlan>> context) {
            PlanNode initialPlan = context.rewrite(node.getInitialPlan(), context.get());
            PlanNode currentPlan = context.rewrite(node.getCurrentPlan(), context.get());
            return new AdaptivePlanNode(node.getId(), initialPlan, SymbolsExtractor.extractOutputSymbols(initialPlan), currentPlan);
        }

        @Override
        public PlanNode visitRemoteSource(RemoteSourceNode node, SimplePlanRewriter.RewriteContext<List<SubPlan>> context) {
            if (node.getSourceFragmentIds().stream().anyMatch(planFragmentId -> this.runtimeInfoProvider.getRuntimeOutputStats((PlanFragmentId)planFragmentId).isAccurate())) {
                return node;
            }
            List sourceSubPlans = (List)context.get().stream().filter(subPlan -> node.getSourceFragmentIds().contains(subPlan.getFragment().getId())).collect(ImmutableList.toImmutableList());
            ImmutableList.Builder sourceNodesBuilder = ImmutableList.builder();
            for (SubPlan sourceSubPlan : sourceSubPlans) {
                PlanNode sourceNode = context.rewrite(sourceSubPlan.getFragment().getRoot(), sourceSubPlan.getChildren());
                sourceNodesBuilder.add((Object)sourceNode);
            }
            ImmutableList sourceNodes = sourceNodesBuilder.build();
            List outputPartitioningSchemes = (List)node.getSourceFragmentIds().stream().map(this.runtimeInfoProvider::getPlanFragment).map(PlanFragment::getOutputPartitioningScheme).collect(ImmutableList.toImmutableList());
            Verify.verify((outputPartitioningSchemes.size() == sourceNodes.size() ? 1 : 0) != 0, (String)"Output partitioning schemes size does not match source nodes size", (Object[])new Object[0]);
            List inputs = (List)outputPartitioningSchemes.stream().map(PartitioningScheme::getOutputLayout).collect(ImmutableList.toImmutableList());
            return new ExchangeNode(node.getId(), node.getExchangeType(), ExchangeNode.Scope.REMOTE, ((PartitioningScheme)outputPartitioningSchemes.getFirst()).translateOutputLayout(node.getOutputSymbols()), (List<PlanNode>)sourceNodes, inputs, node.getOrderingScheme());
        }
    }

    private static class ExchangeSourceIdToSubPlanCollector
    extends SimplePlanVisitor<List<SubPlan>> {
        private final Map<ExchangeSourceId, SubPlan> exchangeSourceIdToSubPlan = new HashMap<ExchangeSourceId, SubPlan>();

        private ExchangeSourceIdToSubPlanCollector() {
        }

        @Override
        public Void visitExchange(ExchangeNode node, List<SubPlan> context) {
            this.visitPlan((PlanNode)node, context);
            if (node.getScope() != ExchangeNode.Scope.REMOTE) {
                return null;
            }
            Set sourceIds = (Set)node.getSources().stream().map(PlanNode::getId).collect(ImmutableSet.toImmutableSet());
            List sourceSubPlans = (List)context.stream().filter(subPlan -> sourceIds.contains(subPlan.getFragment().getRoot().getId())).collect(ImmutableList.toImmutableList());
            if (sourceSubPlans.size() != sourceIds.size()) {
                throw new IllegalStateException(String.format("Source subPlans not found for exchange node %s; sourceIds: %s; filteredSubPlans: %s; allSubPlans: %s", node.getId(), sourceIds, sourceSubPlans.stream().map(subPlan -> String.valueOf(subPlan.getFragment().getId()) + "->" + String.valueOf(subPlan.getFragment().getRoot().getId())).collect(ImmutableList.toImmutableList()), context.stream().map(subPlan -> String.valueOf(subPlan.getFragment().getId()) + "->" + String.valueOf(subPlan.getFragment().getRoot().getId())).collect(ImmutableList.toImmutableList())));
            }
            for (SubPlan sourceSubPlan : sourceSubPlans) {
                PlanNodeId sourceId = sourceSubPlan.getFragment().getRoot().getId();
                this.exchangeSourceIdToSubPlan.put(new ExchangeSourceId(node.getId(), sourceId), sourceSubPlan);
            }
            return null;
        }

        @Override
        public Void visitRemoteSource(RemoteSourceNode node, List<SubPlan> context) {
            List sourceSubPlans = (List)context.stream().filter(subPlan -> node.getSourceFragmentIds().contains(subPlan.getFragment().getId())).collect(ImmutableList.toImmutableList());
            for (SubPlan sourceSubPlan : sourceSubPlans) {
                PlanNodeId sourceId = sourceSubPlan.getFragment().getRoot().getId();
                this.exchangeSourceIdToSubPlan.put(new ExchangeSourceId(node.getId(), sourceId), sourceSubPlan);
            }
            return null;
        }

        public Map<ExchangeSourceId, SubPlan> getExchangeSourceIdToSubPlan() {
            return ImmutableMap.copyOf(this.exchangeSourceIdToSubPlan);
        }
    }

    private static class CurrentPlanRewriter
    extends SimplePlanRewriter<List<SubPlan>> {
        private CurrentPlanRewriter() {
        }

        @Override
        public PlanNode visitAdaptivePlanNode(AdaptivePlanNode node, SimplePlanRewriter.RewriteContext<List<SubPlan>> context) {
            Verify.verify((!AdaptivePlanner.containsAdaptivePlanNode(node.getCurrentPlan()) ? 1 : 0) != 0, (String)"Adaptive plan node cannot have a nested adaptive plan node", (Object[])new Object[0]);
            return node.getCurrentPlan();
        }
    }

    private static class InitialPlanRewriter
    extends SimplePlanRewriter<List<SubPlan>> {
        private InitialPlanRewriter() {
        }

        @Override
        public PlanNode visitAdaptivePlanNode(AdaptivePlanNode node, SimplePlanRewriter.RewriteContext<List<SubPlan>> context) {
            Verify.verify((!AdaptivePlanner.containsAdaptivePlanNode(node.getInitialPlan()) ? 1 : 0) != 0, (String)"Adaptive plan node cannot have a nested adaptive plan node", (Object[])new Object[0]);
            return node.getInitialPlan();
        }
    }

    public record ExchangeSourceId(PlanNodeId exchangeId, PlanNodeId sourceId) {
        public ExchangeSourceId {
            Objects.requireNonNull(exchangeId, "exchangeId is null");
            Objects.requireNonNull(sourceId, "sourceId is null");
        }
    }
}

