/*
 * Decompiled with CFR 0.152.
 */
package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.SystemSessionProperties;
import io.trino.cost.RuntimeInfoProvider;
import io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.optimizations.AdaptivePlanOptimizer;
import io.trino.sql.planner.optimizations.PlanOptimizer;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.SimplePlanRewriter;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

public class AdaptivePartitioning
implements AdaptivePlanOptimizer {
    private static final Logger log = Logger.get(AdaptivePartitioning.class);

    @Override
    public AdaptivePlanOptimizer.Result optimizeAndMarkPlanChanges(PlanNode plan, PlanOptimizer.Context context) {
        if (!SystemSessionProperties.isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(context.session())) {
            return new AdaptivePlanOptimizer.Result(plan, (Set<PlanNodeId>)ImmutableSet.of());
        }
        int maxPartitionCount = SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount(context.session());
        int runtimeAdaptivePartitioningPartitionCount = SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(context.session());
        long runtimeAdaptivePartitioningMaxTaskSizeInBytes = SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(context.session()).toBytes();
        RuntimeInfoProvider runtimeInfoProvider = context.runtimeInfoProvider();
        for (PlanFragment fragment : runtimeInfoProvider.getAllPlanFragments()) {
            List partitionedInputBytes;
            long estimatedMemoryConsumptionInBytes;
            int partitionCount;
            if (!AdaptivePartitioning.consumesHashPartitionedInput(fragment) || runtimeInfoProvider.getRuntimeOutputStats(fragment.getId()).isAccurate() || (partitionCount = fragment.getPartitionCount().orElse(maxPartitionCount).intValue()) >= runtimeAdaptivePartitioningPartitionCount || (estimatedMemoryConsumptionInBytes = (partitionedInputBytes = (List)fragment.getRemoteSourceNodes().stream().filter(remoteSourceNode -> remoteSourceNode.getExchangeType() != ExchangeNode.Type.REPLICATE).map(remoteSourceNode -> remoteSourceNode.getSourceFragmentIds().stream().mapToLong(sourceFragmentId -> {
                OutputStatsEstimator.OutputStatsEstimateResult runtimeStats = runtimeInfoProvider.getRuntimeOutputStats((PlanFragmentId)sourceFragmentId);
                return runtimeStats.outputDataSizeEstimate().getTotalSizeInBytes();
            }).sum()).collect(ImmutableList.toImmutableList())).size() == 1 ? (Long)partitionedInputBytes.get(0) : partitionedInputBytes.stream().mapToLong(Long::longValue).sum() - (Long)Collections.min(partitionedInputBytes)) <= runtimeAdaptivePartitioningMaxTaskSizeInBytes * (long)partitionCount) continue;
            log.info("Stage %s has an estimated memory consumption of %s, changing partition count from %s to %s", new Object[]{fragment.getId(), DataSize.succinctBytes((long)estimatedMemoryConsumptionInBytes), partitionCount, runtimeAdaptivePartitioningPartitionCount});
            Rewriter rewriter = new Rewriter(runtimeAdaptivePartitioningPartitionCount, context.idAllocator(), runtimeInfoProvider);
            PlanNode planNode = SimplePlanRewriter.rewriteWith(rewriter, plan);
            return new AdaptivePlanOptimizer.Result(planNode, rewriter.getChangedPlanIds());
        }
        return new AdaptivePlanOptimizer.Result(plan, (Set<PlanNodeId>)ImmutableSet.of());
    }

    public static boolean consumesHashPartitionedInput(PlanFragment fragment) {
        return AdaptivePartitioning.isPartitioned(fragment.getPartitioning());
    }

    private static boolean isPartitioned(PartitioningHandle partitioningHandle) {
        return partitioningHandle.equals(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) || partitioningHandle.equals(SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION);
    }

    private static class Rewriter
    extends SimplePlanRewriter<Void> {
        private final int partitionCount;
        private final PlanNodeIdAllocator idAllocator;
        private final RuntimeInfoProvider runtimeInfoProvider;
        private final Set<PlanNodeId> changedPlanIds = new HashSet<PlanNodeId>();

        private Rewriter(int partitionCount, PlanNodeIdAllocator idAllocator, RuntimeInfoProvider runtimeInfoProvider) {
            this.partitionCount = partitionCount;
            this.idAllocator = Objects.requireNonNull(idAllocator, "idAllocator is null");
            this.runtimeInfoProvider = Objects.requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null");
        }

        @Override
        public PlanNode visitExchange(ExchangeNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            if (node.getPartitioningScheme().getPartitioning().getHandle().equals(SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION)) {
                return node;
            }
            List sources = (List)node.getSources().stream().map(context::rewrite).collect(ImmutableList.toImmutableList());
            PartitioningScheme partitioningScheme = node.getPartitioningScheme();
            if (node.getPartitioningScheme().getPartitioning().getHandle() == SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION) {
                partitioningScheme = partitioningScheme.withPartitionCount(Optional.of(this.partitionCount));
                this.changedPlanIds.add(node.getId());
            }
            return new ExchangeNode(node.getId(), node.getType(), node.getScope(), partitioningScheme, sources, node.getInputs(), node.getOrderingScheme());
        }

        @Override
        public PlanNode visitRemoteSource(RemoteSourceNode node, SimplePlanRewriter.RewriteContext<Void> context) {
            if (node.getExchangeType() != ExchangeNode.Type.REPARTITION) {
                return node;
            }
            Optional<PartitioningScheme> sourcePartitioningScheme = node.getSourceFragmentIds().stream().map(this.runtimeInfoProvider::getPlanFragment).map(PlanFragment::getOutputPartitioningScheme).filter(scheme -> AdaptivePartitioning.isPartitioned(scheme.getPartitioning().getHandle())).findFirst();
            if (sourcePartitioningScheme.isEmpty()) {
                return node;
            }
            PartitioningScheme newPartitioningSchema = sourcePartitioningScheme.get().withPartitionCount(Optional.of(this.partitionCount)).withPartitioningHandle(SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION);
            PlanNodeId nodeId = this.idAllocator.getNextId();
            this.changedPlanIds.add(nodeId);
            return new ExchangeNode(nodeId, ExchangeNode.Type.REPARTITION, ExchangeNode.Scope.REMOTE, newPartitioningSchema, (List<PlanNode>)ImmutableList.of((Object)node), (List<List<Symbol>>)ImmutableList.of(node.getOutputSymbols()), node.getOrderingScheme());
        }

        public Set<PlanNodeId> getChangedPlanIds() {
            return ImmutableSet.copyOf(this.changedPlanIds);
        }
    }
}

