/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.graph.Traverser;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.Tracer;
import io.trino.Session;
import io.trino.execution.BasicStageStats;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.QueryStateMachine;
import io.trino.execution.RemoteTaskFactory;
import io.trino.execution.SqlStage;
import io.trino.execution.StageId;
import io.trino.execution.StageInfo;
import io.trino.execution.TableInfo;
import io.trino.execution.TaskId;
import io.trino.execution.scheduler.SplitSchedulerStats;
import io.trino.metadata.Metadata;
import io.trino.spi.QueryId;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.SubPlan;
import io.trino.sql.planner.plan.PlanFragmentId;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

class StageManager {
    private final QueryStateMachine queryStateMachine;
    private final Map<StageId, SqlStage> stages;
    private final List<SqlStage> stagesInTopologicalOrder;
    private final List<SqlStage> coordinatorStagesInTopologicalOrder;
    private final List<SqlStage> distributedStagesInTopologicalOrder;
    private final StageId rootStageId;
    private final Map<StageId, Set<StageId>> children;
    private final Map<StageId, StageId> parents;

    static StageManager create(QueryStateMachine queryStateMachine, Metadata metadata, RemoteTaskFactory taskFactory, NodeTaskMap nodeTaskMap, Tracer tracer, Span schedulerSpan, SplitSchedulerStats schedulerStats, SubPlan planTree, boolean summarizeTaskInfo) {
        Session session = queryStateMachine.getSession();
        ImmutableMap.Builder stages = ImmutableMap.builder();
        ImmutableList.Builder stagesInTopologicalOrder = ImmutableList.builder();
        ImmutableList.Builder coordinatorStagesInTopologicalOrder = ImmutableList.builder();
        ImmutableList.Builder distributedStagesInTopologicalOrder = ImmutableList.builder();
        StageId rootStageId = null;
        ImmutableMap.Builder children = ImmutableMap.builder();
        ImmutableMap.Builder parents = ImmutableMap.builder();
        for (SubPlan planNode : Traverser.forTree(SubPlan::getChildren).breadthFirst((Object)planTree)) {
            PlanFragment fragment = planNode.getFragment();
            SqlStage stage = SqlStage.createSqlStage(StageManager.getStageId(session.getQueryId(), fragment.getId()), fragment, TableInfo.extract(session, metadata, fragment), taskFactory, session, summarizeTaskInfo, nodeTaskMap, queryStateMachine.getStateMachineExecutor(), tracer, schedulerSpan, schedulerStats);
            StageId stageId = stage.getStageId();
            stages.put((Object)stageId, (Object)stage);
            stagesInTopologicalOrder.add((Object)stage);
            if (fragment.getPartitioning().isCoordinatorOnly()) {
                coordinatorStagesInTopologicalOrder.add((Object)stage);
            } else {
                distributedStagesInTopologicalOrder.add((Object)stage);
            }
            if (rootStageId == null) {
                rootStageId = stageId;
            }
            Set childStageIds = (Set)planNode.getChildren().stream().map(childStage -> StageManager.getStageId(session.getQueryId(), childStage.getFragment().getId())).collect(ImmutableSet.toImmutableSet());
            children.put((Object)stageId, (Object)childStageIds);
            childStageIds.forEach(child -> parents.put(child, (Object)stageId));
        }
        StageManager stageManager = new StageManager(queryStateMachine, (Map<StageId, SqlStage>)stages.buildOrThrow(), (List<SqlStage>)stagesInTopologicalOrder.build(), (List<SqlStage>)coordinatorStagesInTopologicalOrder.build(), (List<SqlStage>)distributedStagesInTopologicalOrder.build(), rootStageId, (Map<StageId, Set<StageId>>)children.buildOrThrow(), (Map<StageId, StageId>)parents.buildOrThrow());
        stageManager.initialize();
        return stageManager;
    }

    private static StageId getStageId(QueryId queryId, PlanFragmentId fragmentId) {
        return new StageId(queryId, Integer.parseInt(fragmentId.toString()));
    }

    private StageManager(QueryStateMachine queryStateMachine, Map<StageId, SqlStage> stages, List<SqlStage> stagesInTopologicalOrder, List<SqlStage> coordinatorStagesInTopologicalOrder, List<SqlStage> distributedStagesInTopologicalOrder, StageId rootStageId, Map<StageId, Set<StageId>> children, Map<StageId, StageId> parents) {
        this.queryStateMachine = Objects.requireNonNull(queryStateMachine, "queryStateMachine is null");
        this.stages = ImmutableMap.copyOf(Objects.requireNonNull(stages, "stages is null"));
        this.stagesInTopologicalOrder = ImmutableList.copyOf((Collection)Objects.requireNonNull(stagesInTopologicalOrder, "stagesInTopologicalOrder is null"));
        this.coordinatorStagesInTopologicalOrder = ImmutableList.copyOf((Collection)Objects.requireNonNull(coordinatorStagesInTopologicalOrder, "coordinatorStagesInTopologicalOrder is null"));
        this.distributedStagesInTopologicalOrder = ImmutableList.copyOf((Collection)Objects.requireNonNull(distributedStagesInTopologicalOrder, "distributedStagesInTopologicalOrder is null"));
        this.rootStageId = Objects.requireNonNull(rootStageId, "rootStageId is null");
        this.children = ImmutableMap.copyOf(Objects.requireNonNull(children, "children is null"));
        this.parents = ImmutableMap.copyOf(Objects.requireNonNull(parents, "parents is null"));
    }

