/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution.scheduler.policy;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import io.trino.execution.ExecutionFailureInfo;
import io.trino.execution.RemoteTask;
import io.trino.execution.StageId;
import io.trino.execution.StateMachine;
import io.trino.execution.TaskId;
import io.trino.execution.TaskStatus;
import io.trino.execution.scheduler.StageExecution;
import io.trino.execution.scheduler.TaskLifecycleListener;
import io.trino.execution.scheduler.policy.PhasedExecutionSchedule;
import io.trino.execution.scheduler.policy.PlanUtils;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.InternalNode;
import io.trino.metadata.Metadata;
import io.trino.metadata.MetadataManager;
import io.trino.metadata.Split;
import io.trino.server.DynamicFilterService;
import io.trino.spi.QueryId;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanFragmentId;
import io.trino.sql.planner.plan.PlanNodeId;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.assertj.core.api.Assertions;
import org.jgrapht.DirectedGraph;
import org.testng.annotations.Test;

public class TestPhasedExecutionSchedule {
    private final DynamicFilterService dynamicFilterService = new DynamicFilterService((Metadata)MetadataManager.createTestMetadataManager(), FunctionManager.createTestingFunctionManager(), new TypeOperators(), (ExecutorService)MoreExecutors.newDirectExecutorService());

