/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multimap;
import com.google.errorprone.annotations.ThreadSafe;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.Tracer;
import io.trino.Session;
import io.trino.execution.BasicStageInfo;
import io.trino.execution.BasicStageStats;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.RemoteTask;
import io.trino.execution.RemoteTaskFactory;
import io.trino.execution.StageId;
import io.trino.execution.StageInfo;
import io.trino.execution.StageState;
import io.trino.execution.StageStateMachine;
import io.trino.execution.StateMachine;
import io.trino.execution.TableInfo;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.execution.TaskStatus;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.server.DynamicFilterService;
import io.trino.spi.metrics.Metrics;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;

@ThreadSafe
public final class SqlStage {
    private final Session session;
    private final StageStateMachine stateMachine;
    private final RemoteTaskFactory remoteTaskFactory;
    private final NodeTaskMap nodeTaskMap;
    private final boolean summarizeTaskInfo;
    private final Set<DynamicFilterId> outboundDynamicFilterIds;
    private final Map<TaskId, RemoteTask> tasks = new ConcurrentHashMap<TaskId, RemoteTask>();
    @GuardedBy(value="this")
    private final Set<TaskId> allTasks = new HashSet<TaskId>();
    @GuardedBy(value="this")
    private final Set<TaskId> finishedTasks = new HashSet<TaskId>();
    @GuardedBy(value="this")
    private final Set<TaskId> tasksWithFinalInfo = new HashSet<TaskId>();

    public static SqlStage createSqlStage(StageId stageId, PlanFragment fragment, Map<PlanNodeId, TableInfo> tables, RemoteTaskFactory remoteTaskFactory, Session session, boolean summarizeTaskInfo, NodeTaskMap nodeTaskMap, Executor stateMachineExecutor, Tracer tracer, Span schedulerSpan, SplitSchedulerStats schedulerStats) {
        Objects.requireNonNull(stageId, "stageId is null");
        Objects.requireNonNull(fragment, "fragment is null");
        Preconditions.checkArgument((boolean)fragment.getOutputPartitioningScheme().getBucketToPartition().isEmpty(), (Object)"bucket to partition is not expected to be set at this point");
        Objects.requireNonNull(tables, "tables is null");
        Objects.requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(nodeTaskMap, "nodeTaskMap is null");
        Objects.requireNonNull(stateMachineExecutor, "stateMachineExecutor is null");
        Objects.requireNonNull(tracer, "tracer is null");
        Objects.requireNonNull(schedulerStats, "schedulerStats is null");
        StageStateMachine stateMachine = new StageStateMachine(stageId, fragment, tables, stateMachineExecutor, tracer, schedulerSpan, schedulerStats);
        SqlStage sqlStage = new SqlStage(session, stateMachine, remoteTaskFactory, nodeTaskMap, summarizeTaskInfo);
        sqlStage.initialize();
        return sqlStage;
    }

    private SqlStage(Session session, StageStateMachine stateMachine, RemoteTaskFactory remoteTaskFactory, NodeTaskMap nodeTaskMap, boolean summarizeTaskInfo) {
        this.session = Objects.requireNonNull(session, "session is null");
        this.stateMachine = stateMachine;
        this.remoteTaskFactory = Objects.requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
        this.nodeTaskMap = Objects.requireNonNull(nodeTaskMap, "nodeTaskMap is null");
        this.summarizeTaskInfo = summarizeTaskInfo;
        this.outboundDynamicFilterIds = DynamicFilterService.getOutboundDynamicFilters(stateMachine.getFragment());
    }

    private void initialize() {
        this.stateMachine.addStateChangeListener(newState -> this.checkAllTaskFinal());
    }

    public StageId getStageId() {
        return this.stateMachine.getStageId();
    }

    public Span getStageSpan() {
        return this.stateMachine.getStageSpan();
    }

    public StageState getState() {
        return this.stateMachine.getState();
    }

    public synchronized void finish() {
        if (this.stateMachine.transitionToFinished()) {
            this.tasks.values().forEach(RemoteTask::cancel);
        }
    }

    public synchronized void abort() {
        if (this.stateMachine.transitionToAborted()) {
            this.tasks.values().forEach(RemoteTask::abort);
        }
    }

    public synchronized void fail(Throwable throwable) {
        Objects.requireNonNull(throwable, "throwable is null");
        if (this.stateMachine.transitionToFailed(throwable)) {
            this.tasks.values().forEach(RemoteTask::abort);
        }
    }

    public void failTaskRemotely(TaskId taskId, Throwable failureCause) {
        RemoteTask task = Objects.requireNonNull(this.tasks.get(taskId), () -> "task not found: " + String.valueOf(taskId));
        task.failRemotely(failureCause);
    }

    public void addFinalStageInfoListener(StateMachine.StateChangeListener<StageInfo> stateChangeListener) {
        this.stateMachine.addFinalStageInfoListener(stateChangeListener);
    }

    public PlanFragment getFragment() {
        return this.stateMachine.getFragment();
    }

    public long getUserMemoryReservation() {
        return this.stateMachine.getUserMemoryReservation();
    }

    public long getTotalMemoryReservation() {
        return this.stateMachine.getTotalMemoryReservation();
    }

    public Duration getTotalCpuTime() {
        long millis = this.tasks.values().stream().mapToLong(task -> task.getTaskInfo().stats().getTotalCpuTime().toMillis()).sum();
        return new Duration((double)millis, TimeUnit.MILLISECONDS);
    }

