/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.MoreExecutors;
import io.airlift.node.NodeInfo;
import io.airlift.slice.Slice;
import io.airlift.stats.GcMonitor;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.tracing.Tracing;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.trace.Span;
import io.trino.Session;
import io.trino.SessionTestUtils;
import io.trino.connector.ConnectorServices;
import io.trino.connector.ConnectorServicesProvider;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.LocationFactory;
import io.trino.execution.ScheduledSplit;
import io.trino.execution.SplitAssignment;
import io.trino.execution.SqlTaskManager;
import io.trino.execution.StageId;
import io.trino.execution.TaskFailureListener;
import io.trino.execution.TaskId;
import io.trino.execution.TaskInfo;
import io.trino.execution.TaskManagementExecutor;
import io.trino.execution.TaskManagerConfig;
import io.trino.execution.TaskState;
import io.trino.execution.TaskStateMachine;
import io.trino.execution.TaskTestUtils;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.BufferState;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.buffer.PagesSerdeUtil;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.memory.LocalMemoryManager;
import io.trino.memory.NodeMemoryConfig;
import io.trino.memory.QueryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.metadata.InternalNode;
import io.trino.metadata.LanguageFunctionProvider;
import io.trino.metadata.WorkerLanguageFunctionProvider;
import io.trino.operator.DirectExchangeClient;
import io.trino.operator.DirectExchangeClientSupplier;
import io.trino.operator.RetryPolicy;
import io.trino.spi.QueryId;
import io.trino.spi.VersionEmbedder;
import io.trino.spi.catalog.CatalogProperties;
import io.trino.spi.connector.CatalogHandle;
import io.trino.spi.exchange.ExchangeId;
import io.trino.spiller.LocalSpillManager;
import io.trino.spiller.NodeSpillConfig;
import io.trino.testing.TestingSession;
import io.trino.util.EmbedVersion;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
@Execution(value=ExecutionMode.CONCURRENT)
public abstract class BaseTestSqlTaskManager {
    public static final PipelinedOutputBuffers.OutputBufferId OUT = new PipelinedOutputBuffers.OutputBufferId(0);
    private final AtomicInteger sequence = new AtomicInteger();
    private TaskExecutor taskExecutor;
    private TaskManagementExecutor taskManagementExecutor;

    protected abstract TaskExecutor createTaskExecutor();

    @BeforeAll
    public void setUp() {
        this.taskExecutor = this.createTaskExecutor();
        this.taskExecutor.start();
        this.taskManagementExecutor = new TaskManagementExecutor();
    }

    @AfterAll
    public void tearDown() {
        this.taskExecutor.stop();
        this.taskExecutor = null;
        this.taskManagementExecutor.close();
        this.taskManagementExecutor = null;
    }