    @Test
    public void testPartitionedJoin() {
        PlanFragment buildFragment = PlanUtils.createTableScanPlanFragment("build");
        PlanFragment probeFragment = PlanUtils.createTableScanPlanFragment("probe");
        PlanFragment joinFragment = PlanUtils.createJoinPlanFragment(JoinNode.Type.INNER, JoinNode.DistributionType.PARTITIONED, "join", buildFragment, probeFragment);
        TestingStageExecution buildStage = new TestingStageExecution(buildFragment);
        TestingStageExecution probeStage = new TestingStageExecution(probeFragment);
        TestingStageExecution joinStage = new TestingStageExecution(joinFragment);
        PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages((Collection)ImmutableSet.of((Object)buildStage, (Object)probeStage, (Object)joinStage), (DynamicFilterService)this.dynamicFilterService);
        Assertions.assertThat((List)schedule.getSortedFragments()).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId(), probeFragment.getId(), joinFragment.getId()});
        DirectedGraph dependencies = schedule.getFragmentDependency();
        Assertions.assertThat((Iterable)dependencies.edgeSet()).containsExactlyInAnyOrder((Object[])new PhasedExecutionSchedule.FragmentsEdge[]{new PhasedExecutionSchedule.FragmentsEdge(buildFragment.getId(), probeFragment.getId())});
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId(), joinFragment.getId()});
        ListenableFuture rescheduleFuture = (ListenableFuture)schedule.getRescheduleFuture().orElseThrow();
        Assertions.assertThat((Future)rescheduleFuture).isNotDone();
        buildStage.setState(StageExecution.State.FLUSHING);
        Assertions.assertThat((Future)rescheduleFuture).isDone();
        schedule.schedule();
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{joinFragment.getId(), probeFragment.getId()});
        rescheduleFuture = (ListenableFuture)schedule.getRescheduleFuture().orElseThrow();
        Assertions.assertThat((Future)rescheduleFuture).isNotDone();
        probeStage.setState(StageExecution.State.FINISHED);
        Assertions.assertThat((Future)rescheduleFuture).isNotDone();
        joinStage.setState(StageExecution.State.FINISHED);
        schedule.schedule();
        Assertions.assertThat(this.getActiveFragments(schedule)).isEmpty();
        Assertions.assertThat((boolean)schedule.isFinished()).isTrue();
    }

    @Test
    public void testBroadcastSourceJoin() {
        PlanFragment buildFragment = PlanUtils.createTableScanPlanFragment("build");
        PlanFragment joinSourceFragment = PlanUtils.createBroadcastJoinPlanFragment("probe", buildFragment);
        TestingStageExecution buildStage = new TestingStageExecution(buildFragment);
        TestingStageExecution joinSourceStage = new TestingStageExecution(joinSourceFragment);
        PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages((Collection)ImmutableSet.of((Object)joinSourceStage, (Object)buildStage), (DynamicFilterService)this.dynamicFilterService);
        Assertions.assertThat((List)schedule.getSortedFragments()).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId(), joinSourceFragment.getId()});
        DirectedGraph dependencies = schedule.getFragmentDependency();
        Assertions.assertThat((Iterable)dependencies.edgeSet()).containsExactlyInAnyOrder((Object[])new PhasedExecutionSchedule.FragmentsEdge[]{new PhasedExecutionSchedule.FragmentsEdge(buildFragment.getId(), joinSourceFragment.getId())});
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId()});
        buildStage.setAnyTaskBlocked(true);
        schedule.schedule();
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId(), joinSourceFragment.getId()});
    }

    @Test
    public void testAggregation() {
        PlanFragment sourceFragment = PlanUtils.createTableScanPlanFragment("probe");
        PlanFragment aggregationFragment = PlanUtils.createAggregationFragment("aggregation", sourceFragment);
        PlanFragment buildFragment = PlanUtils.createTableScanPlanFragment("build");
        PlanFragment joinFragment = PlanUtils.createJoinPlanFragment(JoinNode.Type.INNER, JoinNode.DistributionType.REPLICATED, "join", buildFragment, aggregationFragment);
        TestingStageExecution sourceStage = new TestingStageExecution(sourceFragment);
        TestingStageExecution aggregationStage = new TestingStageExecution(aggregationFragment);
        TestingStageExecution buildStage = new TestingStageExecution(buildFragment);
        TestingStageExecution joinStage = new TestingStageExecution(joinFragment);
        PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages((Collection)ImmutableSet.of((Object)sourceStage, (Object)aggregationStage, (Object)buildStage, (Object)joinStage), (DynamicFilterService)this.dynamicFilterService);
        Assertions.assertThat((List)schedule.getSortedFragments()).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId(), sourceFragment.getId(), aggregationFragment.getId(), joinFragment.getId()});
        DirectedGraph dependencies = schedule.getFragmentDependency();
        Assertions.assertThat((Iterable)dependencies.edgeSet()).containsExactly((Object[])new PhasedExecutionSchedule.FragmentsEdge[]{new PhasedExecutionSchedule.FragmentsEdge(buildFragment.getId(), joinFragment.getId())});
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId(), sourceFragment.getId(), aggregationFragment.getId()});
    }

    @Test
    public void testDependentStageAbortedBeforeStarted() {
        PlanFragment sourceFragment = PlanUtils.createTableScanPlanFragment("probe");
        PlanFragment aggregationFragment = PlanUtils.createAggregationFragment("aggregation", sourceFragment);
        PlanFragment buildFragment = PlanUtils.createTableScanPlanFragment("build");
        PlanFragment joinFragment = PlanUtils.createJoinPlanFragment(JoinNode.Type.INNER, JoinNode.DistributionType.REPLICATED, "join", buildFragment, aggregationFragment);
        TestingStageExecution sourceStage = new TestingStageExecution(sourceFragment);
        TestingStageExecution aggregationStage = new TestingStageExecution(aggregationFragment);
        TestingStageExecution buildStage = new TestingStageExecution(buildFragment);
        TestingStageExecution joinStage = new TestingStageExecution(joinFragment);
        PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages((Collection)ImmutableSet.of((Object)sourceStage, (Object)aggregationStage, (Object)buildStage, (Object)joinStage), (DynamicFilterService)this.dynamicFilterService);
        Assertions.assertThat((List)schedule.getSortedFragments()).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId(), sourceFragment.getId(), aggregationFragment.getId(), joinFragment.getId()});
        DirectedGraph dependencies = schedule.getFragmentDependency();
        Assertions.assertThat((Iterable)dependencies.edgeSet()).containsExactly((Object[])new PhasedExecutionSchedule.FragmentsEdge[]{new PhasedExecutionSchedule.FragmentsEdge(buildFragment.getId(), joinFragment.getId())});
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{buildFragment.getId(), sourceFragment.getId(), aggregationFragment.getId()});
        joinStage.setState(StageExecution.State.ABORTED);
        buildStage.setState(StageExecution.State.FINISHED);
        aggregationStage.setState(StageExecution.State.FINISHED);
        sourceStage.setState(StageExecution.State.FINISHED);
        schedule.schedule();
        Assertions.assertThat((boolean)schedule.isFinished()).isTrue();
    }

    @Test
    public void testStageWithBroadcastAndPartitionedJoin() {
        PlanFragment broadcastBuildFragment = PlanUtils.createTableScanPlanFragment("broadcast_build");
        PlanFragment partitionedBuildFragment = PlanUtils.createTableScanPlanFragment("partitioned_build");
        PlanFragment probeFragment = PlanUtils.createTableScanPlanFragment("probe");
        PlanFragment joinFragment = PlanUtils.createBroadcastAndPartitionedJoinPlanFragment("join", broadcastBuildFragment, partitionedBuildFragment, probeFragment);
        TestingStageExecution broadcastBuildStage = new TestingStageExecution(broadcastBuildFragment);
        TestingStageExecution partitionedBuildStage = new TestingStageExecution(partitionedBuildFragment);
        TestingStageExecution probeStage = new TestingStageExecution(probeFragment);
        TestingStageExecution joinStage = new TestingStageExecution(joinFragment);
        PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages((Collection)ImmutableSet.of((Object)broadcastBuildStage, (Object)partitionedBuildStage, (Object)probeStage, (Object)joinStage), (DynamicFilterService)this.dynamicFilterService);
        DirectedGraph dependencies = schedule.getFragmentDependency();
        Assertions.assertThat((Iterable)dependencies.edgeSet()).containsExactlyInAnyOrder((Object[])new PhasedExecutionSchedule.FragmentsEdge[]{new PhasedExecutionSchedule.FragmentsEdge(broadcastBuildFragment.getId(), probeFragment.getId()), new PhasedExecutionSchedule.FragmentsEdge(partitionedBuildFragment.getId(), probeFragment.getId()), new PhasedExecutionSchedule.FragmentsEdge(broadcastBuildFragment.getId(), joinFragment.getId())});
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{partitionedBuildFragment.getId(), broadcastBuildFragment.getId(), joinFragment.getId()});
        broadcastBuildStage.setState(StageExecution.State.FLUSHING);
        schedule.schedule();
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{partitionedBuildFragment.getId(), joinFragment.getId()});
        partitionedBuildStage.setState(StageExecution.State.FLUSHING);
        schedule.schedule();
        Assertions.assertThat(this.getActiveFragments(schedule)).containsExactly((Object[])new PlanFragmentId[]{joinFragment.getId(), probeFragment.getId()});
    }

    private Set<PlanFragmentId> getActiveFragments(PhasedExecutionSchedule schedule) {
        return (Set)schedule.getActiveStages().stream().map(stage -> stage.getFragment().getId()).collect(ImmutableSet.toImmutableSet());
    }

    private static class TestingStageExecution
    implements StageExecution {
        private final PlanFragment fragment;
        private StateMachine.StateChangeListener<StageExecution.State> stateChangeListener;
        private boolean anyTaskBlocked;
        private StageExecution.State state = StageExecution.State.SCHEDULING;

        public TestingStageExecution(PlanFragment fragment) {
            this.fragment = Objects.requireNonNull(fragment, "fragment is null");
        }

        public PlanFragment getFragment() {
            return this.fragment;
        }

        public boolean isAnyTaskBlocked() {
            return this.anyTaskBlocked;
        }

        public void setAnyTaskBlocked(boolean anyTaskBlocked) {
            this.anyTaskBlocked = anyTaskBlocked;
        }

        public void setState(StageExecution.State state) {
            this.state = state;
            if (this.stateChangeListener != null) {
                this.stateChangeListener.stateChanged((Object)state);
            }
        }

        public StageExecution.State getState() {
            return this.state;
        }

        public void addStateChangeListener(StateMachine.StateChangeListener<StageExecution.State> stateChangeListener) {
            this.stateChangeListener = Objects.requireNonNull(stateChangeListener, "stateChangeListener is null");
        }

        public StageId getStageId() {
            return new StageId(new QueryId("id"), 0);
        }

        public int getAttemptId() {
            throw new UnsupportedOperationException();
        }

        public void beginScheduling() {
            throw new UnsupportedOperationException();
        }

        public void transitionToSchedulingSplits() {
            throw new UnsupportedOperationException();
        }

        public TaskLifecycleListener getTaskLifecycleListener() {
            throw new UnsupportedOperationException();
        }

        public void schedulingComplete() {
            throw new UnsupportedOperationException();
        }

        public void schedulingComplete(PlanNodeId partitionedSource) {
            throw new UnsupportedOperationException();
        }

        public void cancel() {
            throw new UnsupportedOperationException();
        }

        public void abort() {
            throw new UnsupportedOperationException();
        }

        public void recordGetSplitTime(long start) {
            throw new UnsupportedOperationException();
        }

        public Optional<RemoteTask> scheduleTask(InternalNode node, int partition, Multimap<PlanNodeId, Split> initialSplits) {
            throw new UnsupportedOperationException();
        }

        public void failTask(TaskId taskId, Throwable failureCause) {
            throw new UnsupportedOperationException();
        }

        public void failTaskRemotely(TaskId taskId, Throwable failureCause) {
            throw new UnsupportedOperationException();
        }

        public List<RemoteTask> getAllTasks() {
            throw new UnsupportedOperationException();
        }

        public List<TaskStatus> getTaskStatuses() {
            throw new UnsupportedOperationException();
        }

        public Optional<ExecutionFailureInfo> getFailureCause() {
            throw new UnsupportedOperationException();
        }
    }
}

