/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.base.Verify;
import com.google.common.base.VerifyException;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.concurrent.MoreFutures;
import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.execution.ExecutionFailureInfo;
import io.trino.execution.Lifespan;
import io.trino.execution.RemoteTask;
import io.trino.execution.SqlStage;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.TaskState;
import io.trino.execution.TaskStatus;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.scheduler.BucketNodeMap;
import io.trino.execution.scheduler.NodeAllocator;
import io.trino.execution.scheduler.TaskDescriptor;
import io.trino.execution.scheduler.TaskDescriptorStorage;
import io.trino.execution.scheduler.TaskLifecycleListener;
import io.trino.execution.scheduler.TaskSource;
import io.trino.execution.scheduler.TaskSourceFactory;
import io.trino.failuredetector.FailureDetector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.operator.ExchangeOperator;
import io.trino.spi.ErrorCode;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.ErrorType;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.exchange.Exchange;
import io.trino.spi.exchange.ExchangeSinkHandle;
import io.trino.spi.exchange.ExchangeSinkInstanceHandle;
import io.trino.spi.exchange.ExchangeSourceHandle;
import io.trino.split.RemoteSplit;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.util.Failures;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import java.util.function.Function;
import javax.annotation.concurrent.GuardedBy;

public class FaultTolerantStageScheduler {
    private static final Logger log = Logger.get(FaultTolerantStageScheduler.class);
    private final Session session;
    private final SqlStage stage;
    private final FailureDetector failureDetector;
    private final TaskSourceFactory taskSourceFactory;
    private final NodeAllocator nodeAllocator;
    private final TaskDescriptorStorage taskDescriptorStorage;
    private final TaskLifecycleListener taskLifecycleListener;
    private final Optional<Exchange> sinkExchange;
    private final Optional<int[]> sinkBucketToPartitionMap;
    private final Map<PlanFragmentId, Exchange> sourceExchanges;
    private final Optional<int[]> sourceBucketToPartitionMap;
    private final Optional<BucketNodeMap> sourceBucketNodeMap;
    @GuardedBy(value="this")
    private ListenableFuture<Void> blocked = Futures.immediateVoidFuture();
    @GuardedBy(value="this")
    private ListenableFuture<InternalNode> acquireNodeFuture;
    @GuardedBy(value="this")
    private SettableFuture<Void> taskFinishedFuture;
    @GuardedBy(value="this")
    private TaskSource taskSource;
    @GuardedBy(value="this")
    private final Map<Integer, ExchangeSinkHandle> partitionToExchangeSinkHandleMap = new HashMap<Integer, ExchangeSinkHandle>();
    @GuardedBy(value="this")
    private final Multimap<Integer, RemoteTask> partitionToRemoteTaskMap = ArrayListMultimap.create();
    @GuardedBy(value="this")
    private final Map<TaskId, RemoteTask> runningTasks = new HashMap<TaskId, RemoteTask>();
    @GuardedBy(value="this")
    private final Map<TaskId, InternalNode> runningNodes = new HashMap<TaskId, InternalNode>();
    @GuardedBy(value="this")
    private final Set<Integer> allPartitions = new HashSet<Integer>();
    @GuardedBy(value="this")
    private final Queue<Integer> queuedPartitions = new ArrayDeque<Integer>();
    @GuardedBy(value="this")
    private final Set<Integer> finishedPartitions = new HashSet<Integer>();
    @GuardedBy(value="this")
    private int remainingRetryAttempts;
    @GuardedBy(value="this")
    private Throwable failure;
    @GuardedBy(value="this")
    private boolean closed;