    private void initialize() {
        for (SqlStage stage : this.stages.values()) {
            stage.addFinalStageInfoListener(status -> this.queryStateMachine.updateQueryInfo(Optional.ofNullable(this.getStageInfo())));
        }
    }

    public void finish() {
        this.stages.values().forEach(SqlStage::finish);
    }

    public void abort() {
        this.stages.values().forEach(SqlStage::abort);
    }

    public void failTaskRemotely(TaskId taskId, Throwable failureCause) {
        SqlStage sqlStage = Objects.requireNonNull(this.stages.get(taskId.getStageId()), () -> "stage not found: %s" + String.valueOf(taskId.getStageId()));
        sqlStage.failTaskRemotely(taskId, failureCause);
    }

    public List<SqlStage> getStagesInTopologicalOrder() {
        return this.stagesInTopologicalOrder;
    }

    public List<SqlStage> getCoordinatorStagesInTopologicalOrder() {
        return this.coordinatorStagesInTopologicalOrder;
    }

    public List<SqlStage> getDistributedStagesInTopologicalOrder() {
        return this.distributedStagesInTopologicalOrder;
    }

    public SqlStage getOutputStage() {
        return this.stages.get(this.rootStageId);
    }

    public SqlStage get(PlanFragmentId fragmentId) {
        return this.get(StageManager.getStageId(this.queryStateMachine.getQueryId(), fragmentId));
    }

    public SqlStage get(StageId stageId) {
        return Objects.requireNonNull(this.stages.get(stageId), () -> "stage not found: " + String.valueOf(stageId));
    }

    public Set<SqlStage> getChildren(PlanFragmentId fragmentId) {
        return this.getChildren(StageManager.getStageId(this.queryStateMachine.getQueryId(), fragmentId));
    }

    public Set<SqlStage> getChildren(StageId stageId) {
        return (Set)this.children.get(stageId).stream().map(this::get).collect(ImmutableSet.toImmutableSet());
    }

    public Optional<SqlStage> getParent(PlanFragmentId fragmentId) {
        return this.getParent(StageManager.getStageId(this.queryStateMachine.getQueryId(), fragmentId));
    }

    public Optional<SqlStage> getParent(StageId stageId) {
        return Optional.ofNullable(this.parents.get(stageId)).map(this.stages::get);
    }

    public BasicStageStats getBasicStageStats() {
        List stageStats = (List)this.stages.values().stream().map(SqlStage::getBasicStageStats).collect(ImmutableList.toImmutableList());
        return BasicStageStats.aggregateBasicStageStats(stageStats);
    }

    public StageInfo getStageInfo() {
        Map stageInfos = (Map)this.stages.values().stream().map(SqlStage::getStageInfo).collect(ImmutableMap.toImmutableMap(StageInfo::getStageId, Function.identity()));
        return this.buildStageInfo(this.rootStageId, stageInfos);
    }

    private StageInfo buildStageInfo(StageId stageId, Map<StageId, StageInfo> stageInfos) {
        StageInfo parent = stageInfos.get(stageId);
        Preconditions.checkArgument((parent != null ? 1 : 0) != 0, (String)"No stageInfo for %s", (Object)parent);
        List childStages = (List)this.children.get(stageId).stream().map(childStageId -> this.buildStageInfo((StageId)childStageId, stageInfos)).collect(ImmutableList.toImmutableList());
        if (childStages.isEmpty()) {
            return parent;
        }
        return new StageInfo(parent.getStageId(), parent.getState(), parent.getPlan(), parent.isCoordinatorOnly(), parent.getTypes(), parent.getStageStats(), parent.getTasks(), childStages, parent.getTables(), parent.getFailureCause());
    }

    public long getUserMemoryReservation() {
        return this.stages.values().stream().mapToLong(SqlStage::getUserMemoryReservation).sum();
    }

    public long getTotalMemoryReservation() {
        return this.stages.values().stream().mapToLong(SqlStage::getTotalMemoryReservation).sum();
    }

    public Duration getTotalCpuTime() {
        long millis = this.stages.values().stream().mapToLong(stage -> stage.getTotalCpuTime().toMillis()).sum();
        return new Duration((double)millis, TimeUnit.MILLISECONDS);
    }
}

