/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.execution.scheduler;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.prestosql.execution.SqlStageExecution;
import io.prestosql.execution.StageState;
import io.prestosql.execution.scheduler.ExecutionSchedule;
import io.prestosql.sql.planner.PlanFragment;
import io.prestosql.sql.planner.plan.ExchangeNode;
import io.prestosql.sql.planner.plan.IndexJoinNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanFragmentId;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanVisitor;
import io.prestosql.sql.planner.plan.RemoteSourceNode;
import io.prestosql.sql.planner.plan.SemiJoinNode;
import io.prestosql.sql.planner.plan.SpatialJoinNode;
import io.prestosql.sql.planner.plan.UnionNode;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.concurrent.NotThreadSafe;
import org.jgrapht.DirectedGraph;
import org.jgrapht.alg.StrongConnectivityInspector;
import org.jgrapht.graph.DefaultDirectedGraph;
import org.jgrapht.graph.DefaultEdge;
import org.jgrapht.traverse.TopologicalOrderIterator;

@NotThreadSafe
public class PhasedExecutionSchedule
implements ExecutionSchedule {
    private final List<Set<SqlStageExecution>> schedulePhases;
    private final Set<SqlStageExecution> activeSources = new HashSet<SqlStageExecution>();

    public PhasedExecutionSchedule(Collection<SqlStageExecution> stages) {
        List<Set<PlanFragmentId>> phases = PhasedExecutionSchedule.extractPhases((Collection)stages.stream().map(SqlStageExecution::getFragment).collect(ImmutableList.toImmutableList()));
        Map stagesByFragmentId = (Map)stages.stream().collect(ImmutableMap.toImmutableMap(stage -> stage.getFragment().getId(), Function.identity()));
        this.schedulePhases = new ArrayList<Set<SqlStageExecution>>();
        for (Set<PlanFragmentId> phase : phases) {
            this.schedulePhases.add(phase.stream().map(stagesByFragmentId::get).collect(Collectors.toCollection(HashSet::new)));
        }
    }

    @Override
    public Set<SqlStageExecution> getStagesToSchedule() {
        this.removeCompletedStages();
        this.addPhasesIfNecessary();
        if (this.isFinished()) {
            return ImmutableSet.of();
        }
        return this.activeSources;
    }

    private void removeCompletedStages() {
        Iterator<SqlStageExecution> stageIterator = this.activeSources.iterator();
        while (stageIterator.hasNext()) {
            StageState state = stageIterator.next().getState();
            if (state != StageState.SCHEDULED && state != StageState.RUNNING && !state.isDone()) continue;
            stageIterator.remove();
        }
    }

    private void addPhasesIfNecessary() {
        if (PhasedExecutionSchedule.hasSourceDistributedStage(this.activeSources)) {
            return;
        }
        while (!this.schedulePhases.isEmpty()) {
            Set<SqlStageExecution> phase = this.schedulePhases.remove(0);
            this.activeSources.addAll(phase);
            if (!PhasedExecutionSchedule.hasSourceDistributedStage(phase)) continue;
            return;
        }
    }

    private static boolean hasSourceDistributedStage(Set<SqlStageExecution> phase) {
        return phase.stream().anyMatch(stage -> !stage.getFragment().getPartitionedSources().isEmpty());
    }

    @Override
    public boolean isFinished() {
        return this.activeSources.isEmpty() && this.schedulePhases.isEmpty();
    }

    @VisibleForTesting
    static List<Set<PlanFragmentId>> extractPhases(Collection<PlanFragment> fragments) {
        DefaultDirectedGraph graph = new DefaultDirectedGraph(DefaultEdge.class);
        fragments.forEach(arg_0 -> PhasedExecutionSchedule.lambda$extractPhases$2((DirectedGraph)graph, arg_0));
        Visitor visitor = new Visitor(fragments, (DirectedGraph<PlanFragmentId, DefaultEdge>)graph);
        for (PlanFragment fragment : fragments) {
            visitor.processFragment(fragment.getId());
        }
        List components = new StrongConnectivityInspector((DirectedGraph)graph).stronglyConnectedSets();
        HashMap<PlanFragmentId, Object> componentMembership = new HashMap<PlanFragmentId, Object>();
        for (Object component : components) {
            Iterator iterator = component.iterator();
            while (iterator.hasNext()) {
                PlanFragmentId planFragmentId = (PlanFragmentId)iterator.next();
                componentMembership.put(planFragmentId, component);
            }
        }
        DefaultDirectedGraph componentGraph = new DefaultDirectedGraph(DefaultEdge.class);
        components.forEach(arg_0 -> ((DirectedGraph)componentGraph).addVertex(arg_0));
        for (DefaultEdge edge : graph.edgeSet()) {
            Set to;
            PlanFragmentId source = (PlanFragmentId)graph.getEdgeSource((Object)edge);
            PlanFragmentId target = (PlanFragmentId)graph.getEdgeTarget((Object)edge);
            Set from = (Set)componentMembership.get(source);
            if (from.equals(to = (Set)componentMembership.get(target))) continue;
            componentGraph.addEdge((Object)from, (Object)to);
        }
        ImmutableList schedulePhases = ImmutableList.copyOf((Iterator)new TopologicalOrderIterator((DirectedGraph)componentGraph));
        return schedulePhases;
    }

    private static /* synthetic */ void lambda$extractPhases$2(DirectedGraph graph, PlanFragment fragment) {
        graph.addVertex((Object)fragment.getId());
    }

    private static class Visitor
    extends PlanVisitor<Set<PlanFragmentId>, PlanFragmentId> {
        private final Map<PlanFragmentId, PlanFragment> fragments;
        private final DirectedGraph<PlanFragmentId, DefaultEdge> graph;
        private final Map<PlanFragmentId, Set<PlanFragmentId>> fragmentSources = new HashMap<PlanFragmentId, Set<PlanFragmentId>>();

        public Visitor(Collection<PlanFragment> fragments, DirectedGraph<PlanFragmentId, DefaultEdge> graph) {
            this.fragments = (Map)fragments.stream().collect(ImmutableMap.toImmutableMap(PlanFragment::getId, Function.identity()));
            this.graph = graph;
        }

        public Set<PlanFragmentId> processFragment(PlanFragmentId planFragmentId) {
            if (this.fragmentSources.containsKey(planFragmentId)) {
                return this.fragmentSources.get(planFragmentId);
            }
            Set<PlanFragmentId> fragment = this.processFragment(this.fragments.get(planFragmentId));
            this.fragmentSources.put(planFragmentId, fragment);
            return fragment;
        }

        private Set<PlanFragmentId> processFragment(PlanFragment fragment) {
            Set<PlanFragmentId> sources = fragment.getRoot().accept(this, fragment.getId());
            return ImmutableSet.builder().add((Object)fragment.getId()).addAll(sources).build();
        }

        @Override
        public Set<PlanFragmentId> visitJoin(JoinNode node, PlanFragmentId currentFragmentId) {
            return this.processJoin(node.getRight(), node.getLeft(), currentFragmentId);
        }

        @Override
        public Set<PlanFragmentId> visitSpatialJoin(SpatialJoinNode node, PlanFragmentId currentFragmentId) {
            return this.processJoin(node.getRight(), node.getLeft(), currentFragmentId);
        }

        @Override
        public Set<PlanFragmentId> visitSemiJoin(SemiJoinNode node, PlanFragmentId currentFragmentId) {
            return this.processJoin(node.getFilteringSource(), node.getSource(), currentFragmentId);
        }

        @Override
        public Set<PlanFragmentId> visitIndexJoin(IndexJoinNode node, PlanFragmentId currentFragmentId) {
            return this.processJoin(node.getIndexSource(), node.getProbeSource(), currentFragmentId);
        }

        private Set<PlanFragmentId> processJoin(PlanNode build, PlanNode probe, PlanFragmentId currentFragmentId) {
            Set<PlanFragmentId> buildSources = build.accept(this, currentFragmentId);
            Set<PlanFragmentId> probeSources = probe.accept(this, currentFragmentId);
            for (PlanFragmentId buildSource : buildSources) {
                for (PlanFragmentId probeSource : probeSources) {
                    this.graph.addEdge((Object)buildSource, (Object)probeSource);
                }
            }
            return ImmutableSet.builder().addAll(buildSources).addAll(probeSources).build();
        }

        @Override
        public Set<PlanFragmentId> visitRemoteSource(RemoteSourceNode node, PlanFragmentId currentFragmentId) {
            ImmutableSet.Builder sources = ImmutableSet.builder();
            Object previousFragmentSources = ImmutableSet.of();
            for (PlanFragmentId remoteFragment : node.getSourceFragmentIds()) {
                this.graph.addEdge((Object)currentFragmentId, (Object)remoteFragment);
                Set<PlanFragmentId> remoteFragmentSources = this.processFragment(remoteFragment);
                sources.addAll(remoteFragmentSources);
                this.addEdges((Set<PlanFragmentId>)previousFragmentSources, remoteFragmentSources);
                previousFragmentSources = remoteFragmentSources;
            }
            return sources.build();
        }

        @Override
        public Set<PlanFragmentId> visitExchange(ExchangeNode node, PlanFragmentId currentFragmentId) {
            Preconditions.checkArgument((node.getScope() == ExchangeNode.Scope.LOCAL ? 1 : 0) != 0, (Object)"Only local exchanges are supported in the phased execution scheduler");
            ImmutableSet.Builder allSources = ImmutableSet.builder();
            Object previousSources = ImmutableSet.of();
            for (PlanNode subPlanNode : node.getSources()) {
                Set<PlanFragmentId> currentSources = subPlanNode.accept(this, currentFragmentId);
                allSources.addAll(currentSources);
                this.addEdges((Set<PlanFragmentId>)previousSources, currentSources);
                previousSources = currentSources;
            }
            return allSources.build();
        }

        @Override
        public Set<PlanFragmentId> visitUnion(UnionNode node, PlanFragmentId currentFragmentId) {
            ImmutableSet.Builder allSources = ImmutableSet.builder();
            Object previousSources = ImmutableSet.of();
            for (PlanNode subPlanNode : node.getSources()) {
                Set<PlanFragmentId> currentSources = subPlanNode.accept(this, currentFragmentId);
                allSources.addAll(currentSources);
                this.addEdges((Set<PlanFragmentId>)previousSources, currentSources);
                previousSources = currentSources;
            }
            return allSources.build();
        }

        @Override
        protected Set<PlanFragmentId> visitPlan(PlanNode node, PlanFragmentId currentFragmentId) {
            List<PlanNode> sources = node.getSources();
            if (sources.isEmpty()) {
                return ImmutableSet.of((Object)currentFragmentId);
            }
            if (sources.size() == 1) {
                return sources.get(0).accept(this, currentFragmentId);
            }
            throw new UnsupportedOperationException("not yet implemented: " + node.getClass().getName());
        }

        private void addEdges(Set<PlanFragmentId> sourceFragments, Set<PlanFragmentId> targetFragments) {
            for (PlanFragmentId targetFragment : targetFragments) {
                for (PlanFragmentId sourceFragment : sourceFragments) {
                    this.graph.addEdge((Object)sourceFragment, (Object)targetFragment);
                }
            }
        }
    }
}

