/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.base.Preconditions;
import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.concurrent.Threads;
import io.airlift.slice.SizeOf;
import io.airlift.slice.Slice;
import io.airlift.stats.GcMonitor;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.tracing.Tracing;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.opentelemetry.api.trace.Span;
import io.trino.SessionTestUtils;
import io.trino.block.BlockAssertions;
import io.trino.execution.ScheduledSplit;
import io.trino.execution.SplitAssignment;
import io.trino.execution.SqlTaskExecution;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.TaskState;
import io.trino.execution.TaskStateMachine;
import io.trino.execution.TaskTestUtils;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.BufferState;
import io.trino.execution.buffer.OutputBuffer;
import io.trino.execution.buffer.OutputBufferStateMachine;
import io.trino.execution.buffer.PagesSerdeFactory;
import io.trino.execution.buffer.PagesSerdeUtil;
import io.trino.execution.buffer.PartitionedOutputBuffer;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.execution.executor.timesharing.TimeSharingTaskExecutor;
import io.trino.memory.MemoryPool;
import io.trino.memory.QueryContext;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.SimpleLocalMemoryContext;
import io.trino.metadata.Split;
import io.trino.operator.DriverContext;
import io.trino.operator.DriverFactory;
import io.trino.operator.OperatorContext;
import io.trino.operator.SourceOperator;
import io.trino.operator.SourceOperatorFactory;
import io.trino.operator.TaskContext;
import io.trino.operator.output.TaskOutputOperator;
import io.trino.spi.Page;
import io.trino.spi.QueryId;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.spi.block.TestingBlockEncodingSerde;
import io.trino.spi.connector.ConnectorSplit;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.sql.planner.LocalExecutionPlanner;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingHandles;
import java.util.List;
import java.util.Objects;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.function.Supplier;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestSqlTaskExecution {
    private static final PipelinedOutputBuffers.OutputBufferId OUTPUT_BUFFER_ID = new PipelinedOutputBuffers.OutputBufferId(0);
    private static final Duration ASSERT_WAIT_TIMEOUT = new Duration(1.0, TimeUnit.HOURS);
    public static final TaskId TASK_ID = new TaskId(new StageId("query", 0), 0, 0);

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Test
    public void testSimple() throws Exception {
        ScheduledExecutorService taskNotificationExecutor = Executors.newScheduledThreadPool(10, Threads.threadsNamed((String)"task-notification-%s"));
        ScheduledExecutorService driverYieldExecutor = Executors.newScheduledThreadPool(2, Threads.threadsNamed((String)"driver-yield-%s"));
        ScheduledExecutorService driverTimeoutExecutor = Executors.newScheduledThreadPool(2, Threads.threadsNamed((String)"driver-timeout-%s"));
        TimeSharingTaskExecutor taskExecutor = new TimeSharingTaskExecutor(5, 10, 3, 4, Ticker.systemTicker());
        taskExecutor.start();
        try {
            TaskStateMachine taskStateMachine = new TaskStateMachine(TASK_ID, (Executor)taskNotificationExecutor);
            PartitionedOutputBuffer outputBuffer = this.newTestingOutputBuffer(taskNotificationExecutor);
            OutputBufferConsumer outputBufferConsumer = new OutputBufferConsumer((OutputBuffer)outputBuffer, OUTPUT_BUFFER_ID);
            TestingScanOperatorFactory testingScanOperatorFactory = new TestingScanOperatorFactory(0, TaskTestUtils.TABLE_SCAN_NODE_ID);
            TaskOutputOperator.TaskOutputOperatorFactory taskOutputOperatorFactory = new TaskOutputOperator.TaskOutputOperatorFactory(1, TaskTestUtils.TABLE_SCAN_NODE_ID, (OutputBuffer)outputBuffer, Function.identity(), new PagesSerdeFactory((BlockEncodingSerde)new TestingBlockEncodingSerde(), false));
            LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = new LocalExecutionPlanner.LocalExecutionPlan((List)ImmutableList.of((Object)new DriverFactory(0, true, true, (List)ImmutableList.of((Object)testingScanOperatorFactory, (Object)taskOutputOperatorFactory), OptionalInt.empty())), (List)ImmutableList.of((Object)TaskTestUtils.TABLE_SCAN_NODE_ID));
            TaskContext taskContext = this.newTestingTaskContext(taskNotificationExecutor, driverYieldExecutor, driverTimeoutExecutor, taskStateMachine);
            SqlTaskExecution sqlTaskExecution = new SqlTaskExecution(taskStateMachine, taskContext, Span.getInvalid(), (OutputBuffer)outputBuffer, localExecutionPlan, (TaskExecutor)taskExecutor, TaskTestUtils.createTestSplitMonitor(), Tracing.noopTracer(), (Executor)taskNotificationExecutor);
            sqlTaskExecution.start();
            Assertions.assertThat((Comparable)taskStateMachine.getState()).isEqualTo((Object)TaskState.RUNNING);
            try {
                PlanNodeId tableScanNodeId = new PlanNodeId("tableScan1");
                sqlTaskExecution.addSplitAssignments((List)ImmutableList.of((Object)new SplitAssignment(tableScanNodeId, (Set)ImmutableSet.of((Object)this.newScheduledSplit(3, tableScanNodeId, 400000, 400)), false)));
            }
            catch (NullPointerException nullPointerException) {
                // empty catch block
            }
            sqlTaskExecution.addSplitAssignments((List)ImmutableList.of((Object)new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, (Set)ImmutableSet.of((Object)this.newScheduledSplit(0, TaskTestUtils.TABLE_SCAN_NODE_ID, 100000, 123)), false)));
            outputBufferConsumer.consume(123, ASSERT_WAIT_TIMEOUT);
            testingScanOperatorFactory.getPauser().pause();
            sqlTaskExecution.addSplitAssignments((List)ImmutableList.of((Object)new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, (Set)ImmutableSet.of((Object)this.newScheduledSplit(1, TaskTestUtils.TABLE_SCAN_NODE_ID, 200000, 300), (Object)this.newScheduledSplit(2, TaskTestUtils.TABLE_SCAN_NODE_ID, 300000, 200)), true)));
            this.waitUntilEquals(testingScanOperatorFactory::isOverallNoMoreOperators, true, ASSERT_WAIT_TIMEOUT);
            testingScanOperatorFactory.getPauser().resume();
            outputBufferConsumer.consume(500, ASSERT_WAIT_TIMEOUT);
            outputBufferConsumer.assertBufferComplete(ASSERT_WAIT_TIMEOUT);
            Assertions.assertThat((Comparable)((TaskState)taskStateMachine.getStateChange(TaskState.RUNNING).get(10L, TimeUnit.SECONDS))).isEqualTo((Object)TaskState.FLUSHING);
            outputBufferConsumer.abort();
            Assertions.assertThat((Comparable)((TaskState)taskStateMachine.getStateChange(TaskState.FLUSHING).get(10L, TimeUnit.SECONDS))).isEqualTo((Object)TaskState.FINISHED);
        }
        finally {
            taskExecutor.stop();
            taskNotificationExecutor.shutdownNow();
            driverYieldExecutor.shutdown();
            driverTimeoutExecutor.shutdown();
        }
    }

    private TaskContext newTestingTaskContext(ScheduledExecutorService taskNotificationExecutor, ScheduledExecutorService driverYieldExecutor, ScheduledExecutorService driverTimeoutExecutor, TaskStateMachine taskStateMachine) {
        QueryContext queryContext = new QueryContext(new QueryId("queryid"), DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.MEGABYTE), new MemoryPool(DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.GIGABYTE)), (GcMonitor)new TestingGcMonitor(), (Executor)taskNotificationExecutor, driverYieldExecutor, driverTimeoutExecutor, DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.MEGABYTE), new SpillSpaceTracker(DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.GIGABYTE)));
        return queryContext.addTaskContext(taskStateMachine, SessionTestUtils.TEST_SESSION, () -> {}, false, false);
    }

    private PartitionedOutputBuffer newTestingOutputBuffer(ScheduledExecutorService taskNotificationExecutor) {
        return new PartitionedOutputBuffer(TASK_ID.toString(), new OutputBufferStateMachine(TASK_ID, (Executor)taskNotificationExecutor), PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUTPUT_BUFFER_ID, 0).withNoMoreBufferIds(), DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.MEGABYTE), () -> new SimpleLocalMemoryContext(AggregatedMemoryContext.newSimpleAggregatedMemoryContext(), "test"), (Executor)taskNotificationExecutor);
    }

    private <T> void waitUntilEquals(Supplier<T> actualSupplier, T expected, Duration timeout) {
        long nanoUntil = System.nanoTime() + timeout.toMillis() * 1000000L;
        while (System.nanoTime() - nanoUntil < 0L) {
            if (expected.equals(actualSupplier.get())) {
                return;
            }
            try {
                Thread.sleep(10L);
            }
            catch (InterruptedException interruptedException) {}
        }
        Assertions.assertThat(actualSupplier.get()).isEqualTo(expected);
    }

    private ScheduledSplit newScheduledSplit(int sequenceId, PlanNodeId planNodeId, int begin, int count) {
        return new ScheduledSplit((long)sequenceId, planNodeId, new Split(TestingHandles.TEST_CATALOG_HANDLE, (ConnectorSplit)new TestingSplit(begin, begin + count)));
    }

    private static class OutputBufferConsumer {
        private final OutputBuffer outputBuffer;
        private final PipelinedOutputBuffers.OutputBufferId outputBufferId;
        private int sequenceId;
        private int surplusPositions;
        private boolean bufferComplete;

        public OutputBufferConsumer(OutputBuffer outputBuffer, PipelinedOutputBuffers.OutputBufferId outputBufferId) {
            this.outputBuffer = outputBuffer;
            this.outputBufferId = outputBufferId;
        }

        public void consume(int positions, Duration timeout) throws ExecutionException, InterruptedException, TimeoutException {
            long nanoUntil = System.nanoTime() + timeout.toMillis() * 1000000L;
            this.surplusPositions -= positions;
            while (this.surplusPositions < 0) {
                ((AbstractBooleanAssert)Assertions.assertThat((boolean)this.bufferComplete).describedAs("bufferComplete is set before enough positions are consumed", new Object[0])).isFalse();
                BufferResult results = (BufferResult)this.outputBuffer.get(this.outputBufferId, (long)this.sequenceId, DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.MEGABYTE)).get(nanoUntil - System.nanoTime(), TimeUnit.NANOSECONDS);
                this.bufferComplete = results.isBufferComplete();
                for (Slice serializedPage : results.getSerializedPages()) {
                    this.surplusPositions += PagesSerdeUtil.getSerializedPagePositionCount((Slice)serializedPage);
                }
                this.sequenceId += results.getSerializedPages().size();
            }
        }

        public void assertBufferComplete(Duration timeout) throws InterruptedException, ExecutionException, TimeoutException {
            Assertions.assertThat((int)this.surplusPositions).isEqualTo(0);
            long nanoUntil = System.nanoTime() + timeout.toMillis() * 1000000L;
            while (!this.bufferComplete) {
                BufferResult results = (BufferResult)this.outputBuffer.get(this.outputBufferId, (long)this.sequenceId, DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.MEGABYTE)).get(nanoUntil - System.nanoTime(), TimeUnit.NANOSECONDS);
                this.bufferComplete = results.isBufferComplete();
                for (Slice serializedPage : results.getSerializedPages()) {
                    Assertions.assertThat((int)PagesSerdeUtil.getSerializedPagePositionCount((Slice)serializedPage)).isEqualTo(0);
                }
                this.sequenceId += results.getSerializedPages().size();
            }
        }

        public void abort() {
            this.outputBuffer.destroy(this.outputBufferId);
            Assertions.assertThat((Comparable)this.outputBuffer.getInfo().getState()).isEqualTo((Object)BufferState.FINISHED);
        }
    }

    public static class TestingScanOperatorFactory
    implements SourceOperatorFactory {
        private final int operatorId;
        private final PlanNodeId sourceId;
        private final Pauser pauser = new Pauser();
        private boolean overallNoMoreOperators;

        public TestingScanOperatorFactory(int operatorId, PlanNodeId sourceId) {
            this.operatorId = operatorId;
            this.sourceId = Objects.requireNonNull(sourceId, "sourceId is null");
        }

        public PlanNodeId getSourceId() {
            return this.sourceId;
        }

        public SourceOperator createOperator(DriverContext driverContext) {
            Preconditions.checkState((!this.overallNoMoreOperators ? 1 : 0) != 0, (Object)"noMoreOperators() has been called");
            OperatorContext operatorContext = driverContext.addOperatorContext(this.operatorId, this.sourceId, TestingScanOperator.class.getSimpleName());
            return new TestingScanOperator(operatorContext, this.sourceId);
        }

        public void noMoreOperators() {
            this.overallNoMoreOperators = true;
        }

        public boolean isOverallNoMoreOperators() {
            return this.overallNoMoreOperators;
        }

        public Pauser getPauser() {
            return this.pauser;
        }

        public class TestingScanOperator
        implements SourceOperator {
            private final OperatorContext operatorContext;
            private final PlanNodeId planNodeId;
            private final SettableFuture<Void> blocked = SettableFuture.create();
            private TestingSplit split;
            private boolean finished;

            public TestingScanOperator(OperatorContext operatorContext, PlanNodeId planNodeId) {
                this.operatorContext = Objects.requireNonNull(operatorContext, "operatorContext is null");
                this.planNodeId = Objects.requireNonNull(planNodeId, "planNodeId is null");
            }

            public OperatorContext getOperatorContext() {
                return this.operatorContext;
            }

            public PlanNodeId getSourceId() {
                return this.planNodeId;
            }

            public void addSplit(Split split) {
                Objects.requireNonNull(split, "split is null");
                Preconditions.checkState((this.split == null ? 1 : 0) != 0, (Object)"Table scan split already set");
                if (this.finished) {
                    return;
                }
                this.split = (TestingSplit)split.getConnectorSplit();
                this.blocked.set(null);
            }

            public void noMoreSplits() {
                if (this.split == null) {
                    this.finish();
                }
                this.blocked.set(null);
            }

            public void close() {
                this.finish();
            }

            public void finish() {
                this.finished = true;
            }

            public boolean isFinished() {
                return this.finished;
            }

            public ListenableFuture<Void> isBlocked() {
                return this.blocked;
            }

            public boolean needsInput() {
                return false;
            }

            public void addInput(Page page) {
                throw new UnsupportedOperationException(this.getClass().getName() + " cannot take input");
            }

            public Page getOutput() {
                if (this.split == null) {
                    return null;
                }
                TestingScanOperatorFactory.this.pauser.await();
                Page result = new Page(new Block[]{BlockAssertions.createStringSequenceBlock(this.split.getBegin(), this.split.getEnd())});
                this.finish();
                return result;
            }
        }
    }

    public static class Pauser {
        private volatile SettableFuture<Void> future = SettableFuture.create();

        public Pauser() {
            this.future.set(null);
        }

        public void pause() {
            if (!this.future.isDone()) {
                return;
            }
            this.future = SettableFuture.create();
        }

        public void resume() {
            if (this.future.isDone()) {
                return;
            }
            this.future.set(null);
        }

        public void await() {
            try {
                this.future.get();
            }
            catch (Throwable e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static class TestingSplit
    implements ConnectorSplit {
        private static final int INSTANCE_SIZE = SizeOf.instanceSize(TestingSplit.class);
        private final int begin;
        private final int end;

        @JsonCreator
        public TestingSplit(@JsonProperty(value="begin") int begin, @JsonProperty(value="end") int end) {
            this.begin = begin;
            this.end = end;
        }

        public Object getInfo() {
            return this;
        }

        public long getRetainedSizeInBytes() {
            return INSTANCE_SIZE;
        }

        public int getBegin() {
            return this.begin;
        }

        public int getEnd() {
            return this.end;
        }
    }
}