    @Test
    public void testEmptyQuery() {
        try (SqlTaskManager sqlTaskManager = this.createSqlTaskManager(new TaskManagerConfig());){
            TaskId taskId = this.newTaskId();
            TaskInfo taskInfo = this.createTask(sqlTaskManager, taskId, (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withNoMoreBufferIds());
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.RUNNING);
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.RUNNING);
            taskInfo = this.createTask(sqlTaskManager, taskId, (ImmutableSet<ScheduledSplit>)ImmutableSet.of(), (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withNoMoreBufferIds());
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.FINISHED);
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.FINISHED);
        }
    }

    @Test
    @Timeout(value=30L)
    public void testSimpleQuery() throws Exception {
        try (SqlTaskManager sqlTaskManager = this.createSqlTaskManager(new TaskManagerConfig());){
            TaskId taskId = this.newTaskId();
            this.createTask(sqlTaskManager, taskId, (ImmutableSet<ScheduledSplit>)ImmutableSet.of((Object)TaskTestUtils.SPLIT), (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            TaskInfo taskInfo = (TaskInfo)sqlTaskManager.getTaskInfo(taskId, 0L).get();
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.FLUSHING);
            BufferResult results = (BufferResult)sqlTaskManager.getTaskResults(taskId, OUT, 0L, DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.MEGABYTE)).getResultsFuture().get();
            Assertions.assertThat((boolean)results.isBufferComplete()).isFalse();
            Assertions.assertThat((int)results.getSerializedPages().size()).isEqualTo(1);
            Assertions.assertThat((int)PagesSerdeUtil.getSerializedPagePositionCount((Slice)((Slice)results.getSerializedPages().get(0)))).isEqualTo(1);
            boolean moreResults = true;
            while (moreResults) {
                moreResults = !(results = (BufferResult)sqlTaskManager.getTaskResults(taskId, OUT, results.getToken() + (long)results.getSerializedPages().size(), DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.MEGABYTE)).getResultsFuture().get()).isBufferComplete();
            }
            Assertions.assertThat((boolean)results.isBufferComplete()).isTrue();
            Assertions.assertThat((int)results.getSerializedPages().size()).isEqualTo(0);
            TaskInfo info = sqlTaskManager.destroyTaskResults(taskId, OUT);
            Assertions.assertThat((Comparable)info.outputBuffers().getState()).isEqualTo((Object)BufferState.FINISHED);
            taskInfo = (TaskInfo)sqlTaskManager.getTaskInfo(taskId, taskInfo.taskStatus().getVersion()).get();
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.FINISHED);
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.FINISHED);
        }
    }

    @Test
    public void testCancel() throws InterruptedException, ExecutionException, TimeoutException {
        try (SqlTaskManager sqlTaskManager = this.createSqlTaskManager(new TaskManagerConfig());){
            TaskId taskId = this.newTaskId();
            TaskInfo taskInfo = this.createTask(sqlTaskManager, taskId, (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.RUNNING);
            Assertions.assertThat((Comparable)taskInfo.stats().getEndTime()).isNull();
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.RUNNING);
            Assertions.assertThat((Comparable)taskInfo.stats().getEndTime()).isNull();
            taskInfo = BaseTestSqlTaskManager.pollTerminatingTaskInfoUntilDone(sqlTaskManager, sqlTaskManager.cancelTask(taskId));
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.CANCELED);
            Assertions.assertThat((Comparable)taskInfo.stats().getEndTime()).isNotNull();
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.CANCELED);
            Assertions.assertThat((Comparable)taskInfo.stats().getEndTime()).isNotNull();
        }
    }

    @Test
    public void testAbort() throws InterruptedException, ExecutionException, TimeoutException {
        try (SqlTaskManager sqlTaskManager = this.createSqlTaskManager(new TaskManagerConfig());){
            TaskId taskId = this.newTaskId();
            TaskInfo taskInfo = this.createTask(sqlTaskManager, taskId, (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.RUNNING);
            Assertions.assertThat((Comparable)taskInfo.stats().getEndTime()).isNull();
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.RUNNING);
            Assertions.assertThat((Comparable)taskInfo.stats().getEndTime()).isNull();
            taskInfo = BaseTestSqlTaskManager.pollTerminatingTaskInfoUntilDone(sqlTaskManager, sqlTaskManager.abortTask(taskId));
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.ABORTED);
            Assertions.assertThat((Comparable)taskInfo.stats().getEndTime()).isNotNull();
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.ABORTED);
            Assertions.assertThat((Comparable)taskInfo.stats().getEndTime()).isNotNull();
        }
    }

    @Test
    @Timeout(value=30L)
    public void testAbortResults() throws Exception {
        try (SqlTaskManager sqlTaskManager = this.createSqlTaskManager(new TaskManagerConfig());){
            TaskId taskId = this.newTaskId();
            this.createTask(sqlTaskManager, taskId, (ImmutableSet<ScheduledSplit>)ImmutableSet.of((Object)TaskTestUtils.SPLIT), (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            TaskInfo taskInfo = (TaskInfo)sqlTaskManager.getTaskInfo(taskId, 0L).get();
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.FLUSHING);
            sqlTaskManager.destroyTaskResults(taskId, OUT);
            taskInfo = (TaskInfo)sqlTaskManager.getTaskInfo(taskId, taskInfo.taskStatus().getVersion()).get();
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.FINISHED);
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.FINISHED);
        }
    }

    @Test
    public void testRemoveOldTasks() throws InterruptedException, ExecutionException, TimeoutException {
        try (SqlTaskManager sqlTaskManager = this.createSqlTaskManager(new TaskManagerConfig().setInfoMaxAge(new Duration(5.0, TimeUnit.MILLISECONDS)));){
            TaskId taskId = this.newTaskId();
            TaskInfo taskInfo = this.createTask(sqlTaskManager, taskId, (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds());
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.RUNNING);
            taskInfo = BaseTestSqlTaskManager.pollTerminatingTaskInfoUntilDone(sqlTaskManager, sqlTaskManager.cancelTask(taskId));
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.CANCELED);
            taskInfo = sqlTaskManager.getTaskInfo(taskId);
            Assertions.assertThat((Comparable)taskInfo.taskStatus().getState()).isEqualTo((Object)TaskState.CANCELED);
            Thread.sleep(100L);
            sqlTaskManager.removeOldTasks();
            for (TaskInfo info : sqlTaskManager.getAllTaskInfo()) {
                Assertions.assertThat((Object)info.taskStatus().getTaskId()).isNotEqualTo((Object)taskId);
            }
        }
    }

    @Test
    public void testSessionPropertyMemoryLimitOverride() {
        NodeMemoryConfig memoryConfig = new NodeMemoryConfig().setMaxQueryMemoryPerNode(DataSize.ofBytes((long)3L));
        try (SqlTaskManager sqlTaskManager = this.createSqlTaskManager(new TaskManagerConfig(), memoryConfig);){
            TaskId reduceLimitsId = new TaskId(new StageId("q1", 0), 1, 0);
            TaskId increaseLimitsId = new TaskId(new StageId("q2", 0), 1, 0);
            QueryContext reducesLimitsContext = sqlTaskManager.getQueryContext(reduceLimitsId.getQueryId());
            QueryContext attemptsIncreaseContext = sqlTaskManager.getQueryContext(increaseLimitsId.getQueryId());
            Assertions.assertThat((boolean)reducesLimitsContext.isMemoryLimitsInitialized()).isFalse();
            Assertions.assertThat((long)reducesLimitsContext.getMaxUserMemory()).isEqualTo(memoryConfig.getMaxQueryMemoryPerNode().toBytes());
            Assertions.assertThat((boolean)attemptsIncreaseContext.isMemoryLimitsInitialized()).isFalse();
            Assertions.assertThat((long)attemptsIncreaseContext.getMaxUserMemory()).isEqualTo(memoryConfig.getMaxQueryMemoryPerNode().toBytes());
            sqlTaskManager.updateTask(TestingSession.testSessionBuilder().setSystemProperty("query_max_memory_per_node", "1B").build(), reduceLimitsId, Span.getInvalid(), Optional.of(TaskTestUtils.PLAN_FRAGMENT), (List)ImmutableList.of((Object)new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, (Set)ImmutableSet.of((Object)TaskTestUtils.SPLIT), true)), (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), (Map)ImmutableMap.of(), false);
            Assertions.assertThat((boolean)reducesLimitsContext.isMemoryLimitsInitialized()).isTrue();
            Assertions.assertThat((long)reducesLimitsContext.getMaxUserMemory()).isEqualTo(1L);
            sqlTaskManager.updateTask(TestingSession.testSessionBuilder().setSystemProperty("query_max_memory_per_node", "10B").build(), increaseLimitsId, Span.getInvalid(), Optional.of(TaskTestUtils.PLAN_FRAGMENT), (List)ImmutableList.of((Object)new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, (Set)ImmutableSet.of((Object)TaskTestUtils.SPLIT), true)), (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds(), (Map)ImmutableMap.of(), false);
            Assertions.assertThat((boolean)attemptsIncreaseContext.isMemoryLimitsInitialized()).isTrue();
            Assertions.assertThat((long)attemptsIncreaseContext.getMaxUserMemory()).isEqualTo(memoryConfig.getMaxQueryMemoryPerNode().toBytes());
        }
    }

    private SqlTaskManager createSqlTaskManager(TaskManagerConfig config) {
        return this.createSqlTaskManager(config, new NodeMemoryConfig());
    }

    private SqlTaskManager createSqlTaskManager(TaskManagerConfig taskManagerConfig, NodeMemoryConfig nodeMemoryConfig) {
        return new SqlTaskManager((VersionEmbedder)new EmbedVersion("testversion"), (ConnectorServicesProvider)new NoConnectorServicesProvider(), TaskTestUtils.createTestingPlanner(), (LanguageFunctionProvider)new WorkerLanguageFunctionProvider(), (LocationFactory)new MockLocationFactory(), this.taskExecutor, TaskTestUtils.createTestSplitMonitor(), new NodeInfo("test"), new LocalMemoryManager(nodeMemoryConfig), this.taskManagementExecutor, taskManagerConfig, nodeMemoryConfig, new LocalSpillManager(new NodeSpillConfig()), new NodeSpillConfig(), (GcMonitor)new TestingGcMonitor(), Tracing.noopTracer(), new ExchangeManagerRegistry(OpenTelemetry.noop(), Tracing.noopTracer()));
    }

    private TaskInfo createTask(SqlTaskManager sqlTaskManager, TaskId taskId, ImmutableSet<ScheduledSplit> splits, OutputBuffers outputBuffers) {
        return sqlTaskManager.updateTask(SessionTestUtils.TEST_SESSION, taskId, Span.getInvalid(), Optional.of(TaskTestUtils.PLAN_FRAGMENT), (List)ImmutableList.of((Object)new SplitAssignment(TaskTestUtils.TABLE_SCAN_NODE_ID, splits, true)), outputBuffers, (Map)ImmutableMap.of(), false);
    }

    private TaskInfo createTask(SqlTaskManager sqlTaskManager, TaskId taskId, OutputBuffers outputBuffers) {
        sqlTaskManager.getQueryContext(taskId.getQueryId()).addTaskContext(new TaskStateMachine(taskId, MoreExecutors.directExecutor()), TestingSession.testSessionBuilder().build(), () -> {}, false, false);
        return sqlTaskManager.updateTask(SessionTestUtils.TEST_SESSION, taskId, Span.getInvalid(), Optional.of(TaskTestUtils.PLAN_FRAGMENT), (List)ImmutableList.of(), outputBuffers, (Map)ImmutableMap.of(), false);
    }

    private static TaskInfo pollTerminatingTaskInfoUntilDone(SqlTaskManager taskManager, TaskInfo taskInfo) throws InterruptedException, ExecutionException, TimeoutException {
        Assertions.assertThat((boolean)taskInfo.taskStatus().getState().isTerminatingOrDone()).isTrue();
        for (int attempts = 3; attempts > 0 && taskInfo.taskStatus().getState().isTerminating(); --attempts) {
            taskInfo = (TaskInfo)taskManager.getTaskInfo(taskInfo.taskStatus().getTaskId(), taskInfo.taskStatus().getVersion()).get(5L, TimeUnit.SECONDS);
        }
        return taskInfo;
    }

    private TaskId newTaskId() {
        return new TaskId(new StageId("query" + this.sequence.incrementAndGet(), 0), 1, 0);
    }

    private static class NoConnectorServicesProvider
    implements ConnectorServicesProvider {
        private NoConnectorServicesProvider() {
        }

        public void loadInitialCatalogs() {
        }

        public void ensureCatalogsLoaded(Session session, List<CatalogProperties> catalogs) {
        }

        public void pruneCatalogs(Set<CatalogHandle> catalogsInUse) {
            throw new UnsupportedOperationException();
        }

        public ConnectorServices getConnectorServices(CatalogHandle catalogHandle) {
            throw new UnsupportedOperationException();
        }
    }

    public static class MockLocationFactory
    implements LocationFactory {
        public URI createQueryLocation(QueryId queryId) {
            return URI.create("http://fake.invalid/query/" + String.valueOf(queryId));
        }

        public URI createLocalTaskLocation(TaskId taskId) {
            return URI.create("http://fake.invalid/task/" + String.valueOf(taskId));
        }

        public URI createTaskLocation(InternalNode node, TaskId taskId) {
            return URI.create("http://fake.invalid/task/" + node.getNodeIdentifier() + "/" + String.valueOf(taskId));
        }

        public URI createMemoryInfoLocation(InternalNode node) {
            return URI.create("http://fake.invalid/" + node.getNodeIdentifier() + "/memory");
        }
    }

    public static class MockDirectExchangeClientSupplier
    implements DirectExchangeClientSupplier {
        public DirectExchangeClient get(QueryId queryId, ExchangeId exchangeId, Span parentSpan, LocalMemoryContext memoryContext, TaskFailureListener taskFailureListener, RetryPolicy retryPolicy) {
            throw new UnsupportedOperationException();
        }
    }
}

