/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.errorprone.annotations.ThreadSafe;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.airlift.log.Logger;
import io.trino.execution.StateMachine;
import io.trino.execution.TaskFailureListener;
import io.trino.execution.TaskId;
import io.trino.execution.TaskState;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import org.joda.time.DateTime;

@ThreadSafe
public class TaskStateMachine {
    private static final Logger log = Logger.get(TaskStateMachine.class);
    private final DateTime createdTime = DateTime.now();
    private final TaskId taskId;
    private final Executor executor;
    private final StateMachine<TaskState> taskState;
    private final LinkedBlockingQueue<Throwable> failureCauses = new LinkedBlockingQueue();
    @GuardedBy(value="this")
    private final Map<TaskId, Throwable> sourceTaskFailures = new HashMap<TaskId, Throwable>();
    @GuardedBy(value="this")
    private final List<TaskFailureListener> sourceTaskFailureListeners = new ArrayList<TaskFailureListener>();

    public TaskStateMachine(TaskId taskId, Executor executor) {
        this.taskId = Objects.requireNonNull(taskId, "taskId is null");
        this.executor = Objects.requireNonNull(executor, "executor is null");
        this.taskState = new StateMachine<TaskState>("task " + String.valueOf(taskId), executor, TaskState.RUNNING, TaskState.TERMINAL_TASK_STATES);
        this.taskState.addStateChangeListener((T newState) -> log.debug("Task %s is %s", new Object[]{taskId, newState}));
    }

    public DateTime getCreatedTime() {
        return this.createdTime;
    }

    public TaskId getTaskId() {
        return this.taskId;
    }

    public TaskState getState() {
        return this.taskState.get();
    }

    public ListenableFuture<TaskState> getStateChange(TaskState currentState) {
        Objects.requireNonNull(currentState, "currentState is null");
        Preconditions.checkArgument((!currentState.isDone() ? 1 : 0) != 0, (Object)"Current state is already done");
        ListenableFuture<TaskState> future = this.taskState.getStateChange(currentState);
        TaskState state = this.taskState.get();
        if (state.isDone()) {
            return Futures.immediateFuture((Object)((Object)state));
        }
        return future;
    }

    public LinkedBlockingQueue<Throwable> getFailureCauses() {
        return this.failureCauses;
    }

    public void transitionToFlushing() {
        this.taskState.setIf(TaskState.FLUSHING, currentState -> currentState == TaskState.RUNNING);
    }

    public void finished() {
        this.taskState.setIf(TaskState.FINISHED, currentState -> !currentState.isTerminatingOrDone());
    }

    public void cancel() {
        this.startTermination(TaskState.CANCELING);
    }

    public void abort() {
        this.startTermination(TaskState.ABORTING);
    }

    public void failed(Throwable cause) {
        this.failureCauses.add(cause);
        this.startTermination(TaskState.FAILING);
    }

    public void terminationComplete() {
        TaskState currentState = this.taskState.get();
        if (currentState.isDone()) {
            return;
        }
        Preconditions.checkState((boolean)currentState.isTerminating(), (String)"current state %s is not a terminating state", (Object)((Object)currentState));
        TaskState newState = switch (currentState) {
            case TaskState.CANCELING -> TaskState.CANCELED;
            case TaskState.ABORTING -> TaskState.ABORTED;
            case TaskState.FAILING -> TaskState.FAILED;
            default -> throw new IllegalStateException("Unhandled terminating state: " + String.valueOf((Object)currentState));
        };
        this.taskState.compareAndSet(currentState, newState);
    }

    private void startTermination(TaskState terminatingState) {
        Objects.requireNonNull(terminatingState, "terminatingState is null");
        Preconditions.checkArgument((boolean)terminatingState.isTerminating(), (String)"terminatingState %s is not a terminating state", (Object)((Object)terminatingState));
        this.taskState.setIf(terminatingState, currentState -> !currentState.isTerminatingOrDone());
    }

    public void addStateChangeListener(StateMachine.StateChangeListener<TaskState> stateChangeListener) {
        this.taskState.addStateChangeListener(stateChangeListener);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void addSourceTaskFailureListener(TaskFailureListener listener) {
        ImmutableMap failures;
        TaskStateMachine taskStateMachine = this;
        synchronized (taskStateMachine) {
            this.sourceTaskFailureListeners.add(listener);
            failures = ImmutableMap.copyOf(this.sourceTaskFailures);
        }
        this.executor.execute(() -> TaskStateMachine.lambda$addSourceTaskFailureListener$0((Map)failures, listener));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void sourceTaskFailed(TaskId taskId, Throwable failure) {
        ImmutableList listeners;
        TaskStateMachine taskStateMachine = this;
        synchronized (taskStateMachine) {
            this.sourceTaskFailures.putIfAbsent(taskId, failure);
            listeners = ImmutableList.copyOf(this.sourceTaskFailureListeners);
        }
        this.executor.execute(() -> TaskStateMachine.lambda$sourceTaskFailed$0((List)listeners, taskId, failure));
    }

    public String toString() {
        return MoreObjects.toStringHelper((Object)this).add("taskId", (Object)this.taskId).add("taskState", this.taskState).add("failureCauses", this.failureCauses).toString();
    }

    private static /* synthetic */ void lambda$sourceTaskFailed$0(List listeners, TaskId taskId, Throwable failure) {
        for (TaskFailureListener listener : listeners) {
            listener.onTaskFailed(taskId, failure);
        }
    }

    private static /* synthetic */ void lambda$addSourceTaskFailureListener$0(Map failures, TaskFailureListener listener) {
        failures.forEach(listener::onTaskFailed);
    }
}

