/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import io.airlift.units.Duration;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.execution.BasicStageStats;
import io.trino.execution.Lifespan;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.RemoteTask;
import io.trino.execution.RemoteTaskFactory;
import io.trino.execution.StageId;
import io.trino.execution.StageInfo;
import io.trino.execution.StageStateMachine;
import io.trino.execution.StateMachine;
import io.trino.execution.TableInfo;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.execution.TaskStatus;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.server.DynamicFilterService;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

@ThreadSafe
public final class SqlStage {
    private final Session session;
    private final StageStateMachine stateMachine;
    private final RemoteTaskFactory remoteTaskFactory;
    private final NodeTaskMap nodeTaskMap;
    private final boolean summarizeTaskInfo;
    private final Set<DynamicFilterId> outboundDynamicFilterIds;
    private final Map<TaskId, RemoteTask> tasks = new ConcurrentHashMap<TaskId, RemoteTask>();
    @GuardedBy(value="this")
    private final Set<TaskId> allTasks = new HashSet<TaskId>();
    @GuardedBy(value="this")
    private final Set<TaskId> finishedTasks = new HashSet<TaskId>();
    @GuardedBy(value="this")
    private final Set<TaskId> tasksWithFinalInfo = new HashSet<TaskId>();

    public static SqlStage createSqlStage(StageId stageId, PlanFragment fragment, Map<PlanNodeId, TableInfo> tables, RemoteTaskFactory remoteTaskFactory, Session session, boolean summarizeTaskInfo, NodeTaskMap nodeTaskMap, Executor executor, SplitSchedulerStats schedulerStats) {
        Objects.requireNonNull(stageId, "stageId is null");
        Objects.requireNonNull(fragment, "fragment is null");
        Preconditions.checkArgument((boolean)fragment.getPartitioningScheme().getBucketToPartition().isEmpty(), (Object)"bucket to partition is not expected to be set at this point");
        Objects.requireNonNull(tables, "tables is null");
        Objects.requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
        Objects.requireNonNull(session, "session is null");
        Objects.requireNonNull(nodeTaskMap, "nodeTaskMap is null");
        Objects.requireNonNull(executor, "executor is null");
        Objects.requireNonNull(schedulerStats, "schedulerStats is null");
        SqlStage sqlStage = new SqlStage(session, new StageStateMachine(stageId, fragment, tables, executor, schedulerStats), remoteTaskFactory, nodeTaskMap, summarizeTaskInfo);
        sqlStage.initialize();
        return sqlStage;
    }

    private SqlStage(Session session, StageStateMachine stateMachine, RemoteTaskFactory remoteTaskFactory, NodeTaskMap nodeTaskMap, boolean summarizeTaskInfo) {
        this.session = Objects.requireNonNull(session, "session is null");
        this.stateMachine = stateMachine;
        this.remoteTaskFactory = Objects.requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
        this.nodeTaskMap = Objects.requireNonNull(nodeTaskMap, "nodeTaskMap is null");
        this.summarizeTaskInfo = summarizeTaskInfo;
        this.outboundDynamicFilterIds = SystemSessionProperties.isEnableCoordinatorDynamicFiltersDistribution(session) ? DynamicFilterService.getOutboundDynamicFilters(stateMachine.getFragment()) : ImmutableSet.of();
    }

    private void initialize() {
        this.stateMachine.addStateChangeListener(newState -> this.checkAllTaskFinal());
    }

    public StageId getStageId() {
        return this.stateMachine.getStageId();
    }

    public synchronized void finish() {
        this.stateMachine.transitionToFinished();
        this.tasks.values().forEach(RemoteTask::cancel);
    }

    public synchronized void abort() {
        this.stateMachine.transitionToAborted();
        this.tasks.values().forEach(RemoteTask::abort);
    }

    public synchronized void fail(Throwable throwable) {
        Objects.requireNonNull(throwable, "throwable is null");
        this.stateMachine.transitionToFailed(throwable);
        this.tasks.values().forEach(RemoteTask::abort);
    }

    public void addFinalStageInfoListener(StateMachine.StateChangeListener<StageInfo> stateChangeListener) {
        this.stateMachine.addFinalStageInfoListener(stateChangeListener);
    }

    public PlanFragment getFragment() {
        return this.stateMachine.getFragment();
    }

    public long getUserMemoryReservation() {
        return this.stateMachine.getUserMemoryReservation();
    }

    public long getTotalMemoryReservation() {
        return this.stateMachine.getTotalMemoryReservation();
    }