    public FaultTolerantStageScheduler(Session session, SqlStage stage, FailureDetector failureDetector, TaskSourceFactory taskSourceFactory, NodeAllocator nodeAllocator, TaskDescriptorStorage taskDescriptorStorage, TaskLifecycleListener taskLifecycleListener, Optional<Exchange> sinkExchange, Optional<int[]> sinkBucketToPartitionMap, Map<PlanFragmentId, Exchange> sourceExchanges, Optional<int[]> sourceBucketToPartitionMap, Optional<BucketNodeMap> sourceBucketNodeMap, int retryAttempts) {
        Preconditions.checkArgument((!stage.getFragment().getStageExecutionDescriptor().isStageGroupedExecution() ? 1 : 0) != 0, (Object)"grouped execution is expected to be disabled");
        this.session = Objects.requireNonNull(session, "session is null");
        this.stage = Objects.requireNonNull(stage, "stage is null");
        this.failureDetector = Objects.requireNonNull(failureDetector, "failureDetector is null");
        this.taskSourceFactory = Objects.requireNonNull(taskSourceFactory, "taskSourceFactory is null");
        this.nodeAllocator = Objects.requireNonNull(nodeAllocator, "nodeAllocator is null");
        this.taskDescriptorStorage = Objects.requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null");
        this.taskLifecycleListener = Objects.requireNonNull(taskLifecycleListener, "taskLifecycleListener is null");
        this.sinkExchange = Objects.requireNonNull(sinkExchange, "sinkExchange is null");
        this.sinkBucketToPartitionMap = Objects.requireNonNull(sinkBucketToPartitionMap, "sinkBucketToPartitionMap is null");
        this.sourceExchanges = ImmutableMap.copyOf(Objects.requireNonNull(sourceExchanges, "sourceExchanges is null"));
        this.sourceBucketToPartitionMap = Objects.requireNonNull(sourceBucketToPartitionMap, "sourceBucketToPartitionMap is null");
        this.sourceBucketNodeMap = Objects.requireNonNull(sourceBucketNodeMap, "sourceBucketNodeMap is null");
        Preconditions.checkArgument((retryAttempts >= 0 ? 1 : 0) != 0, (String)"retryAttempts must be greater than or equal to 0: %s", (int)retryAttempts);
        this.remainingRetryAttempts = retryAttempts;
    }

    public StageId getStageId() {
        return this.stage.getStageId();
    }

    public synchronized ListenableFuture<Void> isBlocked() {
        return Futures.nonCancellationPropagating(this.blocked);
    }

