/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import io.airlift.http.client.HttpUriBuilder;
import io.airlift.log.Logger;
import io.trino.exchange.DirectExchangeInput;
import io.trino.execution.ExecutionFailureInfo;
import io.trino.execution.RemoteTask;
import io.trino.execution.SqlStage;
import io.trino.execution.StageId;
import io.trino.execution.StateMachine;
import io.trino.execution.TaskId;
import io.trino.execution.TaskState;
import io.trino.execution.TaskStatus;
import io.trino.execution.buffer.OutputBufferStatus;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.scheduler.PipelinedOutputBufferManager;
import io.trino.execution.scheduler.StageExecution;
import io.trino.execution.scheduler.TaskLifecycleListener;
import io.trino.failuredetector.FailureDetector;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.operator.ExchangeOperator;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.split.RemoteSplit;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.RemoteSourceNode;
import io.trino.util.Failures;
import java.net.URI;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.annotation.concurrent.GuardedBy;

public class PipelinedStageExecution
implements StageExecution {
    private static final Logger log = Logger.get(PipelinedStageExecution.class);
    private final PipelinedStageStateMachine stateMachine;
    private final SqlStage stage;
    private final Map<PlanFragmentId, PipelinedOutputBufferManager> outputBufferManagers;
    private final TaskLifecycleListener taskLifecycleListener;
    private final FailureDetector failureDetector;
    private final Optional<int[]> bucketToPartition;
    private final Map<PlanFragmentId, RemoteSourceNode> exchangeSources;
    private final int attempt;
    private final Map<Integer, RemoteTask> tasks = new ConcurrentHashMap<Integer, RemoteTask>();
    @GuardedBy(value="this")
    private final Set<TaskId> allTasks = new HashSet<TaskId>();
    private final Set<TaskId> finishedTasks = ConcurrentHashMap.newKeySet();
    private final Set<TaskId> flushingTasks = ConcurrentHashMap.newKeySet();
    @GuardedBy(value="this")
    private final Multimap<PlanFragmentId, RemoteTask> sourceTasks = HashMultimap.create();
    @GuardedBy(value="this")
    private final Set<PlanFragmentId> completeSourceFragments = new HashSet<PlanFragmentId>();
    @GuardedBy(value="this")
    private final Set<PlanNodeId> completeSources = new HashSet<PlanNodeId>();

    public static PipelinedStageExecution createPipelinedStageExecution(SqlStage stage, Map<PlanFragmentId, PipelinedOutputBufferManager> outputBufferManagers, TaskLifecycleListener taskLifecycleListener, FailureDetector failureDetector, Executor executor, Optional<int[]> bucketToPartition, int attempt) {
        PipelinedStageStateMachine stateMachine = new PipelinedStageStateMachine(stage.getStageId(), executor);
        ImmutableMap.Builder exchangeSources = ImmutableMap.builder();
        for (RemoteSourceNode remoteSourceNode : stage.getFragment().getRemoteSourceNodes()) {
            for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) {
                exchangeSources.put((Object)planFragmentId, (Object)remoteSourceNode);
            }
        }
        PipelinedStageExecution execution = new PipelinedStageExecution(stateMachine, stage, outputBufferManagers, taskLifecycleListener, failureDetector, bucketToPartition, (Map<PlanFragmentId, RemoteSourceNode>)exchangeSources.buildOrThrow(), attempt);
        execution.initialize();
        return execution;
    }

    private PipelinedStageExecution(PipelinedStageStateMachine stateMachine, SqlStage stage, Map<PlanFragmentId, PipelinedOutputBufferManager> outputBufferManagers, TaskLifecycleListener taskLifecycleListener, FailureDetector failureDetector, Optional<int[]> bucketToPartition, Map<PlanFragmentId, RemoteSourceNode> exchangeSources, int attempt) {
        this.stateMachine = Objects.requireNonNull(stateMachine, "stateMachine is null");
        this.stage = Objects.requireNonNull(stage, "stage is null");
        this.outputBufferManagers = ImmutableMap.copyOf(Objects.requireNonNull(outputBufferManagers, "outputBufferManagers is null"));
        this.taskLifecycleListener = Objects.requireNonNull(taskLifecycleListener, "taskLifecycleListener is null");
        this.failureDetector = Objects.requireNonNull(failureDetector, "failureDetector is null");
        this.bucketToPartition = Objects.requireNonNull(bucketToPartition, "bucketToPartition is null");
        this.exchangeSources = ImmutableMap.copyOf(Objects.requireNonNull(exchangeSources, "exchangeSources is null"));
        this.attempt = attempt;
    }

    private void initialize() {
        this.stateMachine.addStateChangeListener(state -> {
            if (!state.canScheduleMoreTasks()) {
                this.taskLifecycleListener.noMoreTasks(this.stage.getFragment().getId());
                this.updateSourceTasksOutputBuffers(PipelinedOutputBufferManager::noMoreBuffers);
            }
        });
    }

    @Override
    public StageExecution.State getState() {
        return this.stateMachine.getState();
    }

    @Override
    public void addStateChangeListener(StateMachine.StateChangeListener<StageExecution.State> stateChangeListener) {
        this.stateMachine.addStateChangeListener(stateChangeListener);
    }

    @Override
    public void beginScheduling() {
        this.stateMachine.transitionToScheduling();
    }

    @Override
    public void transitionToSchedulingSplits() {
        this.stateMachine.transitionToSchedulingSplits();
    }

    @Override
    public void schedulingComplete() {
        if (!this.stateMachine.transitionToScheduled()) {
            return;
        }
        if (this.isStageFlushing()) {
            this.stateMachine.transitionToFlushing();
        }
        if (this.isStageFinished()) {
            this.stateMachine.transitionToFinished();
        }
        for (PlanNodeId partitionedSource : this.stage.getFragment().getPartitionedSources()) {
            this.schedulingComplete(partitionedSource);
        }
    }

    @Override
    public synchronized void schedulingComplete(PlanNodeId partitionedSource) {
        for (RemoteTask task : this.getAllTasks()) {
            task.noMoreSplits(partitionedSource);
        }
        this.completeSources.add(partitionedSource);
    }

    @Override
    public synchronized void cancel() {
        if (this.stateMachine.transitionToCanceled()) {
            this.tasks.values().forEach(RemoteTask::cancel);
        }
    }

    @Override
    public synchronized void abort() {
        this.stateMachine.transitionToAborted();
        this.tasks.values().forEach(RemoteTask::abort);
    }

    public synchronized void fail(Throwable failureCause) {
        this.stateMachine.transitionToFailed(failureCause);
        this.tasks.values().forEach(RemoteTask::abort);
    }

    @Override
    public synchronized void failTask(TaskId taskId, Throwable failureCause) {
        RemoteTask task = Objects.requireNonNull(this.tasks.get(taskId.getPartitionId()), () -> "task not found: " + taskId);
        task.failLocallyImmediately(failureCause);
        this.fail(failureCause);
    }

    @Override
    public synchronized Optional<RemoteTask> scheduleTask(InternalNode node, int partition, Multimap<PlanNodeId, Split> initialSplits) {
        if (this.stateMachine.getState().isDone()) {
            return Optional.empty();
        }
        Preconditions.checkArgument((!this.tasks.containsKey(partition) ? 1 : 0) != 0, (String)"A task for partition %s already exists", (int)partition);
        PipelinedOutputBuffers outputBuffers = this.outputBufferManagers.get(this.stage.getFragment().getId()).getOutputBuffers();
        Optional<RemoteTask> optionalTask = this.stage.createTask(node, partition, this.attempt, this.bucketToPartition, outputBuffers, initialSplits, (Set<PlanNodeId>)ImmutableSet.of(), Optional.empty());
        if (optionalTask.isEmpty()) {
            return Optional.empty();
        }
        RemoteTask task = optionalTask.get();
        this.tasks.put(partition, task);
        ImmutableMultimap.Builder exchangeSplits = ImmutableMultimap.builder();
        this.sourceTasks.forEach((fragmentId, sourceTask) -> {
            TaskStatus status = sourceTask.getTaskStatus();
            if (status.getState() != TaskState.FINISHED) {
                PlanNodeId planNodeId = this.exchangeSources.get(fragmentId).getId();
                exchangeSplits.put((Object)planNodeId, (Object)PipelinedStageExecution.createExchangeSplit(sourceTask, task));
            }
        });
        this.allTasks.add(task.getTaskId());
        task.addSplits((Multimap<PlanNodeId, Split>)exchangeSplits.build());
        this.completeSources.forEach(task::noMoreSplits);
        task.addStateChangeListener(this::updateTaskStatus);
        task.start();
        this.taskLifecycleListener.taskCreated(this.stage.getFragment().getId(), task);
        PipelinedOutputBuffers.OutputBufferId outputBufferId = new PipelinedOutputBuffers.OutputBufferId(task.getTaskId().getPartitionId());
        this.updateSourceTasksOutputBuffers(outputBufferManager -> outputBufferManager.addOutputBuffer(outputBufferId));
        return Optional.of(task);
    }

    private void updateTaskStatus(TaskStatus taskStatus) {
        if (this.stateMachine.getState().isDone()) {
            return;
        }
        boolean newFlushingOrFinishedTaskObserved = false;
        TaskState taskState = taskStatus.getState();
        switch (taskState) {
            case FAILING: 
            case FAILED: {
                RuntimeException failure = taskStatus.getFailures().stream().findFirst().map(this::rewriteTransportFailure).map(ExecutionFailureInfo::toException).orElseGet(() -> new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason"));
                this.fail(failure);
                break;
            }
            case CANCELING: 
            case CANCELED: 
            case ABORTING: 
            case ABORTED: {
                this.fail(new TrinoException((ErrorCodeSupplier)StandardErrorCode.GENERIC_INTERNAL_ERROR, String.format("A task is in the %s state but stage is %s", new Object[]{taskState, this.stateMachine.getState()})));
                break;
            }
            case FLUSHING: {
                newFlushingOrFinishedTaskObserved = this.addFlushingTask(taskStatus.getTaskId());
                break;
            }
            case FINISHED: {
                newFlushingOrFinishedTaskObserved = this.addFinishedTask(taskStatus.getTaskId());
                break;
            }
        }
        StageExecution.State stageState = this.stateMachine.getState();
        if (stageState == StageExecution.State.SCHEDULED || stageState == StageExecution.State.RUNNING || stageState == StageExecution.State.FLUSHING) {
            if (taskState == TaskState.RUNNING) {
                this.stateMachine.transitionToRunning();
            }
            if (newFlushingOrFinishedTaskObserved) {
                if (this.isStageFlushing()) {
                    this.stateMachine.transitionToFlushing();
                }
                if (this.isStageFinished()) {
                    this.stateMachine.transitionToFinished();
                }
            }
        }
    }

    private synchronized boolean isStageFlushing() {
        return !this.flushingTasks.isEmpty() && this.allTasks.size() == this.finishedTasks.size() + this.flushingTasks.size();
    }

    private synchronized boolean isStageFinished() {
        boolean finished;
        boolean bl = finished = this.finishedTasks.size() == this.allTasks.size();
        if (finished) {
            Preconditions.checkState((boolean)this.finishedTasks.containsAll(this.allTasks), (Object)"Finished tasks should contain all tasks");
        }
        return finished;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean addFlushingTask(TaskId taskId) {
        if (!this.flushingTasks.contains(taskId) && !this.finishedTasks.contains(taskId)) {
            PipelinedStageExecution pipelinedStageExecution = this;
            synchronized (pipelinedStageExecution) {
                if (!this.finishedTasks.contains(taskId)) {
                    return this.flushingTasks.add(taskId);
                }
            }
        }
        return false;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean addFinishedTask(TaskId taskId) {
        if (!this.finishedTasks.contains(taskId)) {
            PipelinedStageExecution pipelinedStageExecution = this;
            synchronized (pipelinedStageExecution) {
                boolean added = this.finishedTasks.add(taskId);
                this.flushingTasks.remove(taskId);
                return added;
            }
        }
        return false;
    }

    private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) {
        if (executionFailureInfo.getRemoteHost() == null || this.failureDetector.getState(executionFailureInfo.getRemoteHost()) != FailureDetector.State.GONE) {
            return executionFailureInfo;
        }
        return new ExecutionFailureInfo(executionFailureInfo.getType(), executionFailureInfo.getMessage(), executionFailureInfo.getCause(), executionFailureInfo.getSuppressed(), executionFailureInfo.getStack(), executionFailureInfo.getErrorLocation(), StandardErrorCode.REMOTE_HOST_GONE.toErrorCode(), executionFailureInfo.getRemoteHost());
    }

    @Override
    public TaskLifecycleListener getTaskLifecycleListener() {
        return new TaskLifecycleListener(){

            @Override
            public void taskCreated(PlanFragmentId fragmentId, RemoteTask task) {
                PipelinedStageExecution.this.sourceTaskCreated(fragmentId, task);
            }

            @Override
            public void noMoreTasks(PlanFragmentId fragmentId) {
                PipelinedStageExecution.this.noMoreSourceTasks(fragmentId);
            }
        };
    }

    private synchronized void sourceTaskCreated(PlanFragmentId fragmentId, RemoteTask sourceTask) {
        Objects.requireNonNull(fragmentId, "fragmentId is null");
        RemoteSourceNode remoteSource = this.exchangeSources.get(fragmentId);
        Preconditions.checkArgument((remoteSource != null ? 1 : 0) != 0, (String)"Unknown remote source %s. Known sources are %s", (Object)fragmentId, this.exchangeSources.keySet());
        this.sourceTasks.put((Object)fragmentId, (Object)sourceTask);
        PipelinedOutputBufferManager outputBufferManager = this.outputBufferManagers.get(fragmentId);
        sourceTask.setOutputBuffers(outputBufferManager.getOutputBuffers());
        for (RemoteTask destinationTask : this.getAllTasks()) {
            destinationTask.addSplits((Multimap<PlanNodeId, Split>)ImmutableMultimap.of((Object)remoteSource.getId(), (Object)PipelinedStageExecution.createExchangeSplit(sourceTask, destinationTask)));
        }
    }

    private synchronized void noMoreSourceTasks(PlanFragmentId fragmentId) {
        RemoteSourceNode remoteSource = this.exchangeSources.get(fragmentId);
        Preconditions.checkArgument((remoteSource != null ? 1 : 0) != 0, (String)"Unknown remote source %s. Known sources are %s", (Object)fragmentId, this.exchangeSources.keySet());
        this.completeSourceFragments.add(fragmentId);
        if (this.completeSourceFragments.containsAll(remoteSource.getSourceFragmentIds())) {
            this.completeSources.add(remoteSource.getId());
            for (RemoteTask task : this.getAllTasks()) {
                task.noMoreSplits(remoteSource.getId());
            }
        }
    }

    private synchronized void updateSourceTasksOutputBuffers(Consumer<PipelinedOutputBufferManager> updater) {
        for (PlanFragmentId sourceFragment : this.exchangeSources.keySet()) {
            PipelinedOutputBufferManager outputBufferManager = this.outputBufferManagers.get(sourceFragment);
            updater.accept(outputBufferManager);
            for (RemoteTask sourceTask : this.sourceTasks.get((Object)sourceFragment)) {
                sourceTask.setOutputBuffers(outputBufferManager.getOutputBuffers());
            }
        }
    }

    @Override
    public List<RemoteTask> getAllTasks() {
        return ImmutableList.copyOf(this.tasks.values());
    }

    @Override
    public List<TaskStatus> getTaskStatuses() {
        return (List)this.getAllTasks().stream().map(RemoteTask::getTaskStatus).collect(ImmutableList.toImmutableList());
    }

    @Override
    public boolean isAnyTaskBlocked() {
        return this.getTaskStatuses().stream().map(TaskStatus::getOutputBufferStatus).anyMatch(OutputBufferStatus::isOverutilized);
    }

    @Override
    public void recordGetSplitTime(long start) {
        this.stage.recordGetSplitTime(start);
    }

    @Override
    public StageId getStageId() {
        return this.stage.getStageId();
    }

    @Override
    public int getAttemptId() {
        return this.attempt;
    }

    @Override
    public PlanFragment getFragment() {
        return this.stage.getFragment();
    }

    @Override
    public Optional<ExecutionFailureInfo> getFailureCause() {
        return this.stateMachine.getFailureCause();
    }

    public String toString() {
        return this.stateMachine.toString();
    }

    private static Split createExchangeSplit(RemoteTask sourceTask, RemoteTask destinationTask) {
        URI exchangeLocation = sourceTask.getTaskStatus().getSelf();
        URI splitLocation = HttpUriBuilder.uriBuilderFrom((URI)exchangeLocation).appendPath("results").appendPath(String.valueOf(destinationTask.getTaskId().getPartitionId())).build();
        return new Split(ExchangeOperator.REMOTE_CATALOG_HANDLE, new RemoteSplit(new DirectExchangeInput(sourceTask.getTaskId(), splitLocation.toString())));
    }

    private static class PipelinedStageStateMachine {
        private static final Set<StageExecution.State> TERMINAL_STAGE_STATES = (Set)Stream.of(StageExecution.State.values()).filter(StageExecution.State::isDone).collect(ImmutableSet.toImmutableSet());
        private final StageId stageId;
        private final StateMachine<StageExecution.State> state;
        private final AtomicReference<ExecutionFailureInfo> failureCause = new AtomicReference();

        private PipelinedStageStateMachine(StageId stageId, Executor executor) {
            this.stageId = Objects.requireNonNull(stageId, "stageId is null");
            this.state = new StateMachine<StageExecution.State>("Pipelined stage execution " + stageId, executor, StageExecution.State.PLANNED, TERMINAL_STAGE_STATES);
            this.state.addStateChangeListener((T state) -> log.debug("Pipelined stage execution %s is %s", new Object[]{stageId, state}));
        }

        public StageExecution.State getState() {
            return this.state.get();
        }

        public boolean transitionToScheduling() {
            return this.state.compareAndSet(StageExecution.State.PLANNED, StageExecution.State.SCHEDULING);
        }

        public boolean transitionToSchedulingSplits() {
            return this.state.setIf(StageExecution.State.SCHEDULING_SPLITS, currentState -> currentState == StageExecution.State.PLANNED || currentState == StageExecution.State.SCHEDULING);
        }

        public boolean transitionToScheduled() {
            return this.state.setIf(StageExecution.State.SCHEDULED, currentState -> currentState == StageExecution.State.PLANNED || currentState == StageExecution.State.SCHEDULING || currentState == StageExecution.State.SCHEDULING_SPLITS);
        }

        public boolean transitionToRunning() {
            return this.state.setIf(StageExecution.State.RUNNING, currentState -> currentState != StageExecution.State.RUNNING && currentState != StageExecution.State.FLUSHING && !currentState.isDone());
        }

        public boolean transitionToFlushing() {
            return this.state.setIf(StageExecution.State.FLUSHING, currentState -> currentState != StageExecution.State.FLUSHING && !currentState.isDone());
        }

        public boolean transitionToFinished() {
            return this.state.setIf(StageExecution.State.FINISHED, currentState -> !currentState.isDone());
        }

        public boolean transitionToCanceled() {
            return this.state.setIf(StageExecution.State.CANCELED, currentState -> !currentState.isDone());
        }

        public boolean transitionToAborted() {
            return this.state.setIf(StageExecution.State.ABORTED, currentState -> !currentState.isDone());
        }

        public boolean transitionToFailed(Throwable throwable) {
            Objects.requireNonNull(throwable, "throwable is null");
            this.failureCause.compareAndSet(null, Failures.toFailure(throwable));
            boolean failed = this.state.setIf(StageExecution.State.FAILED, currentState -> !currentState.isDone());
            if (failed) {
                log.error(throwable, "Pipelined stage execution for stage %s failed", new Object[]{this.stageId});
            } else {
                log.debug(throwable, "Failure in pipelined stage execution for stage %s after finished", new Object[]{this.stageId});
            }
            return failed;
        }

        public Optional<ExecutionFailureInfo> getFailureCause() {
            return Optional.ofNullable(this.failureCause.get());
        }

        public void addStateChangeListener(StateMachine.StateChangeListener<StageExecution.State> stateChangeListener) {
            this.state.addStateChangeListener(stateChangeListener);
        }

        public String toString() {
            return MoreObjects.toStringHelper((Object)this).add("stageId", (Object)this.stageId).add("state", this.state).toString();
        }
    }
}

