/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterators;
import com.google.common.graph.Traverser;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.PlanFragmentIdAllocator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

public final class RuntimeAdaptivePartitioningRewriter {
    private RuntimeAdaptivePartitioningRewriter() {
    }

    public static SubPlan overridePartitionCountRecursively(SubPlan subPlan, int oldPartitionCount, int newPartitionCount, PlanFragmentIdAllocator planFragmentIdAllocator, PlanNodeIdAllocator planNodeIdAllocator, Set<PlanFragmentId> startedFragments) {
        PlanFragment fragment = subPlan.getFragment();
        if (startedFragments.contains(fragment.getId())) {
            return subPlan;
        }
        PartitioningScheme outputPartitioningScheme = fragment.getOutputPartitioningScheme();
        if (outputPartitioningScheme.getPartitioning().getHandle().equals(SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION)) {
            return subPlan;
        }
        if (RuntimeAdaptivePartitioningRewriter.producesHashPartitionedOutput(fragment)) {
            fragment = fragment.withOutputPartitioningScheme(outputPartitioningScheme.withPartitionCount(Optional.of(newPartitionCount)));
        }
        if (!RuntimeAdaptivePartitioningRewriter.consumesHashPartitionedInput(fragment)) {
            return new SubPlan(fragment, (List)subPlan.getChildren().stream().map(child -> RuntimeAdaptivePartitioningRewriter.overridePartitionCountRecursively(child, oldPartitionCount, newPartitionCount, planFragmentIdAllocator, planNodeIdAllocator, startedFragments)).collect(ImmutableList.toImmutableList()));
        }
        fragment = fragment.withPartitionCount(Optional.of(newPartitionCount));
        ImmutableList.Builder newSources = ImmutableList.builder();
        ImmutableMap.Builder runtimeAdaptivePlanFragmentIdMapping = ImmutableMap.builder();
        for (SubPlan source : subPlan.getChildren()) {
            PlanFragment sourceFragment = source.getFragment();
            RemoteSourceNode sourceRemoteSourceNode = (RemoteSourceNode)Iterators.getOnlyElement(fragment.getRemoteSourceNodes().stream().filter(remoteSourceNode -> remoteSourceNode.getSourceFragmentIds().contains(sourceFragment.getId())).iterator());
            Objects.requireNonNull(sourceRemoteSourceNode, "sourceRemoteSourceNode is null");
            if (sourceRemoteSourceNode.getExchangeType() == ExchangeNode.Type.REPLICATE) {
                newSources.add((Object)source);
                continue;
            }
            if (!startedFragments.contains(sourceFragment.getId())) {
                newSources.add((Object)RuntimeAdaptivePartitioningRewriter.overridePartitionCountRecursively(source, oldPartitionCount, newPartitionCount, planFragmentIdAllocator, planNodeIdAllocator, startedFragments));
                runtimeAdaptivePlanFragmentIdMapping.put((Object)sourceFragment.getId(), (Object)sourceFragment.getId());
                continue;
            }
            RemoteSourceNode runtimeAdaptiveRemoteSourceNode = new RemoteSourceNode(planNodeIdAllocator.getNextId(), sourceFragment.getId(), sourceFragment.getOutputPartitioningScheme().getOutputLayout(), sourceRemoteSourceNode.getOrderingScheme(), sourceRemoteSourceNode.getExchangeType(), sourceRemoteSourceNode.getRetryPolicy());
            PlanFragment runtimeAdaptivePlanFragment = new PlanFragment(planFragmentIdAllocator.getNextId(), runtimeAdaptiveRemoteSourceNode, sourceFragment.getSymbols(), SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION, Optional.of(oldPartitionCount), (List<PlanNodeId>)ImmutableList.of(), sourceFragment.getOutputPartitioningScheme().withPartitionCount(Optional.of(newPartitionCount)), sourceFragment.getStatsAndCosts(), sourceFragment.getActiveCatalogs(), sourceFragment.getLanguageFunctions(), sourceFragment.getJsonRepresentation());
            SubPlan newSource = new SubPlan(runtimeAdaptivePlanFragment, (List<SubPlan>)ImmutableList.of((Object)RuntimeAdaptivePartitioningRewriter.overridePartitionCountRecursively(source, oldPartitionCount, newPartitionCount, planFragmentIdAllocator, planNodeIdAllocator, startedFragments)));
            newSources.add((Object)newSource);
            runtimeAdaptivePlanFragmentIdMapping.put((Object)sourceFragment.getId(), (Object)runtimeAdaptivePlanFragment.getId());
        }
        return new SubPlan(fragment.withRoot(SimplePlanRewriter.rewriteWith(new UpdateRemoteSourceFragmentIdsRewriter((Map<PlanFragmentId, PlanFragmentId>)runtimeAdaptivePlanFragmentIdMapping.buildOrThrow()), fragment.getRoot())), (List<SubPlan>)newSources.build());
    }

    public static boolean consumesHashPartitionedInput(PlanFragment fragment) {
        return RuntimeAdaptivePartitioningRewriter.isPartitioned(fragment.getPartitioning());
    }

    public static boolean producesHashPartitionedOutput(PlanFragment fragment) {
        return RuntimeAdaptivePartitioningRewriter.isPartitioned(fragment.getOutputPartitioningScheme().getPartitioning().getHandle());
    }

    public static int getMaxPlanFragmentId(List<SubPlan> subPlans) {
        return subPlans.stream().map(SubPlan::getFragment).map(PlanFragment::getId).mapToInt(fragmentId -> Integer.parseInt(fragmentId.toString())).max().orElseThrow();
    }

    public static int getMaxPlanId(List<SubPlan> subPlans) {
        return subPlans.stream().map(SubPlan::getFragment).map(PlanFragment::getRoot).mapToInt(root -> RuntimeAdaptivePartitioningRewriter.traverse(root).map(PlanNode::getId).mapToInt(planNodeId -> Integer.parseInt(planNodeId.toString())).max().orElseThrow()).max().orElseThrow();
    }

    private static boolean isPartitioned(PartitioningHandle partitioningHandle) {
        return partitioningHandle.equals(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) || partitioningHandle.equals(SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION);
    }

    private static Stream<PlanNode> traverse(PlanNode node) {
        Iterable iterable = Traverser.forTree(PlanNode::getSources).depthFirstPreOrder((Object)node);
        return StreamSupport.stream(iterable.spliterator(), false);
    }

    private static class UpdateRemoteSourceFragmentIdsRewriter
    extends SimplePlanRewriter<Void> {
        private final Map<PlanFragmentId, PlanFragmentId> runtimeAdaptivePlanFragmentIdMapping;

        public UpdateRemoteSourceFragmentIdsRewriter(Map<PlanFragmentId, PlanFragmentId> runtimeAdaptivePlanFragmentIdMapping) {
            this.runtimeAdaptivePlanFragmentIdMapping = Objects.requireNonNull(runtimeAdaptivePlanFragmentIdMapping, "runtimeAdaptivePlanFragmentIdMapping is null");
        }

        @Override
        public PlanNode visitRemoteSource(RemoteSourceNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            if (node.getExchangeType() == ExchangeNode.Type.REPLICATE) {
                return node;
            }
            return node.withSourceFragmentIds((List)node.getSourceFragmentIds().stream().map(this.runtimeAdaptivePlanFragmentIdMapping::get).collect(ImmutableList.toImmutableList()));
        }
    }
}