    public synchronized void schedule() throws Exception {
        if (this.failure != null) {
            Throwables.propagateIfPossible((Throwable)this.failure, Exception.class);
            throw new RuntimeException(this.failure);
        }
        if (this.closed) {
            return;
        }
        if (this.isFinished()) {
            return;
        }
        if (!this.blocked.isDone()) {
            return;
        }
        if (this.taskSource == null) {
            Map sourceHandles = (Map)this.sourceExchanges.entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, entry -> MoreFutures.toListenableFuture((CompletableFuture)((Exchange)entry.getValue()).getSourceHandles())));
            List blockedFutures = (List)sourceHandles.values().stream().filter(future -> !future.isDone()).collect(ImmutableList.toImmutableList());
            if (!blockedFutures.isEmpty()) {
                this.blocked = MoreFutures.asVoid((ListenableFuture)Futures.allAsList((Iterable)blockedFutures));
                return;
            }
            Multimap exchangeSources = (Multimap)sourceHandles.entrySet().stream().collect(ImmutableListMultimap.flatteningToImmutableListMultimap(Map.Entry::getKey, entry -> ((List)MoreFutures.getFutureValue((Future)((Future)entry.getValue()))).stream()));
            this.taskSource = this.taskSourceFactory.create(this.session, this.stage.getFragment(), this.sourceExchanges, (Multimap<PlanFragmentId, ExchangeSourceHandle>)exchangeSources, this.stage::recordGetSplitTime, this.sourceBucketToPartitionMap, this.sourceBucketNodeMap);
        }
        while (!this.queuedPartitions.isEmpty() || !this.taskSource.isFinished()) {
            OutputBuffers outputBuffers;
            Optional<Object> exchangeSinkInstanceHandle;
            while (this.queuedPartitions.isEmpty() && !this.taskSource.isFinished()) {
                List<TaskDescriptor> tasks = this.taskSource.getMoreTasks();
                for (TaskDescriptor task : tasks) {
                    this.queuedPartitions.add(task.getPartitionId());
                    this.allPartitions.add(task.getPartitionId());
                    this.taskDescriptorStorage.put(this.stage.getStageId(), task);
                    this.sinkExchange.ifPresent(exchange -> {
                        ExchangeSinkHandle exchangeSinkHandle = exchange.addSink(task.getPartitionId());
                        this.partitionToExchangeSinkHandleMap.put(task.getPartitionId(), exchangeSinkHandle);
                    });
                }
                if (!this.taskSource.isFinished()) continue;
                this.sinkExchange.ifPresent(Exchange::noMoreSinks);
            }
            if (this.queuedPartitions.isEmpty()) break;
            int partition = this.queuedPartitions.peek();
            Optional<TaskDescriptor> taskDescriptorOptional = this.taskDescriptorStorage.get(this.stage.getStageId(), partition);
            if (taskDescriptorOptional.isEmpty()) {
                return;
            }
            TaskDescriptor taskDescriptor = taskDescriptorOptional.get();
            if (this.acquireNodeFuture == null) {
                this.acquireNodeFuture = this.nodeAllocator.acquire(taskDescriptor.getNodeRequirements());
            }
            if (!this.acquireNodeFuture.isDone()) {
                this.blocked = MoreFutures.asVoid(this.acquireNodeFuture);
                return;
            }
            InternalNode node = (InternalNode)MoreFutures.getFutureValue(this.acquireNodeFuture);
            this.acquireNodeFuture = null;
            this.queuedPartitions.poll();
            ListMultimap<PlanNodeId, Split> tableScanSplits = taskDescriptor.getSplits();
            Multimap<PlanNodeId, Split> remoteSplits = FaultTolerantStageScheduler.createRemoteSplits(taskDescriptor.getExchangeSourceHandles());
            ImmutableListMultimap taskSplits = ImmutableListMultimap.builder().putAll(tableScanSplits).putAll(remoteSplits).build();
            int attemptId = this.getNextAttemptIdForPartition(partition);
            if (this.sinkExchange.isPresent()) {
                ExchangeSinkHandle sinkHandle = this.partitionToExchangeSinkHandleMap.get(partition);
                exchangeSinkInstanceHandle = Optional.of(this.sinkExchange.get().instantiateSink(sinkHandle, attemptId));
                outputBuffers = OutputBuffers.createSpoolingExchangeOutputBuffers((ExchangeSinkInstanceHandle)exchangeSinkInstanceHandle.get());
            } else {
                exchangeSinkInstanceHandle = Optional.empty();
                outputBuffers = OutputBuffers.createInitialEmptyOutputBuffers(OutputBuffers.BufferType.PARTITIONED).withBuffer(new OutputBuffers.OutputBufferId(0), 0).withNoMoreBufferIds();
            }
            ImmutableSet allSourcePlanNodeIds = ImmutableSet.builder().addAll(this.stage.getFragment().getPartitionedSources()).addAll(this.stage.getFragment().getRemoteSourceNodes().stream().map(PlanNode::getId).iterator()).build();
            RemoteTask task = this.stage.createTask(node, partition, attemptId, this.sinkBucketToPartitionMap, outputBuffers, (Multimap<PlanNodeId, Split>)taskSplits, (Multimap<PlanNodeId, Lifespan>)((Multimap)allSourcePlanNodeIds.stream().collect(ImmutableListMultimap.toImmutableListMultimap(Function.identity(), planNodeId -> Lifespan.taskWide()))), (Set<PlanNodeId>)allSourcePlanNodeIds).orElseThrow(() -> new VerifyException("stage execution is expected to be active"));
            this.partitionToRemoteTaskMap.put((Object)partition, (Object)task);
            this.runningTasks.put(task.getTaskId(), task);
            this.runningNodes.put(task.getTaskId(), node);
            if (this.taskFinishedFuture == null) {
                this.taskFinishedFuture = SettableFuture.create();
            }
            this.taskLifecycleListener.taskCreated(this.stage.getFragment().getId(), task);
            task.addStateChangeListener(taskStatus -> this.updateTaskStatus((TaskStatus)taskStatus, (Optional<ExchangeSinkInstanceHandle>)exchangeSinkInstanceHandle));
            task.start();
        }
        if (this.taskFinishedFuture != null && !this.taskFinishedFuture.isDone()) {
            this.blocked = this.taskFinishedFuture;
        }
    }

    public synchronized boolean isFinished() {
        return this.failure == null && this.taskSource != null && this.taskSource.isFinished() && this.queuedPartitions.isEmpty() && this.finishedPartitions.containsAll(this.allPartitions);
    }

    public void cancel() {
        this.close(false);
    }

    public void abort() {
        this.close(true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void fail(Throwable t) {
        FaultTolerantStageScheduler faultTolerantStageScheduler = this;
        synchronized (faultTolerantStageScheduler) {
            if (this.failure == null) {
                this.failure = t;
            }
        }
        this.close(true);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void close(boolean abort) {
        boolean closed;
        FaultTolerantStageScheduler faultTolerantStageScheduler = this;
        synchronized (faultTolerantStageScheduler) {
            closed = this.closed;
            this.closed = true;
        }
        if (!closed) {
            this.cancelRunningTasks(abort);
            this.cancelBlockedFuture();
            this.releaseAcquiredNode();
            this.closeTaskSource();
            this.closeSinkExchange();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void cancelRunningTasks(boolean abort) {
        ImmutableList tasks;
        FaultTolerantStageScheduler faultTolerantStageScheduler = this;
        synchronized (faultTolerantStageScheduler) {
            tasks = ImmutableList.copyOf(this.runningTasks.values());
        }
        if (abort) {
            tasks.forEach(RemoteTask::abort);
        } else {
            tasks.forEach(RemoteTask::cancel);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void cancelBlockedFuture() {
        ListenableFuture<Void> future;
        Verify.verify((!Thread.holdsLock(this) ? 1 : 0) != 0);
        FaultTolerantStageScheduler faultTolerantStageScheduler = this;
        synchronized (faultTolerantStageScheduler) {
            future = this.blocked;
        }
        if (future != null && !future.isDone()) {
            future.cancel(true);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void releaseAcquiredNode() {
        ListenableFuture<InternalNode> future;
        Verify.verify((!Thread.holdsLock(this) ? 1 : 0) != 0);
        FaultTolerantStageScheduler faultTolerantStageScheduler = this;
        synchronized (faultTolerantStageScheduler) {
            future = this.acquireNodeFuture;
            this.acquireNodeFuture = null;
        }
        if (future != null) {
            future.cancel(true);
            if (future.isDone() && !future.isCancelled()) {
                this.nodeAllocator.release((InternalNode)MoreFutures.getFutureValue(future));
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void closeTaskSource() {
        TaskSource taskSource;
        FaultTolerantStageScheduler faultTolerantStageScheduler = this;
        synchronized (faultTolerantStageScheduler) {
            taskSource = this.taskSource;
        }
        if (taskSource != null) {
            try {
                taskSource.close();
            }
            catch (RuntimeException e) {
                log.warn((Throwable)e, "Error closing task source for stage: %s", new Object[]{this.stage.getStageId()});
            }
        }
    }

    private void closeSinkExchange() {
        try {
            this.sinkExchange.ifPresent(Exchange::close);
        }
        catch (RuntimeException e) {
            log.warn((Throwable)e, "Error closing sink exchange for stage: %s", new Object[]{this.stage.getStageId()});
        }
    }

    public synchronized void reportTaskFailure(TaskId taskId, Throwable failureCause) {
        RemoteTask task = this.runningTasks.get(taskId);
        if (task != null) {
            task.fail(failureCause);
        }
    }

    public void failTaskRemotely(TaskId taskId, Throwable failureCause) {
        RemoteTask task = this.runningTasks.get(taskId);
        if (task != null) {
            task.failRemotely(failureCause);
        }
    }

    private int getNextAttemptIdForPartition(int partition) {
        int latestAttemptId = this.partitionToRemoteTaskMap.get((Object)partition).stream().mapToInt(task -> task.getTaskId().getAttemptId()).max().orElse(-1);
        return latestAttemptId + 1;
    }

    private static Multimap<PlanNodeId, Split> createRemoteSplits(Multimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles) {
        ImmutableListMultimap.Builder result = ImmutableListMultimap.builder();
        for (PlanNodeId planNodeId : exchangeSourceHandles.keySet()) {
            result.put((Object)planNodeId, (Object)new Split(ExchangeOperator.REMOTE_CONNECTOR_ID, new RemoteSplit(new RemoteSplit.SpoolingExchangeInput((List<ExchangeSourceHandle>)ImmutableList.copyOf((Collection)exchangeSourceHandles.get((Object)planNodeId)))), Lifespan.taskWide()));
        }
        return result.build();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void updateTaskStatus(TaskStatus taskStatus, Optional<ExchangeSinkInstanceHandle> exchangeSinkInstanceHandle) {
        TaskState state = taskStatus.getState();
        if (!state.isDone()) {
            return;
        }
        try {
            SettableFuture<Void> future;
            RuntimeException failure = null;
            FaultTolerantStageScheduler faultTolerantStageScheduler = this;
            synchronized (faultTolerantStageScheduler) {
                TaskId taskId = taskStatus.getTaskId();
                this.runningTasks.remove(taskId);
                future = this.taskFinishedFuture;
                this.taskFinishedFuture = !this.runningTasks.isEmpty() ? SettableFuture.create() : null;
                InternalNode node = Objects.requireNonNull(this.runningNodes.remove(taskId), () -> "node not found for task id: " + taskId);
                this.nodeAllocator.release(node);
                int partitionId = taskId.getPartitionId();
                if (!this.finishedPartitions.contains(partitionId) && !this.closed) {
                    switch (state) {
                        case FINISHED: {
                            this.finishedPartitions.add(partitionId);
                            if (this.sinkExchange.isPresent()) {
                                Preconditions.checkArgument((boolean)exchangeSinkInstanceHandle.isPresent(), (Object)"exchangeSinkInstanceHandle is expected to be present");
                                this.sinkExchange.get().sinkFinished(exchangeSinkInstanceHandle.get());
                            }
                            this.partitionToRemoteTaskMap.get((Object)partitionId).forEach(RemoteTask::abort);
                            break;
                        }
                        case CANCELED: {
                            log.debug("Task cancelled: %s", new Object[]{taskId});
                            break;
                        }
                        case ABORTED: {
                            log.debug("Task aborted: %s", new Object[]{taskId});
                            break;
                        }
                        case FAILED: {
                            ExecutionFailureInfo failureInfo = taskStatus.getFailures().stream().findFirst().map(this::rewriteTransportFailure).orElse(Failures.toFailure(new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")));
                            log.warn((Throwable)failureInfo.toException(), "Task failed: %s", new Object[]{taskId});
                            ErrorCode errorCode = failureInfo.getErrorCode();
                            if (this.remainingRetryAttempts > 0 && (errorCode == null || errorCode.getType() != ErrorType.USER_ERROR)) {
                                --this.remainingRetryAttempts;
                                this.queuedPartitions.add(partitionId);
                                log.debug("Retrying partition %s for stage %s", new Object[]{partitionId, this.stage.getStageId()});
                                break;
                            }
                            failure = failureInfo.toException();
                            break;
                        }
                        default: {
                            throw new IllegalArgumentException("Unexpected task state: " + state);
                        }
                    }
                }
            }
            if (failure != null) {
                this.fail(failure);
            }
            if (future != null && !future.isDone()) {
                future.set(null);
            }
        }
        catch (Throwable t) {
            this.fail(t);
        }
    }

    private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) {
        if (executionFailureInfo.getRemoteHost() == null || this.failureDetector.getState(executionFailureInfo.getRemoteHost()) != FailureDetector.State.GONE) {
            return executionFailureInfo;
        }
        return new ExecutionFailureInfo(executionFailureInfo.getType(), executionFailureInfo.getMessage(), executionFailureInfo.getCause(), executionFailureInfo.getSuppressed(), executionFailureInfo.getStack(), executionFailureInfo.getErrorLocation(), StandardErrorCode.REMOTE_HOST_GONE.toErrorCode(), executionFailureInfo.getRemoteHost());
    }
}

