/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.airlift.concurrent.MoreFutures;
import io.airlift.concurrent.SetThreadName;
import io.airlift.http.client.HttpUriBuilder;
import io.airlift.log.Logger;
import io.airlift.stats.TimeStat;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.Tracer;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.exchange.DirectExchangeInput;
import io.trino.exchange.ExchangeInput;
import io.trino.execution.BasicStageStats;
import io.trino.execution.ExecutionFailureInfo;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.QueryState;
import io.trino.execution.QueryStateMachine;
import io.trino.execution.RemoteTask;
import io.trino.execution.RemoteTaskFactory;
import io.trino.execution.SqlStage;
import io.trino.execution.SqlTaskManager;
import io.trino.execution.StageId;
import io.trino.execution.StageInfo;
import io.trino.execution.StateMachine;
import io.trino.execution.TableExecuteContextManager;
import io.trino.execution.TaskFailureListener;
import io.trino.execution.TaskId;
import io.trino.execution.TaskStatus;
import io.trino.execution.scheduler.BroadcastPipelinedOutputBufferManager;
import io.trino.execution.scheduler.BucketNodeMap;
import io.trino.execution.scheduler.DynamicSplitPlacementPolicy;
import io.trino.execution.scheduler.FixedCountScheduler;
import io.trino.execution.scheduler.FixedSourcePartitionedScheduler;
import io.trino.execution.scheduler.MultiSourcePartitionedScheduler;
import io.trino.execution.scheduler.NodeScheduler;
import io.trino.execution.scheduler.NodeSelector;
import io.trino.execution.scheduler.PartitionedPipelinedOutputBufferManager;
import io.trino.execution.scheduler.PipelinedOutputBufferManager;
import io.trino.execution.scheduler.PipelinedStageExecution;
import io.trino.execution.scheduler.QueryScheduler;
import io.trino.execution.scheduler.ScaledPipelinedOutputBufferManager;
import io.trino.execution.scheduler.ScaledWriterScheduler;
import io.trino.execution.scheduler.ScheduleResult;
import io.trino.execution.scheduler.SourcePartitionedScheduler;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.execution.scheduler.StageExecution;
import io.trino.execution.scheduler.StageManager;
import io.trino.execution.scheduler.StageScheduler;
import io.trino.execution.scheduler.TaskLifecycleListener;
import io.trino.execution.scheduler.policy.ExecutionPolicy;
import io.trino.execution.scheduler.policy.ExecutionSchedule;
import io.trino.execution.scheduler.policy.StagesScheduleResult;
import io.trino.failuredetector.FailureDetector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Metadata;
import io.trino.metadata.Split;
import io.trino.operator.RetryPolicy;
import io.trino.server.DynamicFilterService;
import io.trino.spi.ErrorCode;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.ErrorType;
import io.trino.spi.QueryId;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogHandle;
import io.trino.split.SplitSource;
import io.trino.sql.planner.NodePartitionMap;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SplitSourceFactory;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.SystemPartitioningHandle;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.util.Failures;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class PipelinedQueryScheduler
implements QueryScheduler {
    private static final Logger log = Logger.get(PipelinedQueryScheduler.class);
    private final QueryStateMachine queryStateMachine;
    private final NodePartitioningManager nodePartitioningManager;
    private final NodeScheduler nodeScheduler;
    private final int splitBatchSize;
    private final ExecutorService executor;
    private final ScheduledExecutorService schedulerExecutor;
    private final FailureDetector failureDetector;
    private final ExecutionPolicy executionPolicy;
    private final SplitSchedulerStats schedulerStats;
    private final DynamicFilterService dynamicFilterService;
    private final TableExecuteContextManager tableExecuteContextManager;
    private final SplitSourceFactory splitSourceFactory;
    private final StageManager stageManager;
    private final CoordinatorStagesScheduler coordinatorStagesScheduler;
    private final RetryPolicy retryPolicy;
    private final int maxQueryRetryAttempts;
    private final AtomicInteger currentAttempt = new AtomicInteger();
    private final Duration retryInitialDelay;
    private final Duration retryMaxDelay;
    private final double retryDelayScaleFactor;
    @GuardedBy(value="this")
    private boolean started;
    @GuardedBy(value="this")
    private final AtomicReference<DistributedStagesScheduler> distributedStagesScheduler = new AtomicReference();
    @GuardedBy(value="this")
    private Future<Void> distributedStagesSchedulingTask;

    public PipelinedQueryScheduler(QueryStateMachine queryStateMachine, SubPlan plan, NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, RemoteTaskFactory remoteTaskFactory, boolean summarizeTaskInfo, int splitBatchSize, ExecutorService queryExecutor, ScheduledExecutorService schedulerExecutor, FailureDetector failureDetector, NodeTaskMap nodeTaskMap, ExecutionPolicy executionPolicy, Tracer tracer, SplitSchedulerStats schedulerStats, DynamicFilterService dynamicFilterService, TableExecuteContextManager tableExecuteContextManager, Metadata metadata, SplitSourceFactory splitSourceFactory, SqlTaskManager coordinatorTaskManager) {
        this.queryStateMachine = Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
        this.nodePartitioningManager = Objects.requireNonNull(nodePartitioningManager, "nodePartitioningManager is null");
        this.nodeScheduler = Objects.requireNonNull(nodeScheduler, "nodeScheduler is null");
        this.splitBatchSize = splitBatchSize;
        this.executor = Objects.requireNonNull(queryExecutor, "queryExecutor is null");
        this.schedulerExecutor = Objects.requireNonNull(schedulerExecutor, "schedulerExecutor is null");
        this.failureDetector = Objects.requireNonNull(failureDetector, "failureDetector is null");
        this.executionPolicy = Objects.requireNonNull(executionPolicy, "executionPolicy is null");
        this.schedulerStats = Objects.requireNonNull(schedulerStats, "schedulerStats is null");
        this.dynamicFilterService = Objects.requireNonNull(dynamicFilterService, "dynamicFilterService is null");
        this.tableExecuteContextManager = Objects.requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null");
        this.splitSourceFactory = Objects.requireNonNull(splitSourceFactory, "splitSourceFactory is null");
        this.stageManager = StageManager.create(queryStateMachine, metadata, remoteTaskFactory, nodeTaskMap, tracer, schedulerStats, plan, summarizeTaskInfo);
        this.coordinatorStagesScheduler = CoordinatorStagesScheduler.create(queryStateMachine, nodeScheduler, this.stageManager, failureDetector, schedulerExecutor, this.distributedStagesScheduler, coordinatorTaskManager);
        this.retryPolicy = SystemSessionProperties.getRetryPolicy(queryStateMachine.getSession());
        Verify.verify((this.retryPolicy == RetryPolicy.NONE || this.retryPolicy == RetryPolicy.QUERY ? 1 : 0) != 0, (String)"unexpected retry policy: %s", (Object)((Object)this.retryPolicy));
        this.maxQueryRetryAttempts = SystemSessionProperties.getQueryRetryAttempts(queryStateMachine.getSession());
        this.retryInitialDelay = SystemSessionProperties.getRetryInitialDelay(queryStateMachine.getSession());
        this.retryMaxDelay = SystemSessionProperties.getRetryMaxDelay(queryStateMachine.getSession());
        this.retryDelayScaleFactor = SystemSessionProperties.getRetryDelayScaleFactor(queryStateMachine.getSession());
    }

    @Override
    public synchronized void start() {
        if (this.started) {
            return;
        }
        this.started = true;
        if (this.queryStateMachine.isDone()) {
            return;
        }
        this.queryStateMachine.addStateChangeListener(state -> {
            DistributedStagesScheduler distributedStagesScheduler;
            if (!state.isDone()) {
                return;
            }
            PipelinedQueryScheduler pipelinedQueryScheduler = this;
            synchronized (pipelinedQueryScheduler) {
                distributedStagesScheduler = this.distributedStagesScheduler.get();
            }
            if (state == QueryState.FINISHED) {
                this.coordinatorStagesScheduler.cancel();
                if (distributedStagesScheduler != null) {
                    distributedStagesScheduler.cancel();
                }
                this.stageManager.finish();
            } else if (state == QueryState.FAILED) {
                this.coordinatorStagesScheduler.abort();
                if (distributedStagesScheduler != null) {
                    distributedStagesScheduler.abort();
                }
                this.stageManager.abort();
            }
            this.queryStateMachine.updateQueryInfo(Optional.ofNullable(this.getStageInfo()));
        });
        Optional<DistributedStagesScheduler> distributedStagesScheduler = this.createDistributedStagesScheduler(this.currentAttempt.get());
        this.coordinatorStagesScheduler.schedule();
        distributedStagesScheduler.ifPresent(scheduler -> {
            this.distributedStagesSchedulingTask = this.executor.submit(scheduler::schedule, null);
        });
    }

    private synchronized Optional<DistributedStagesScheduler> createDistributedStagesScheduler(int attempt) {
        Verify.verify((attempt == 0 || this.retryPolicy == RetryPolicy.QUERY ? 1 : 0) != 0, (String)"unexpected attempt %s for retry policy %s", (int)attempt, (Object)((Object)this.retryPolicy));
        if (this.queryStateMachine.isDone()) {
            return Optional.empty();
        }
        DistributedStagesScheduler distributedStagesScheduler = switch (this.retryPolicy) {
            case RetryPolicy.QUERY, RetryPolicy.NONE -> {
                if (attempt > 0) {
                    this.dynamicFilterService.registerQueryRetry(this.queryStateMachine.getQueryId(), attempt);
                }
                yield DistributedStagesScheduler.create(this.queryStateMachine, this.schedulerStats, this.nodeScheduler, this.nodePartitioningManager, this.stageManager, this.coordinatorStagesScheduler, this.executionPolicy, this.failureDetector, this.schedulerExecutor, this.splitSourceFactory, this.splitBatchSize, this.dynamicFilterService, this.tableExecuteContextManager, this.retryPolicy, attempt);
            }
            default -> throw new IllegalArgumentException("Unexpected retry policy: " + String.valueOf((Object)this.retryPolicy));
        };
        this.distributedStagesScheduler.set(distributedStagesScheduler);
        distributedStagesScheduler.addStateChangeListener(state -> {
            if (this.queryStateMachine.getQueryState() == QueryState.STARTING && (state == DistributedStagesSchedulerState.RUNNING || state.isDone())) {
                this.queryStateMachine.transitionToRunning();
            }
            if (state.isDone() && !state.isFailure()) {
                this.stageManager.getDistributedStagesInTopologicalOrder().forEach(stage -> this.stageManager.get(stage.getStageId()).finish());
            }
            if (this.stageManager.getCoordinatorStagesInTopologicalOrder().isEmpty()) {
                if (state == DistributedStagesSchedulerState.FINISHED) {
                    this.queryStateMachine.transitionToFinishing();
                } else if (state == DistributedStagesSchedulerState.CANCELED) {
                    this.queryStateMachine.transitionToCanceled();
                }
            }
            if (state == DistributedStagesSchedulerState.FAILED) {
                StageFailureInfo stageFailureInfo = distributedStagesScheduler.getFailureCause().orElseGet(() -> new StageFailureInfo(Failures.toFailure(new VerifyException("distributedStagesScheduler failed but failure cause is not present")), Optional.empty()));
                ErrorCode errorCode = stageFailureInfo.getFailureInfo().getErrorCode();
                if (this.shouldRetry(errorCode)) {
                    long delayInMillis = Math.min(this.retryInitialDelay.toMillis() * (long)Math.pow(this.retryDelayScaleFactor, this.currentAttempt.get()), this.retryMaxDelay.toMillis());
                    this.currentAttempt.incrementAndGet();
                    this.scheduleRetryWithDelay(delayInMillis);
                } else {
                    this.stageManager.getDistributedStagesInTopologicalOrder().forEach(stage -> {
                        if (stageFailureInfo.getFailedStageId().isPresent() && stageFailureInfo.getFailedStageId().get().equals(stage.getStageId())) {
                            stage.fail(stageFailureInfo.getFailureInfo().toException());
                        } else {
                            stage.abort();
                        }
                    });
                    this.queryStateMachine.transitionToFailed(stageFailureInfo.getFailureInfo().toException());
                }
            }
        });
        return Optional.of(distributedStagesScheduler);
    }

    private boolean shouldRetry(ErrorCode errorCode) {
        return this.retryPolicy == RetryPolicy.QUERY && this.currentAttempt.get() < this.maxQueryRetryAttempts && PipelinedQueryScheduler.isRetryableErrorCode(errorCode);
    }

    private static boolean isRetryableErrorCode(ErrorCode errorCode) {
        return errorCode == null || errorCode.getType() == ErrorType.INTERNAL_ERROR || errorCode.getType() == ErrorType.EXTERNAL || errorCode.getCode() == StandardErrorCode.CLUSTER_OUT_OF_MEMORY.toErrorCode().getCode();
    }

    private void scheduleRetryWithDelay(long delayInMillis) {
        try {
            this.schedulerExecutor.schedule(this::scheduleRetry, delayInMillis, TimeUnit.MILLISECONDS);
        }
        catch (Throwable t) {
            this.queryStateMachine.transitionToFailed(t);
        }
    }

    private synchronized void scheduleRetry() {
        try {
            Preconditions.checkState((this.distributedStagesSchedulingTask != null ? 1 : 0) != 0, (Object)"schedulingTask is expected to be set");
            this.distributedStagesSchedulingTask.get(5L, TimeUnit.MINUTES);
            Optional<DistributedStagesScheduler> distributedStagesScheduler = this.createDistributedStagesScheduler(this.currentAttempt.get());
            distributedStagesScheduler.ifPresent(scheduler -> {
                this.distributedStagesSchedulingTask = this.executor.submit(scheduler::schedule, null);
            });
        }
        catch (Throwable t) {
            this.queryStateMachine.transitionToFailed(t);
        }
    }

    @Override
    public synchronized void cancelStage(StageId stageId) {
        try (SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});){
            this.coordinatorStagesScheduler.cancelStage(stageId);
            DistributedStagesScheduler distributedStagesScheduler = this.distributedStagesScheduler.get();
            if (distributedStagesScheduler != null) {
                distributedStagesScheduler.cancelStage(stageId);
            }
        }
    }

    @Override
    public void failTask(TaskId taskId, Throwable failureCause) {
        try (SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});){
            this.stageManager.failTaskRemotely(taskId, failureCause);
        }
    }

    @Override
    public BasicStageStats getBasicStageStats() {
        return this.stageManager.getBasicStageStats();
    }

    @Override
    public StageInfo getStageInfo() {
        return this.stageManager.getStageInfo();
    }

    @Override
    public long getUserMemoryReservation() {
        return this.stageManager.getUserMemoryReservation();
    }

    @Override
    public long getTotalMemoryReservation() {
        return this.stageManager.getTotalMemoryReservation();
    }

    @Override
    public Duration getTotalCpuTime() {
        return this.stageManager.getTotalCpuTime();
    }

    private static class CoordinatorStagesScheduler {
        private static final int[] SINGLE_PARTITION = new int[]{0};
        private final QueryStateMachine queryStateMachine;
        private final NodeScheduler nodeScheduler;
        private final Map<PlanFragmentId, PipelinedOutputBufferManager> outputBuffersForStagesConsumedByCoordinator;
        private final Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator;
        private final TaskLifecycleListener taskLifecycleListener;
        private final StageManager stageManager;
        private final List<StageExecution> stageExecutions;
        private final AtomicReference<DistributedStagesScheduler> distributedStagesScheduler;
        private final SqlTaskManager coordinatorTaskManager;
        private final AtomicBoolean scheduled = new AtomicBoolean();

        public static CoordinatorStagesScheduler create(QueryStateMachine queryStateMachine, NodeScheduler nodeScheduler, StageManager stageManager, FailureDetector failureDetector, Executor executor, AtomicReference<DistributedStagesScheduler> distributedStagesScheduler, SqlTaskManager coordinatorTaskManager) {
            Map<PlanFragmentId, PipelinedOutputBufferManager> outputBuffersForStagesConsumedByCoordinator = CoordinatorStagesScheduler.createOutputBuffersForStagesConsumedByCoordinator(stageManager);
            Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator = CoordinatorStagesScheduler.createBucketToPartitionForStagesConsumedByCoordinator(stageManager);
            TaskLifecycleListener taskLifecycleListener = new QueryOutputTaskLifecycleListener(queryStateMachine);
            ImmutableList.Builder stageExecutions = ImmutableList.builder();
            for (SqlStage stage : stageManager.getCoordinatorStagesInTopologicalOrder()) {
                PipelinedStageExecution stageExecution = PipelinedStageExecution.createPipelinedStageExecution(stage, outputBuffersForStagesConsumedByCoordinator, taskLifecycleListener, failureDetector, executor, bucketToPartitionForStagesConsumedByCoordinator.get(stage.getFragment().getId()), 0);
                stageExecutions.add((Object)stageExecution);
                taskLifecycleListener = stageExecution.getTaskLifecycleListener();
            }
            CoordinatorStagesScheduler coordinatorStagesScheduler = new CoordinatorStagesScheduler(queryStateMachine, nodeScheduler, outputBuffersForStagesConsumedByCoordinator, bucketToPartitionForStagesConsumedByCoordinator, taskLifecycleListener, stageManager, (List<StageExecution>)stageExecutions.build(), distributedStagesScheduler, coordinatorTaskManager);
            coordinatorStagesScheduler.initialize();
            return coordinatorStagesScheduler;
        }

        private static Map<PlanFragmentId, PipelinedOutputBufferManager> createOutputBuffersForStagesConsumedByCoordinator(StageManager stageManager) {
            ImmutableMap.Builder result = ImmutableMap.builder();
            SqlStage outputStage = stageManager.getOutputStage();
            result.put((Object)outputStage.getFragment().getId(), (Object)CoordinatorStagesScheduler.createSingleStreamOutputBuffer(outputStage));
            for (SqlStage coordinatorStage : stageManager.getCoordinatorStagesInTopologicalOrder()) {
                for (SqlStage childStage : stageManager.getChildren(coordinatorStage.getStageId())) {
                    result.put((Object)childStage.getFragment().getId(), (Object)CoordinatorStagesScheduler.createSingleStreamOutputBuffer(childStage));
                }
            }
            return result.buildOrThrow();
        }

        private static PipelinedOutputBufferManager createSingleStreamOutputBuffer(SqlStage stage) {
            PartitioningHandle partitioningHandle = stage.getFragment().getOutputPartitioningScheme().getPartitioning().getHandle();
            Preconditions.checkArgument((boolean)partitioningHandle.isSingleNode(), (String)"partitioning is expected to be single node: %s", (Object)partitioningHandle);
            return new PartitionedPipelinedOutputBufferManager(partitioningHandle, 1);
        }

        private static Map<PlanFragmentId, Optional<int[]>> createBucketToPartitionForStagesConsumedByCoordinator(StageManager stageManager) {
            ImmutableMap.Builder result = ImmutableMap.builder();
            SqlStage outputStage = stageManager.getOutputStage();
            result.put((Object)outputStage.getFragment().getId(), Optional.of(SINGLE_PARTITION));
            for (SqlStage coordinatorStage : stageManager.getCoordinatorStagesInTopologicalOrder()) {
                for (SqlStage childStage : stageManager.getChildren(coordinatorStage.getStageId())) {
                    result.put((Object)childStage.getFragment().getId(), Optional.of(SINGLE_PARTITION));
                }
            }
            return result.buildOrThrow();
        }

        private CoordinatorStagesScheduler(QueryStateMachine queryStateMachine, NodeScheduler nodeScheduler, Map<PlanFragmentId, PipelinedOutputBufferManager> outputBuffersForStagesConsumedByCoordinator, Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator, TaskLifecycleListener taskLifecycleListener, StageManager stageManager, List<StageExecution> stageExecutions, AtomicReference<DistributedStagesScheduler> distributedStagesScheduler, SqlTaskManager coordinatorTaskManager) {
            this.queryStateMachine = Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
            this.nodeScheduler = Objects.requireNonNull(nodeScheduler, "nodeScheduler is null");
            this.outputBuffersForStagesConsumedByCoordinator = ImmutableMap.copyOf(Objects.requireNonNull(outputBuffersForStagesConsumedByCoordinator, "outputBuffersForStagesConsumedByCoordinator is null"));
            this.bucketToPartitionForStagesConsumedByCoordinator = ImmutableMap.copyOf(Objects.requireNonNull(bucketToPartitionForStagesConsumedByCoordinator, "bucketToPartitionForStagesConsumedByCoordinator is null"));
            this.taskLifecycleListener = Objects.requireNonNull(taskLifecycleListener, "taskLifecycleListener is null");
            this.stageManager = Objects.requireNonNull(stageManager, "stageManager is null");
            this.stageExecutions = ImmutableList.copyOf((Collection)Objects.requireNonNull(stageExecutions, "stageExecutions is null"));
            this.distributedStagesScheduler = Objects.requireNonNull(distributedStagesScheduler, "distributedStagesScheduler is null");
            this.coordinatorTaskManager = Objects.requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null");
        }

        private void initialize() {
            for (StageExecution stageExecution2 : this.stageExecutions) {
                stageExecution2.addStateChangeListener(state -> {
                    if (this.queryStateMachine.isDone()) {
                        return;
                    }
                    if (state == StageExecution.State.FAILED) {
                        RuntimeException failureCause = stageExecution2.getFailureCause().map(ExecutionFailureInfo::toException).orElseGet(() -> new VerifyException(String.format("stage execution for stage %s is failed by failure cause is not present", stageExecution2.getStageId())));
                        this.stageManager.get(stageExecution2.getStageId()).fail(failureCause);
                        this.queryStateMachine.transitionToFailed(failureCause);
                    } else if (state == StageExecution.State.ABORTED) {
                        this.stageManager.get(stageExecution2.getStageId()).abort();
                        this.queryStateMachine.transitionToFailed(new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "Query stage was aborted"));
                    } else if (state.isDone()) {
                        this.stageManager.get(stageExecution2.getStageId()).finish();
                    }
                });
            }
            int currentIndex = 0;
            for (int nextIndex = 1; nextIndex < this.stageExecutions.size(); ++nextIndex) {
                StageExecution stageExecution3 = this.stageExecutions.get(currentIndex);
                StageExecution childStageExecution = this.stageExecutions.get(nextIndex);
                Set<SqlStage> childStages = this.stageManager.getChildren(stageExecution3.getStageId());
                Verify.verify((childStages.size() == 1 ? 1 : 0) != 0, (String)"exactly one child stage is expected", (Object[])new Object[0]);
                SqlStage childStage = (SqlStage)Iterables.getOnlyElement(childStages);
                Verify.verify((boolean)childStage.getStageId().equals(childStageExecution.getStageId()), (String)"stage execution order doesn't match the stage order", (Object[])new Object[0]);
                stageExecution3.addStateChangeListener(newState -> {
                    if (newState == StageExecution.State.FLUSHING || newState.isDone()) {
                        childStageExecution.cancel();
                    }
                });
                ++currentIndex;
            }
            Optional<StageExecution> root = Optional.ofNullable((StageExecution)Iterables.getFirst(this.stageExecutions, null));
            root.ifPresent(stageExecution -> stageExecution.addStateChangeListener(state -> {
                if (state == StageExecution.State.FINISHED) {
                    this.queryStateMachine.transitionToFinishing();
                } else if (state == StageExecution.State.CANCELED) {
                    this.queryStateMachine.transitionToCanceled();
                }
            }));
            Optional<StageExecution> last = Optional.ofNullable((StageExecution)Iterables.getLast(this.stageExecutions, null));
            last.ifPresent(stageExecution -> stageExecution.addStateChangeListener(newState -> {
                DistributedStagesScheduler distributedStagesScheduler;
                if ((newState == StageExecution.State.FLUSHING || newState.isDone()) && (distributedStagesScheduler = this.distributedStagesScheduler.get()) != null) {
                    distributedStagesScheduler.cancel();
                }
            }));
        }

        public synchronized void schedule() {
            if (!this.scheduled.compareAndSet(false, true)) {
                return;
            }
            TaskFailureReporter failureReporter = new TaskFailureReporter(this.distributedStagesScheduler);
            this.queryStateMachine.addOutputTaskFailureListener(failureReporter);
            InternalNode coordinator = this.nodeScheduler.createNodeSelector(this.queryStateMachine.getSession(), Optional.empty()).selectCurrentNode();
            for (StageExecution stageExecution : this.stageExecutions) {
                Optional<RemoteTask> remoteTask = stageExecution.scheduleTask(coordinator, 0, (Multimap<PlanNodeId, Split>)ImmutableMultimap.of());
                stageExecution.schedulingComplete();
                remoteTask.ifPresent(task -> this.coordinatorTaskManager.addSourceTaskFailureListener(task.getTaskId(), failureReporter));
                if (this.queryStateMachine.getQueryState() != QueryState.STARTING || !remoteTask.isPresent()) continue;
                this.queryStateMachine.transitionToRunning();
            }
        }

        public Map<PlanFragmentId, PipelinedOutputBufferManager> getOutputBuffersForStagesConsumedByCoordinator() {
            return this.outputBuffersForStagesConsumedByCoordinator;
        }

        public Map<PlanFragmentId, Optional<int[]>> getBucketToPartitionForStagesConsumedByCoordinator() {
            return this.bucketToPartitionForStagesConsumedByCoordinator;
        }

        public TaskLifecycleListener getTaskLifecycleListener() {
            return this.taskLifecycleListener;
        }

        public void cancelStage(StageId stageId) {
            for (StageExecution stageExecution : this.stageExecutions) {
                if (!stageExecution.getStageId().equals(stageId)) continue;
                stageExecution.cancel();
            }
        }

        public void cancel() {
            this.stageExecutions.forEach(StageExecution::cancel);
        }

        public void abort() {
            this.stageExecutions.forEach(StageExecution::abort);
        }
    }

    private static class DistributedStagesScheduler {
        private final DistributedStagesSchedulerStateMachine stateMachine;
        private final QueryStateMachine queryStateMachine;
        private final SplitSchedulerStats schedulerStats;
        private final StageManager stageManager;
        private final ExecutionSchedule executionSchedule;
        private final Map<StageId, StageScheduler> stageSchedulers;
        private final Map<StageId, StageExecution> stageExecutions;
        private final DynamicFilterService dynamicFilterService;
        private final AtomicBoolean started = new AtomicBoolean();

        public static DistributedStagesScheduler create(QueryStateMachine queryStateMachine, SplitSchedulerStats schedulerStats, NodeScheduler nodeScheduler, NodePartitioningManager nodePartitioningManager, StageManager stageManager, CoordinatorStagesScheduler coordinatorStagesScheduler, ExecutionPolicy executionPolicy, FailureDetector failureDetector, ScheduledExecutorService executor, SplitSourceFactory splitSourceFactory, int splitBatchSize, DynamicFilterService dynamicFilterService, TableExecuteContextManager tableExecuteContextManager, RetryPolicy retryPolicy, int attempt) {
            DistributedStagesSchedulerStateMachine stateMachine = new DistributedStagesSchedulerStateMachine(queryStateMachine.getQueryId(), executor);
            HashMap partitioningCacheMap = new HashMap();
            Function<PartitioningKey, NodePartitionMap> partitioningCache = partitioningKey -> partitioningCacheMap.computeIfAbsent(partitioningKey, partitioning -> nodePartitioningManager.getNodePartitioningMap(queryStateMachine.getSession(), partitioning.handle.equals(SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION) ? SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION : partitioning.handle, partitioning.partitionCount));
            Map<PlanFragmentId, Optional<int[]>> bucketToPartitionMap = DistributedStagesScheduler.createBucketToPartitionMap(coordinatorStagesScheduler.getBucketToPartitionForStagesConsumedByCoordinator(), stageManager, partitioningCache);
            Map<PlanFragmentId, PipelinedOutputBufferManager> outputBufferManagers = DistributedStagesScheduler.createOutputBufferManagers(coordinatorStagesScheduler.getOutputBuffersForStagesConsumedByCoordinator(), stageManager, bucketToPartitionMap);
            TaskLifecycleListener coordinatorTaskLifecycleListener = coordinatorStagesScheduler.getTaskLifecycleListener();
            if (retryPolicy != RetryPolicy.NONE) {
                TaskLifecycleListenerBridge taskLifecycleListenerBridge = new TaskLifecycleListenerBridge(coordinatorTaskLifecycleListener);
                coordinatorTaskLifecycleListener = taskLifecycleListenerBridge;
                stateMachine.addStateChangeListener(state -> {
                    if (state == DistributedStagesSchedulerState.FINISHED) {
                        taskLifecycleListenerBridge.notifyNoMoreSourceTasks();
                    }
                });
            }
            LinkedHashMap<StageId, PipelinedStageExecution> stageExecutions = new LinkedHashMap<StageId, PipelinedStageExecution>();
            for (SqlStage sqlStage : stageManager.getDistributedStagesInTopologicalOrder()) {
                TaskLifecycleListener taskLifecycleListener;
                Optional<SqlStage> parentStage = stageManager.getParent(sqlStage.getStageId());
                if (parentStage.isEmpty() || parentStage.get().getFragment().getPartitioning().isCoordinatorOnly()) {
                    taskLifecycleListener = coordinatorTaskLifecycleListener;
                } else {
                    StageId parentStageId = parentStage.get().getStageId();
                    StageExecution parentStageExecution = Objects.requireNonNull((StageExecution)stageExecutions.get(parentStageId), () -> "execution is null for stage: " + String.valueOf(parentStageId));
                    taskLifecycleListener = parentStageExecution.getTaskLifecycleListener();
                }
                PlanFragment fragment = sqlStage.getFragment();
                PipelinedStageExecution stageExecution = PipelinedStageExecution.createPipelinedStageExecution(stageManager.get(fragment.getId()), outputBufferManagers, taskLifecycleListener, failureDetector, executor, bucketToPartitionMap.get(fragment.getId()), attempt);
                stageExecutions.put(sqlStage.getStageId(), stageExecution);
            }
            ImmutableMap.Builder stageSchedulers = ImmutableMap.builder();
            for (StageExecution stageExecution : stageExecutions.values()) {
                List children = (List)stageManager.getChildren(stageExecution.getStageId()).stream().map(stage -> Objects.requireNonNull((StageExecution)stageExecutions.get(stage.getStageId()), () -> "stage execution not found for stage: " + String.valueOf(stage))).collect(ImmutableList.toImmutableList());
                StageScheduler scheduler = DistributedStagesScheduler.createStageScheduler(queryStateMachine, stageExecution, splitSourceFactory, children, partitioningCache, nodeScheduler, nodePartitioningManager, splitBatchSize, dynamicFilterService, executor, tableExecuteContextManager);
                stageSchedulers.put((Object)stageExecution.getStageId(), (Object)scheduler);
            }
            DistributedStagesScheduler distributedStagesScheduler = new DistributedStagesScheduler(stateMachine, queryStateMachine, schedulerStats, stageManager, executionPolicy.createExecutionSchedule(stageExecutions.values()), (Map<StageId, StageScheduler>)stageSchedulers.buildOrThrow(), (Map<StageId, StageExecution>)ImmutableMap.copyOf(stageExecutions), dynamicFilterService);
            distributedStagesScheduler.initialize();
            return distributedStagesScheduler;
        }

        private static Map<PlanFragmentId, Optional<int[]>> createBucketToPartitionMap(Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator, StageManager stageManager, Function<PartitioningKey, NodePartitionMap> partitioningCache) {
            ImmutableMap.Builder result = ImmutableMap.builder();
            result.putAll(bucketToPartitionForStagesConsumedByCoordinator);
            for (SqlStage stage : stageManager.getDistributedStagesInTopologicalOrder()) {
                PlanFragment fragment = stage.getFragment();
                Optional<int[]> bucketToPartition = DistributedStagesScheduler.getBucketToPartition(fragment.getPartitioning(), partitioningCache, fragment.getRoot(), fragment.getRemoteSourceNodes(), fragment.getPartitionCount());
                for (SqlStage childStage : stageManager.getChildren(stage.getStageId())) {
                    result.put((Object)childStage.getFragment().getId(), bucketToPartition);
                }
            }
            return result.buildOrThrow();
        }

        private static Optional<int[]> getBucketToPartition(PartitioningHandle partitioningHandle, Function<PartitioningKey, NodePartitionMap> partitioningCache, PlanNode fragmentRoot, List<RemoteSourceNode> remoteSourceNodes, Optional<Integer> partitionCount) {
            if (partitioningHandle.equals(SystemPartitioningHandle.SOURCE_DISTRIBUTION) || partitioningHandle.equals(SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
                return Optional.of(new int[1]);
            }
            if (PlanNodeSearcher.searchFrom(fragmentRoot).where(node -> node instanceof TableScanNode).findFirst().isPresent()) {
                if (remoteSourceNodes.stream().allMatch(node -> node.getExchangeType() == ExchangeNode.Type.REPLICATE)) {
                    return Optional.empty();
                }
                NodePartitionMap nodePartitionMap = partitioningCache.apply(new PartitioningKey(partitioningHandle, partitionCount));
                return Optional.of(nodePartitionMap.getBucketToPartition());
            }
            NodePartitionMap nodePartitionMap = partitioningCache.apply(new PartitioningKey(partitioningHandle, partitionCount));
            List<InternalNode> partitionToNode = nodePartitionMap.getPartitionToNode();
            Failures.checkCondition(!partitionToNode.isEmpty(), (ErrorCodeSupplier)StandardErrorCode.NO_NODES_AVAILABLE, "No worker nodes available", new Object[0]);
            return Optional.of(nodePartitionMap.getBucketToPartition());
        }

        private static Map<PlanFragmentId, PipelinedOutputBufferManager> createOutputBufferManagers(Map<PlanFragmentId, PipelinedOutputBufferManager> outputBuffersForStagesConsumedByCoordinator, StageManager stageManager, Map<PlanFragmentId, Optional<int[]>> bucketToPartitionMap) {
            ImmutableMap.Builder result = ImmutableMap.builder();
            result.putAll(outputBuffersForStagesConsumedByCoordinator);
            for (SqlStage parentStage : stageManager.getDistributedStagesInTopologicalOrder()) {
                for (SqlStage childStage : stageManager.getChildren(parentStage.getStageId())) {
                    PipelinedOutputBufferManager outputBufferManager;
                    PlanFragmentId fragmentId = childStage.getFragment().getId();
                    PartitioningHandle partitioningHandle = childStage.getFragment().getOutputPartitioningScheme().getPartitioning().getHandle();
                    if (partitioningHandle.equals(SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION)) {
                        outputBufferManager = new BroadcastPipelinedOutputBufferManager();
                    } else if (partitioningHandle.equals(SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
                        outputBufferManager = new ScaledPipelinedOutputBufferManager();
                    } else {
                        Optional<int[]> bucketToPartition = bucketToPartitionMap.get(fragmentId);
                        Preconditions.checkArgument((boolean)bucketToPartition.isPresent(), (String)"bucketToPartition is expected to be present for fragment: %s", (Object)fragmentId);
                        int partitionCount = Ints.max((int[])bucketToPartition.get()) + 1;
                        outputBufferManager = new PartitionedPipelinedOutputBufferManager(partitioningHandle, partitionCount);
                    }
                    result.put((Object)fragmentId, (Object)outputBufferManager);
                }
            }
            return result.buildOrThrow();
        }

        private static StageScheduler createStageScheduler(QueryStateMachine queryStateMachine, StageExecution stageExecution, SplitSourceFactory splitSourceFactory, List<StageExecution> childStageExecutions, Function<PartitioningKey, NodePartitionMap> partitioningCache, NodeScheduler nodeScheduler, NodePartitioningManager nodePartitioningManager, int splitBatchSize, DynamicFilterService dynamicFilterService, ScheduledExecutorService executor, TableExecuteContextManager tableExecuteContextManager) {
            List<InternalNode> stageNodeList;
            BucketNodeMap bucketNodeMap;
            Session session = queryStateMachine.getSession();
            Span stageSpan = stageExecution.getStageSpan();
            PlanFragment fragment = stageExecution.getFragment();
            PartitioningHandle partitioningHandle = fragment.getPartitioning();
            Optional<Integer> partitionCount = fragment.getPartitionCount();
            final Map<PlanNodeId, SplitSource> splitSources = splitSourceFactory.createSplitSources(session, stageSpan, fragment);
            if (!splitSources.isEmpty()) {
                queryStateMachine.addStateChangeListener(new StateMachine.StateChangeListener<QueryState>(){
                    private final AtomicReference<Collection<SplitSource>> splitSourcesReference;
                    {
                        this.splitSourcesReference = new AtomicReference(splitSources.values());
                    }

                    @Override
                    public void stateChanged(QueryState newState) {
                        Collection sources;
                        if (newState.isDone() && (sources = (Collection)this.splitSourcesReference.getAndSet(null)) != null) {
                            DistributedStagesScheduler.closeSplitSources(sources);
                        }
                    }
                });
            }
            if (partitioningHandle.equals(SystemPartitioningHandle.SOURCE_DISTRIBUTION)) {
                if (splitSources.size() == 1) {
                    Map.Entry entry = (Map.Entry)Iterables.getOnlyElement(splitSources.entrySet());
                    PlanNodeId planNodeId = (PlanNodeId)entry.getKey();
                    SplitSource splitSource = (SplitSource)entry.getValue();
                    Optional<CatalogHandle> catalogHandle = Optional.of(splitSource.getCatalogHandle()).filter(catalog -> !catalog.getType().isInternal());
                    NodeSelector nodeSelector = nodeScheduler.createNodeSelector(session, catalogHandle);
                    DynamicSplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeSelector, stageExecution::getAllTasks);
                    return SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler(stageExecution, planNodeId, splitSource, placementPolicy, splitBatchSize, dynamicFilterService, tableExecuteContextManager, () -> childStageExecutions.stream().anyMatch(StageExecution::isAnyTaskBlocked));
                }
                Set allCatalogHandles = (Set)splitSources.values().stream().map(SplitSource::getCatalogHandle).filter(catalog -> !catalog.getType().isInternal()).collect(ImmutableSet.toImmutableSet());
                Preconditions.checkState((allCatalogHandles.size() <= 1 ? 1 : 0) != 0, (Object)"table scans that are within one stage should read from same catalog");
                Optional<CatalogHandle> catalogHandle = allCatalogHandles.size() == 1 ? Optional.of((CatalogHandle)Iterables.getOnlyElement((Iterable)allCatalogHandles)) : Optional.empty();
                NodeSelector nodeSelector = nodeScheduler.createNodeSelector(session, catalogHandle);
                return new MultiSourcePartitionedScheduler(stageExecution, splitSources, new DynamicSplitPlacementPolicy(nodeSelector, stageExecution::getAllTasks), splitBatchSize, dynamicFilterService, tableExecuteContextManager, () -> childStageExecutions.stream().anyMatch(StageExecution::isAnyTaskBlocked));
            }
            if (partitioningHandle.equals(SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
                Supplier<Collection<TaskStatus>> sourceTasksProvider = () -> (Collection)childStageExecutions.stream().map(StageExecution::getTaskStatuses).flatMap(Collection::stream).collect(ImmutableList.toImmutableList());
                Supplier<Collection<TaskStatus>> writerTasksProvider = stageExecution::getTaskStatuses;
                Preconditions.checkState((boolean)partitionCount.isPresent(), (Object)"Partition count cannot be empty when scale writers is used");
                ScaledWriterScheduler scheduler = new ScaledWriterScheduler(stageExecution, sourceTasksProvider, writerTasksProvider, nodeScheduler.createNodeSelector(session, Optional.empty()), executor, SystemSessionProperties.getWriterScalingMinDataProcessed(session), partitionCount.get());
                DistributedStagesScheduler.whenAllStages(childStageExecutions, StageExecution.State::isDone).addListener(scheduler::finish, MoreExecutors.directExecutor());
                return scheduler;
            }
            if (splitSources.isEmpty()) {
                NodePartitionMap nodePartitionMap = partitioningCache.apply(new PartitioningKey(partitioningHandle, partitionCount));
                List<InternalNode> partitionToNode = nodePartitionMap.getPartitionToNode();
                Failures.checkCondition(!partitionToNode.isEmpty(), (ErrorCodeSupplier)StandardErrorCode.NO_NODES_AVAILABLE, "No worker nodes available", new Object[0]);
                return new FixedCountScheduler(stageExecution, partitionToNode);
            }
            List<PlanNodeId> schedulingOrder = fragment.getPartitionedSources();
            Optional<CatalogHandle> catalogHandle = partitioningHandle.getCatalogHandle();
            Preconditions.checkArgument((boolean)catalogHandle.isPresent(), (String)"No catalog handle for partitioning handle: %s", (Object)partitioningHandle);
            if (fragment.getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == ExchangeNode.Type.REPLICATE)) {
                bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle);
                stageNodeList = new ArrayList<InternalNode>(nodeScheduler.createNodeSelector(session, catalogHandle).allNodes());
                Collections.shuffle(stageNodeList);
            } else {
                NodePartitionMap nodePartitionMap = partitioningCache.apply(new PartitioningKey(partitioningHandle, partitionCount));
                stageNodeList = nodePartitionMap.getPartitionToNode();
                bucketNodeMap = nodePartitionMap.asBucketNodeMap();
            }
            return new FixedSourcePartitionedScheduler(stageExecution, splitSources, schedulingOrder, stageNodeList, bucketNodeMap, splitBatchSize, nodeScheduler.createNodeSelector(session, catalogHandle), dynamicFilterService, tableExecuteContextManager);
        }

        private static void closeSplitSources(Collection<SplitSource> splitSources) {
            for (SplitSource source : splitSources) {
                try {
                    source.close();
                }
                catch (Throwable t) {
                    log.warn(t, "Error closing split source");
                }
            }
        }

        private static ListenableFuture<Void> whenAllStages(Collection<StageExecution> stages, Predicate<StageExecution.State> predicate) {
            Preconditions.checkArgument((!stages.isEmpty() ? 1 : 0) != 0, (Object)"stages is empty");
            Set stageIds = stages.stream().map(StageExecution::getStageId).collect(Collectors.toCollection(Sets::newConcurrentHashSet));
            SettableFuture future = SettableFuture.create();
            for (StageExecution stageExecution : stages) {
                stageExecution.addStateChangeListener((StageExecution.State state) -> {
                    if (predicate.test((StageExecution.State)((Object)state)) && stageIds.remove(stageExecution.getStageId()) && stageIds.isEmpty()) {
                        future.set(null);
                    }
                });
            }
            return future;
        }

        private DistributedStagesScheduler(DistributedStagesSchedulerStateMachine stateMachine, QueryStateMachine queryStateMachine, SplitSchedulerStats schedulerStats, StageManager stageManager, ExecutionSchedule executionSchedule, Map<StageId, StageScheduler> stageSchedulers, Map<StageId, StageExecution> stageExecutions, DynamicFilterService dynamicFilterService) {
            this.stateMachine = Objects.requireNonNull(stateMachine, "stateMachine is null");
            this.queryStateMachine = Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
            this.schedulerStats = Objects.requireNonNull(schedulerStats, "schedulerStats is null");
            this.stageManager = Objects.requireNonNull(stageManager, "stageManager is null");
            this.executionSchedule = Objects.requireNonNull(executionSchedule, "executionSchedule is null");
            this.stageSchedulers = ImmutableMap.copyOf(Objects.requireNonNull(stageSchedulers, "stageSchedulers is null"));
            this.stageExecutions = ImmutableMap.copyOf(Objects.requireNonNull(stageExecutions, "stageExecutions is null"));
            this.dynamicFilterService = Objects.requireNonNull(dynamicFilterService, "dynamicFilterService is null");
        }

        private void initialize() {
            for (StageExecution stageExecution : this.stageExecutions.values()) {
                List childStageExecutions = (List)this.stageManager.getChildren(stageExecution.getStageId()).stream().map(stage -> Objects.requireNonNull(this.stageExecutions.get(stage.getStageId()), () -> "stage execution not found for stage: " + String.valueOf(stage))).collect(ImmutableList.toImmutableList());
                if (childStageExecutions.isEmpty()) continue;
                stageExecution.addStateChangeListener((StageExecution.State newState) -> {
                    if (newState == StageExecution.State.FLUSHING || newState.isDone()) {
                        childStageExecutions.forEach(StageExecution::cancel);
                    }
                });
            }
            Set finishedStages = Sets.newConcurrentHashSet();
            for (StageExecution stageExecution : this.stageExecutions.values()) {
                stageExecution.addStateChangeListener((StageExecution.State state) -> {
                    if (this.stateMachine.getState().isDone()) {
                        return;
                    }
                    int numberOfTasks = stageExecution.getAllTasks().size();
                    if (!state.canScheduleMoreTasks()) {
                        this.dynamicFilterService.stageCannotScheduleMoreTasks(stageExecution.getStageId(), stageExecution.getAttemptId(), numberOfTasks);
                    }
                    if (state == StageExecution.State.FAILED) {
                        RuntimeException failureCause = stageExecution.getFailureCause().map(ExecutionFailureInfo::toException).orElseGet(() -> new VerifyException(String.format("stage execution for stage %s is failed by failure cause is not present", stageExecution.getStageId())));
                        this.fail(failureCause, Optional.of(stageExecution.getStageId()));
                    } else if (state.isDone()) {
                        finishedStages.add(stageExecution.getStageId());
                        if (finishedStages.containsAll(this.stageExecutions.keySet())) {
                            this.stateMachine.transitionToFinished();
                        }
                    }
                });
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void schedule() {
            RuntimeException closeError;
            Preconditions.checkState((boolean)this.started.compareAndSet(false, true), (Object)"already started");
            try {
                try (SetThreadName ignored = new SetThreadName("Query-%s", new Object[]{this.queryStateMachine.getQueryId()});){
                    this.stageSchedulers.values().forEach(StageScheduler::start);
                    while (!this.executionSchedule.isFinished()) {
                        ArrayList<ListenableFuture<Void>> blockedStages = new ArrayList<ListenableFuture<Void>>();
                        StagesScheduleResult stagesScheduleResult = this.executionSchedule.getStagesToSchedule();
                        block27: for (StageExecution stageExecution : stagesScheduleResult.getStagesToSchedule()) {
                            stageExecution.beginScheduling();
                            ScheduleResult scheduleResult = this.stageSchedulers.get(stageExecution.getStageId()).schedule();
                            if (this.stateMachine.getState() == DistributedStagesSchedulerState.PLANNED && stageExecution.getAllTasks().size() > 0) {
                                this.stateMachine.transitionToRunning();
                            }
                            if (scheduleResult.isFinished()) {
                                stageExecution.schedulingComplete();
                            } else if (!scheduleResult.getBlocked().isDone()) {
                                blockedStages.add(scheduleResult.getBlocked());
                            }
                            this.schedulerStats.getSplitsScheduledPerIteration().add((long)scheduleResult.getSplitsScheduled());
                            if (!scheduleResult.getBlockedReason().isPresent()) continue;
                            switch (scheduleResult.getBlockedReason().get()) {
                                case WRITER_SCALING: {
                                    continue block27;
                                }
                                case WAITING_FOR_SOURCE: {
                                    this.schedulerStats.getWaitingForSource().update(1L);
                                    continue block27;
                                }
                                case SPLIT_QUEUES_FULL: {
                                    this.schedulerStats.getSplitQueuesFull().update(1L);
                                    continue block27;
                                }
                            }
                            throw new UnsupportedOperationException("Unknown blocked reason: " + String.valueOf((Object)scheduleResult.getBlockedReason().get()));
                        }
                        if (blockedStages.isEmpty()) continue;
                        ImmutableList.Builder futures = ImmutableList.builder();
                        futures.addAll(blockedStages);
                        stagesScheduleResult.getRescheduleFuture().ifPresent(arg_0 -> ((ImmutableList.Builder)futures).add(arg_0));
                        try (TimeStat.BlockTimer timer = this.schedulerStats.getSleepTime().time();){
                            MoreFutures.tryGetFutureValue((Future)MoreFutures.whenAnyComplete((Iterable)futures.build()), (int)1, (TimeUnit)TimeUnit.SECONDS);
                        }
                        for (ListenableFuture listenableFuture : blockedStages) {
                            listenableFuture.cancel(true);
                        }
                    }
                    for (StageExecution stageExecution : this.stageExecutions.values()) {
                        StageExecution.State state = stageExecution.getState();
                        if (state == StageExecution.State.SCHEDULED || state == StageExecution.State.RUNNING || state == StageExecution.State.FLUSHING || state.isDone()) continue;
                        throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, String.format("Scheduling is complete, but stage %s is in state %s", new Object[]{stageExecution.getStageId(), state}));
                    }
                }
                closeError = new RuntimeException();
            }
            catch (Throwable t) {
                try {
                    this.fail(t, Optional.empty());
                }
                catch (Throwable throwable) {
                    throw throwable;
                }
                finally {
                    RuntimeException closeError2 = new RuntimeException();
                    for (StageScheduler scheduler : this.stageSchedulers.values()) {
                        try {
                            scheduler.close();
                        }
                        catch (Throwable t2) {
                            this.fail(t2, Optional.empty());
                            if (closeError2 == t2) continue;
                            closeError2.addSuppressed(t2);
                        }
                    }
                }
            }
            for (StageScheduler scheduler : this.stageSchedulers.values()) {
                try {
                    scheduler.close();
                }
                catch (Throwable t) {
                    this.fail(t, Optional.empty());
                    if (closeError == t) continue;
                    closeError.addSuppressed(t);
                }
            }
        }

        public void cancelStage(StageId stageId) {
            StageExecution stageExecution = this.stageExecutions.get(stageId);
            if (stageExecution != null) {
                stageExecution.cancel();
            }
        }

        public void cancel() {
            this.stateMachine.transitionToCanceled();
            this.stageExecutions.values().forEach(StageExecution::cancel);
        }

        public void abort() {
            this.stateMachine.transitionToAborted();
            this.stageExecutions.values().forEach(StageExecution::abort);
        }

        public void fail(Throwable failureCause, Optional<StageId> failedStageId) {
            this.stateMachine.transitionToFailed(failureCause, failedStageId);
            this.stageExecutions.values().forEach(StageExecution::abort);
        }

        public void reportTaskFailure(TaskId taskId, Throwable failureCause) {
            StageExecution stageExecution = this.stageExecutions.get(taskId.getStageId());
            if (stageExecution == null) {
                return;
            }
            List<RemoteTask> tasks = stageExecution.getAllTasks();
            if (tasks.stream().noneMatch(task -> task.getTaskId().equals(taskId))) {
                return;
            }
            stageExecution.failTask(taskId, failureCause);
            this.stateMachine.transitionToFailed(failureCause, Optional.of(taskId.getStageId()));
            this.stageExecutions.values().forEach(StageExecution::abort);
        }

        public void addStateChangeListener(StateMachine.StateChangeListener<DistributedStagesSchedulerState> stateChangeListener) {
            this.stateMachine.addStateChangeListener(stateChangeListener);
        }

        public Optional<StageFailureInfo> getFailureCause() {
            return this.stateMachine.getFailureCause();
        }
    }

    private static enum DistributedStagesSchedulerState {
        PLANNED(false, false),
        RUNNING(false, false),
        FINISHED(true, false),
        CANCELED(true, false),
        ABORTED(true, true),
        FAILED(true, true);

        public static final Set<DistributedStagesSchedulerState> TERMINAL_STATES;
        private final boolean doneState;
        private final boolean failureState;

        private DistributedStagesSchedulerState(boolean doneState, boolean failureState) {
            Preconditions.checkArgument((!failureState || doneState ? 1 : 0) != 0, (String)"%s is a non-done failure state", (Object)this.name());
            this.doneState = doneState;
            this.failureState = failureState;
        }

        public boolean isDone() {
            return this.doneState;
        }

        public boolean isFailure() {
            return this.failureState;
        }

        static {
            TERMINAL_STATES = (Set)Stream.of(DistributedStagesSchedulerState.values()).filter(DistributedStagesSchedulerState::isDone).collect(ImmutableSet.toImmutableSet());
        }
    }

    private static class StageFailureInfo {
        private final ExecutionFailureInfo failureInfo;
        private final Optional<StageId> failedStageId;

        private StageFailureInfo(ExecutionFailureInfo failureInfo, Optional<StageId> failedStageId) {
            this.failureInfo = Objects.requireNonNull(failureInfo, "failureInfo is null");
            this.failedStageId = Objects.requireNonNull(failedStageId, "failedStageId is null");
        }

        public ExecutionFailureInfo getFailureInfo() {
            return this.failureInfo;
        }

        public Optional<StageId> getFailedStageId() {
            return this.failedStageId;
        }
    }

    private record PartitioningKey(PartitioningHandle handle, Optional<Integer> partitionCount) {
        public PartitioningKey(PartitioningHandle handle, Optional<Integer> partitionCount) {
            this.handle = Objects.requireNonNull(handle, "handle cannot be null");
            this.partitionCount = Objects.requireNonNull(partitionCount, "partitionCount cannot be null");
        }
    }

    private static class TaskLifecycleListenerBridge
    implements TaskLifecycleListener {
        private final TaskLifecycleListener listener;
        @GuardedBy(value="this")
        private final Set<PlanFragmentId> noMoreSourceTasks = new HashSet<PlanFragmentId>();
        @GuardedBy(value="this")
        private boolean done;

        private TaskLifecycleListenerBridge(TaskLifecycleListener listener) {
            this.listener = Objects.requireNonNull(listener, "listener is null");
        }

        @Override
        public synchronized void taskCreated(PlanFragmentId fragmentId, RemoteTask task) {
            Preconditions.checkState((!this.done ? 1 : 0) != 0, (Object)"unexpected state");
            this.listener.taskCreated(fragmentId, task);
        }

        @Override
        public synchronized void noMoreTasks(PlanFragmentId fragmentId) {
            Preconditions.checkState((!this.done ? 1 : 0) != 0, (Object)"unexpected state");
            this.noMoreSourceTasks.add(fragmentId);
        }

        public synchronized void notifyNoMoreSourceTasks() {
            Preconditions.checkState((!this.done ? 1 : 0) != 0, (Object)"unexpected state");
            this.done = true;
            this.noMoreSourceTasks.forEach(this.listener::noMoreTasks);
        }
    }

    private static class DistributedStagesSchedulerStateMachine {
        private final QueryId queryId;
        private final StateMachine<DistributedStagesSchedulerState> state;
        private final AtomicReference<StageFailureInfo> failureCause = new AtomicReference();

        public DistributedStagesSchedulerStateMachine(QueryId queryId, Executor executor) {
            this.queryId = Objects.requireNonNull(queryId, "queryId is null");
            Objects.requireNonNull(executor, "executor is null");
            this.state = new StateMachine<DistributedStagesSchedulerState>("Distributed stages scheduler", executor, DistributedStagesSchedulerState.PLANNED, DistributedStagesSchedulerState.TERMINAL_STATES);
        }

        public DistributedStagesSchedulerState getState() {
            return this.state.get();
        }

        public boolean transitionToRunning() {
            return this.state.setIf(DistributedStagesSchedulerState.RUNNING, currentState -> !currentState.isDone());
        }

        public boolean transitionToFinished() {
            return this.state.setIf(DistributedStagesSchedulerState.FINISHED, currentState -> !currentState.isDone());
        }

        public boolean transitionToCanceled() {
            return this.state.setIf(DistributedStagesSchedulerState.CANCELED, currentState -> !currentState.isDone());
        }

        public boolean transitionToAborted() {
            return this.state.setIf(DistributedStagesSchedulerState.ABORTED, currentState -> !currentState.isDone());
        }

        public boolean transitionToFailed(Throwable throwable, Optional<StageId> failedStageId) {
            Objects.requireNonNull(throwable, "throwable is null");
            this.failureCause.compareAndSet(null, new StageFailureInfo(Failures.toFailure(throwable), failedStageId));
            boolean failed = this.state.setIf(DistributedStagesSchedulerState.FAILED, currentState -> !currentState.isDone());
            if (failed) {
                log.error(throwable, "Failure in distributed stage for query %s", new Object[]{this.queryId});
            } else {
                log.debug(throwable, "Failure in distributed stage for query %s after finished", new Object[]{this.queryId});
            }
            return failed;
        }

        public Optional<StageFailureInfo> getFailureCause() {
            return Optional.ofNullable(this.failureCause.get());
        }

        public void addStateChangeListener(StateMachine.StateChangeListener<DistributedStagesSchedulerState> stateChangeListener) {
            this.state.addStateChangeListener(stateChangeListener);
        }
    }

    private static class TaskFailureReporter
    implements TaskFailureListener {
        private final AtomicReference<DistributedStagesScheduler> distributedStagesScheduler;

        private TaskFailureReporter(AtomicReference<DistributedStagesScheduler> distributedStagesScheduler) {
            this.distributedStagesScheduler = distributedStagesScheduler;
        }

        @Override
        public void onTaskFailed(TaskId taskId, Throwable failure) {
            if (failure instanceof TrinoException && StandardErrorCode.REMOTE_TASK_FAILED.toErrorCode().equals((Object)((TrinoException)failure).getErrorCode())) {
                log.debug("Task failure discovered while fetching task results: %s", new Object[]{taskId});
                return;
            }
            log.warn(failure, "Reported task failure: %s", new Object[]{taskId});
            DistributedStagesScheduler scheduler = this.distributedStagesScheduler.get();
            if (scheduler != null) {
                scheduler.reportTaskFailure(taskId, failure);
            }
        }
    }

    private static class QueryOutputTaskLifecycleListener
    implements TaskLifecycleListener {
        private final QueryStateMachine queryStateMachine;

        private QueryOutputTaskLifecycleListener(QueryStateMachine queryStateMachine) {
            this.queryStateMachine = Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
        }

        @Override
        public void taskCreated(PlanFragmentId fragmentId, RemoteTask task) {
            URI taskUri = HttpUriBuilder.uriBuilderFrom((URI)task.getTaskStatus().getSelf()).appendPath("results").appendPath("0").build();
            DirectExchangeInput input = new DirectExchangeInput(task.getTaskId(), taskUri.toString());
            this.queryStateMachine.updateInputsForQueryResults((List<ExchangeInput>)ImmutableList.of((Object)input), false);
        }

        @Override
        public void noMoreTasks(PlanFragmentId fragmentId) {
            this.queryStateMachine.updateInputsForQueryResults((List<ExchangeInput>)ImmutableList.of(), true);
        }
    }
}