    public BasicStageStats getBasicStageStats() {
        return this.stateMachine.getBasicStageStats(this::getAllTaskInfo);
    }

    public StageInfo getStageInfo() {
        return this.stateMachine.getStageInfo(this::getAllTaskInfo);
    }

    public BasicStageInfo getBasicStageInfo() {
        return this.stateMachine.getBasicStageInfo(this::getAllTaskInfo);
    }

    private Iterable<TaskInfo> getAllTaskInfo() {
        return (Iterable)this.tasks.values().stream().map(RemoteTask::getTaskInfo).collect(ImmutableList.toImmutableList());
    }

    public synchronized Optional<RemoteTask> createTask(InternalNode node, int partition, int attempt, Optional<int[]> bucketToPartition, OutputBuffers outputBuffers, Multimap<PlanNodeId, Split> splits, Set<PlanNodeId> noMoreSplits, Optional<DataSize> estimatedMemory, boolean speculative) {
        if (this.stateMachine.getState().isDone()) {
            return Optional.empty();
        }
        TaskId taskId = new TaskId(this.stateMachine.getStageId(), partition, attempt);
        Preconditions.checkArgument((!this.tasks.containsKey(taskId) ? 1 : 0) != 0, (String)"A task with id %s already exists", (Object)taskId);
        this.stateMachine.transitionToScheduling();
        RemoteTask task = this.remoteTaskFactory.createRemoteTask(this.session, this.stateMachine.getStageSpan(), taskId, node, speculative, this.stateMachine.getFragment().withBucketToPartition(bucketToPartition), splits, outputBuffers, this.nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), this.outboundDynamicFilterIds, estimatedMemory, this.summarizeTaskInfo);
        noMoreSplits.forEach(task::noMoreSplits);
        this.tasks.put(taskId, task);
        this.allTasks.add(taskId);
        this.nodeTaskMap.addTask(node, task);
        task.addStateChangeListener(this::updateTaskStatus);
        task.addStateChangeListener(new MemoryUsageListener());
        task.addFinalTaskInfoListener(this::updateFinalTaskInfo);
        return Optional.of(task);
    }

    public void recordSplitSourceMetrics(PlanNodeId nodeId, Metrics metrics, long start) {
        this.stateMachine.recordSplitSourceMetrics(nodeId, metrics, start);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void updateTaskStatus(TaskStatus status) {
        boolean isDone = status.getState().isDone();
        if (!isDone && this.stateMachine.getState() == StageState.RUNNING) {
            return;
        }
        SqlStage sqlStage = this;
        synchronized (sqlStage) {
            if (isDone) {
                this.finishedTasks.add(status.getTaskId());
            }
            if (this.finishedTasks.size() == this.allTasks.size()) {
                this.stateMachine.transitionToPending();
            } else {
                this.stateMachine.transitionToRunning();
            }
        }
    }

    private synchronized void updateFinalTaskInfo(TaskInfo finalTaskInfo) {
        this.tasksWithFinalInfo.add(finalTaskInfo.taskStatus().getTaskId());
        this.checkAllTaskFinal();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void checkAllTaskFinal() {
        if (!this.stateMachine.getState().isDone()) {
            return;
        }
        SqlStage sqlStage = this;
        synchronized (sqlStage) {
            if (this.tasksWithFinalInfo.size() == this.allTasks.size()) {
                List finalTaskInfos = (List)this.tasks.values().stream().map(RemoteTask::getTaskInfo).collect(ImmutableList.toImmutableList());
                this.stateMachine.setAllTasksFinal(finalTaskInfos);
            }
        }
    }

    public synchronized String toString() {
        return MoreObjects.toStringHelper((Object)this).add("stateMachine", (Object)this.stateMachine).add("summarizeTaskInfo", this.summarizeTaskInfo).add("outboundDynamicFilterIds", this.outboundDynamicFilterIds).add("tasks", this.tasks).add("allTasks", this.allTasks).add("finishedTasks", this.finishedTasks).add("tasksWithFinalInfo", this.tasksWithFinalInfo).toString();
    }

    private class MemoryUsageListener
    implements StateMachine.StateChangeListener<TaskStatus> {
        private long previousUserMemory;
        private long previousRevocableMemory;
        private boolean finalUsageReported;

        private MemoryUsageListener() {
        }

        @Override
        public synchronized void stateChanged(TaskStatus taskStatus) {
            if (this.finalUsageReported) {
                return;
            }
            long currentUserMemory = taskStatus.getMemoryReservation().toBytes();
            long currentRevocableMemory = taskStatus.getRevocableMemoryReservation().toBytes();
            long deltaUserMemoryInBytes = currentUserMemory - this.previousUserMemory;
            long deltaRevocableMemoryInBytes = currentRevocableMemory - this.previousRevocableMemory;
            long deltaTotalMemoryInBytes = currentUserMemory + currentRevocableMemory - (this.previousUserMemory + this.previousRevocableMemory);
            this.previousUserMemory = currentUserMemory;
            this.previousRevocableMemory = currentRevocableMemory;
            SqlStage.this.stateMachine.updateMemoryUsage(deltaUserMemoryInBytes, deltaRevocableMemoryInBytes, deltaTotalMemoryInBytes);
            if (taskStatus.getState().isDone()) {
                SqlStage.this.stateMachine.updateMemoryUsage(-currentUserMemory, -currentRevocableMemory, -(currentUserMemory + currentRevocableMemory));
                this.previousUserMemory = 0L;
                this.previousRevocableMemory = 0L;
                this.finalUsageReported = true;
            }
        }
    }
}