    public Duration getTotalCpuTime() {
        long millis = this.tasks.values().stream().mapToLong(task -> task.getTaskInfo().getStats().getTotalCpuTime().toMillis()).sum();
        return new Duration((double)millis, TimeUnit.MILLISECONDS);
    }

    public BasicStageStats getBasicStageStats() {
        return this.stateMachine.getBasicStageStats(this::getAllTaskInfo);
    }

    public StageInfo getStageInfo() {
        return this.stateMachine.getStageInfo(this::getAllTaskInfo);
    }

    private Iterable<TaskInfo> getAllTaskInfo() {
        return (Iterable)this.tasks.values().stream().map(RemoteTask::getTaskInfo).collect(ImmutableList.toImmutableList());
    }

    public synchronized Optional<RemoteTask> createTask(InternalNode node, int partition, int attempt, Optional<int[]> bucketToPartition, OutputBuffers outputBuffers, Multimap<PlanNodeId, Split> splits, Multimap<PlanNodeId, Lifespan> noMoreSplitsForLifespan, Set<PlanNodeId> noMoreSplits) {
        if (this.stateMachine.getState().isDone()) {
            return Optional.empty();
        }
        TaskId taskId = new TaskId(this.stateMachine.getStageId(), partition, attempt);
        Preconditions.checkArgument((!this.tasks.containsKey(taskId) ? 1 : 0) != 0, (String)"A task with id %s already exists", (Object)taskId);
        this.stateMachine.transitionToScheduling();
        RemoteTask task = this.remoteTaskFactory.createRemoteTask(this.session, taskId, node, this.stateMachine.getFragment().withBucketToPartition(bucketToPartition), splits, outputBuffers, this.nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), this.outboundDynamicFilterIds, this.summarizeTaskInfo);
        noMoreSplitsForLifespan.forEach(task::noMoreSplits);
        noMoreSplits.forEach(task::noMoreSplits);
        this.tasks.put(taskId, task);
        this.allTasks.add(taskId);
        this.nodeTaskMap.addTask(node, task);
        task.addStateChangeListener(this::updateTaskStatus);
        task.addStateChangeListener(new MemoryUsageListener());
        task.addFinalTaskInfoListener(this::updateFinalTaskInfo);
        return Optional.of(task);
    }

    public void recordGetSplitTime(long start) {
        this.stateMachine.recordGetSplitTime(start);
    }

    private synchronized void updateTaskStatus(TaskStatus status) {
        if (status.getState().isDone()) {
            this.finishedTasks.add(status.getTaskId());
        }
        if (!this.finishedTasks.containsAll(this.allTasks)) {
            this.stateMachine.transitionToRunning();
        } else {
            this.stateMachine.transitionToPending();
        }
    }

    private synchronized void updateFinalTaskInfo(TaskInfo finalTaskInfo) {
        this.tasksWithFinalInfo.add(finalTaskInfo.getTaskStatus().getTaskId());
        this.checkAllTaskFinal();
    }

    private synchronized void checkAllTaskFinal() {
        if (this.stateMachine.getState().isDone() && this.tasksWithFinalInfo.containsAll(this.tasks.keySet())) {
            List finalTaskInfos = (List)this.tasks.values().stream().map(RemoteTask::getTaskInfo).collect(ImmutableList.toImmutableList());
            this.stateMachine.setAllTasksFinal(finalTaskInfos);
        }
    }

    public String toString() {
        return this.stateMachine.toString();
    }

    private class MemoryUsageListener
    implements StateMachine.StateChangeListener<TaskStatus> {
        private long previousUserMemory;
        private long previousRevocableMemory;

        private MemoryUsageListener() {
        }

        @Override
        public synchronized void stateChanged(TaskStatus taskStatus) {
            long currentUserMemory = taskStatus.getMemoryReservation().toBytes();
            long currentRevocableMemory = taskStatus.getRevocableMemoryReservation().toBytes();
            long deltaUserMemoryInBytes = currentUserMemory - this.previousUserMemory;
            long deltaRevocableMemoryInBytes = currentRevocableMemory - this.previousRevocableMemory;
            long deltaTotalMemoryInBytes = currentUserMemory + currentRevocableMemory - (this.previousUserMemory + this.previousRevocableMemory);
            this.previousUserMemory = currentUserMemory;
            this.previousRevocableMemory = currentRevocableMemory;
            SqlStage.this.stateMachine.updateMemoryUsage(deltaUserMemoryInBytes, deltaRevocableMemoryInBytes, deltaTotalMemoryInBytes);
        }
    }
}

