/*
 * Decompiled with CFR 0.152.
 */
package io.trino.execution;

import com.google.common.base.Ticker;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import io.airlift.concurrent.Threads;
import io.airlift.configuration.secrets.SecretsResolver;
import io.airlift.stats.CounterStat;
import io.airlift.stats.GcMonitor;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.tracing.Tracing;
import io.airlift.units.DataSize;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.trace.Tracer;
import io.trino.exchange.ExchangeManagerRegistry;
import io.trino.execution.MemoryRevokingScheduler;
import io.trino.execution.SplitAssignment;
import io.trino.execution.SqlTask;
import io.trino.execution.SqlTaskExecutionFactory;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.TaskManagerConfig;
import io.trino.execution.TaskTestUtils;
import io.trino.execution.TestSqlTask;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.executor.TaskExecutor;
import io.trino.execution.executor.timesharing.TimeSharingTaskExecutor;
import io.trino.memory.MemoryPool;
import io.trino.memory.QueryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.operator.DriverContext;
import io.trino.operator.OperatorContext;
import io.trino.operator.PipelineContext;
import io.trino.operator.TaskContext;
import io.trino.spi.QueryId;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.sql.planner.LocalExecutionPlanner;
import io.trino.sql.planner.plan.PlanNodeId;
import java.net.URI;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(value=TestInstance.Lifecycle.PER_METHOD)
public class TestMemoryRevokingScheduler {
    private final AtomicInteger idGenerator = new AtomicInteger();
    private final SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(DataSize.of((long)10L, (DataSize.Unit)DataSize.Unit.GIGABYTE));
    private final Map<QueryId, QueryContext> queryContexts = new HashMap<QueryId, QueryContext>();
    private MemoryPool memoryPool;
    private TaskExecutor taskExecutor;
    private ScheduledExecutorService executor;
    private ScheduledExecutorService scheduledExecutor;
    private SqlTaskExecutionFactory sqlTaskExecutionFactory;
    private Set<OperatorContext> allOperatorContexts;

    @BeforeEach
    public void setUp() {
        this.memoryPool = new MemoryPool(DataSize.ofBytes((long)10L));
        this.taskExecutor = new TimeSharingTaskExecutor(8, 16, 3, 4, Ticker.systemTicker());
        this.taskExecutor.start();
        this.executor = Executors.newScheduledThreadPool(1, Threads.threadsNamed((String)"task-notification-%s"));
        this.scheduledExecutor = Executors.newScheduledThreadPool(2, Threads.threadsNamed((String)"task-notification-%s"));
        LocalExecutionPlanner planner = TaskTestUtils.createTestingPlanner();
        this.sqlTaskExecutionFactory = new SqlTaskExecutionFactory((Executor)this.executor, this.taskExecutor, planner, TaskTestUtils.createTestSplitMonitor(), Tracing.noopTracer(), new TaskManagerConfig());
        this.allOperatorContexts = null;
    }

    @AfterEach
    public void tearDown() {
        this.queryContexts.clear();
        this.memoryPool = null;
        this.taskExecutor.stop();
        this.taskExecutor = null;
        this.executor.shutdownNow();
        this.scheduledExecutor.shutdownNow();
        this.sqlTaskExecutionFactory = null;
        this.allOperatorContexts = null;
    }

    @Test
    public void testScheduleMemoryRevoking() throws Exception {
        QueryContext q1 = this.getOrCreateQueryContext(new QueryId("q1"));
        QueryContext q2 = this.getOrCreateQueryContext(new QueryId("q2"));
        SqlTask sqlTask1 = this.newSqlTask(q1.getQueryId());
        SqlTask sqlTask2 = this.newSqlTask(q2.getQueryId());
        TaskContext taskContext1 = this.getOrCreateTaskContext(sqlTask1);
        PipelineContext pipelineContext11 = taskContext1.addPipelineContext(0, false, false, false);
        DriverContext driverContext111 = pipelineContext11.addDriverContext();
        OperatorContext operatorContext1 = driverContext111.addOperatorContext(1, new PlanNodeId("na"), "na");
        OperatorContext operatorContext2 = driverContext111.addOperatorContext(2, new PlanNodeId("na"), "na");
        DriverContext driverContext112 = pipelineContext11.addDriverContext();
        OperatorContext operatorContext3 = driverContext112.addOperatorContext(3, new PlanNodeId("na"), "na");
        TaskContext taskContext2 = this.getOrCreateTaskContext(sqlTask2);
        PipelineContext pipelineContext21 = taskContext2.addPipelineContext(1, false, false, false);
        DriverContext driverContext211 = pipelineContext21.addDriverContext();
        OperatorContext operatorContext4 = driverContext211.addOperatorContext(4, new PlanNodeId("na"), "na");
        OperatorContext operatorContext5 = driverContext211.addOperatorContext(5, new PlanNodeId("na"), "na");
        ImmutableList tasks = ImmutableList.of((Object)sqlTask1, (Object)sqlTask2);
        MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(this.memoryPool, () -> TestMemoryRevokingScheduler.lambda$testScheduleMemoryRevoking$0((Collection)tasks), this.executor, 1.0, 1.0);
        this.allOperatorContexts = ImmutableSet.of((Object)operatorContext1, (Object)operatorContext2, (Object)operatorContext3, (Object)operatorContext4, (Object)operatorContext5);
        this.assertMemoryRevokingNotRequested();
        this.requestMemoryRevoking(scheduler);
        Assertions.assertThat((int)10).isEqualTo(this.memoryPool.getFreeBytes());
        this.assertMemoryRevokingNotRequested();
        LocalMemoryContext revocableMemory1 = operatorContext1.localRevocableMemoryContext();
        LocalMemoryContext revocableMemory3 = operatorContext3.localRevocableMemoryContext();
        LocalMemoryContext revocableMemory4 = operatorContext4.localRevocableMemoryContext();
        LocalMemoryContext revocableMemory5 = operatorContext5.localRevocableMemoryContext();
        revocableMemory1.setBytes(3L);
        revocableMemory3.setBytes(6L);
        Assertions.assertThat((int)1).isEqualTo(this.memoryPool.getFreeBytes());
        this.requestMemoryRevoking(scheduler);
        this.assertMemoryRevokingNotRequested();
        revocableMemory4.setBytes(7L);
        Assertions.assertThat((int)-6).isEqualTo(this.memoryPool.getFreeBytes());
        this.requestMemoryRevoking(scheduler);
        this.assertMemoryRevokingRequestedFor(operatorContext1, operatorContext3);
        this.requestMemoryRevoking(scheduler);
        this.assertMemoryRevokingRequestedFor(operatorContext1, operatorContext3);
        revocableMemory1.setBytes(0L);
        operatorContext1.resetMemoryRevokingRequested();
        this.requestMemoryRevoking(scheduler);
        this.assertMemoryRevokingRequestedFor(operatorContext3);
        Assertions.assertThat((int)-3).isEqualTo(this.memoryPool.getFreeBytes());
        revocableMemory5.setBytes(3L);
        Assertions.assertThat((int)-6).isEqualTo(this.memoryPool.getFreeBytes());
        this.requestMemoryRevoking(scheduler);
        this.assertMemoryRevokingRequestedFor(operatorContext3);
        revocableMemory5.setBytes(4L);
        Assertions.assertThat((int)-7).isEqualTo(this.memoryPool.getFreeBytes());
        this.requestMemoryRevoking(scheduler);
        this.assertMemoryRevokingRequestedFor(operatorContext3, operatorContext4);
    }

    @Test
    public void testImmediateMemoryRevoking() throws Exception {
        SqlTask sqlTask = this.newSqlTask(new QueryId("query"));
        OperatorContext operatorContext = this.createContexts(sqlTask);
        this.allOperatorContexts = ImmutableSet.of((Object)operatorContext);
        ImmutableList tasks = ImmutableList.of((Object)sqlTask);
        MemoryRevokingScheduler scheduler = new MemoryRevokingScheduler(this.memoryPool, () -> TestMemoryRevokingScheduler.lambda$testImmediateMemoryRevoking$0((List)tasks), this.executor, 1.0, 1.0);
        scheduler.registerPoolListeners();
        operatorContext.localRevocableMemoryContext().setBytes(12L);
        this.awaitAsynchronousCallbacksRun();
        this.assertMemoryRevokingRequestedFor(operatorContext);
    }

    private OperatorContext createContexts(SqlTask sqlTask) {
        TaskContext taskContext = this.getOrCreateTaskContext(sqlTask);
        PipelineContext pipelineContext = taskContext.addPipelineContext(0, false, false, false);
        DriverContext driverContext = pipelineContext.addDriverContext();
        return driverContext.addOperatorContext(1, new PlanNodeId("na"), "na");
    }

    private void requestMemoryRevoking(MemoryRevokingScheduler scheduler) throws Exception {
        scheduler.requestMemoryRevokingIfNeeded();
        this.awaitAsynchronousCallbacksRun();
    }

    private void awaitAsynchronousCallbacksRun() throws Exception {
        this.executor.invokeAll(Collections.singletonList(() -> null));
    }

    private void assertMemoryRevokingRequestedFor(OperatorContext ... operatorContexts) {
        ImmutableSet operatorContextsSet = ImmutableSet.copyOf((Object[])operatorContexts);
        operatorContextsSet.forEach(operatorContext -> ((AbstractBooleanAssert)Assertions.assertThat((boolean)operatorContext.isMemoryRevokingRequested()).describedAs("expected memory requested for operator " + operatorContext.getOperatorId(), new Object[0])).isTrue());
        Sets.difference(this.allOperatorContexts, (Set)operatorContextsSet).forEach(operatorContext -> ((AbstractBooleanAssert)Assertions.assertThat((boolean)operatorContext.isMemoryRevokingRequested()).describedAs("expected memory  not requested for operator " + operatorContext.getOperatorId(), new Object[0])).isFalse());
    }

    private void assertMemoryRevokingNotRequested() {
        this.assertMemoryRevokingRequestedFor(new OperatorContext[0]);
    }

    private SqlTask newSqlTask(QueryId queryId) {
        QueryContext queryContext = this.getOrCreateQueryContext(queryId);
        TaskId taskId = new TaskId(new StageId(queryId.getId(), 0), this.idGenerator.incrementAndGet(), 0);
        URI location = URI.create("fake://task/" + String.valueOf(taskId));
        return SqlTask.createSqlTask((TaskId)taskId, (URI)location, (String)"fake", (QueryContext)queryContext, (Tracer)Tracing.noopTracer(), (SqlTaskExecutionFactory)this.sqlTaskExecutionFactory, (ExecutorService)this.executor, sqlTask -> {}, (DataSize)DataSize.of((long)32L, (DataSize.Unit)DataSize.Unit.MEGABYTE), (DataSize)DataSize.of((long)200L, (DataSize.Unit)DataSize.Unit.MEGABYTE), (ExchangeManagerRegistry)new ExchangeManagerRegistry(OpenTelemetry.noop(), Tracing.noopTracer(), new SecretsResolver((Map)ImmutableMap.of())), (CounterStat)new CounterStat());
    }

    private QueryContext getOrCreateQueryContext(QueryId queryId) {
        return this.queryContexts.computeIfAbsent(queryId, id -> new QueryContext(id, DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.MEGABYTE), this.memoryPool, (GcMonitor)new TestingGcMonitor(), (Executor)this.executor, this.scheduledExecutor, this.scheduledExecutor, DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.GIGABYTE), this.spillSpaceTracker));
    }

    private TaskContext getOrCreateTaskContext(SqlTask sqlTask) {
        if (sqlTask.getTaskContext().isEmpty()) {
            TaskTestUtils.updateTask(sqlTask, (List<SplitAssignment>)ImmutableList.of(), (OutputBuffers)PipelinedOutputBuffers.createInitial((PipelinedOutputBuffers.BufferType)PipelinedOutputBuffers.BufferType.PARTITIONED).withBuffer(TestSqlTask.OUT, 0).withNoMoreBufferIds());
        }
        return (TaskContext)sqlTask.getTaskContext().orElseThrow(() -> new IllegalStateException("TaskContext not present"));
    }

    private static /* synthetic */ Collection lambda$testImmediateMemoryRevoking$0(List tasks) {
        return tasks;
    }

    private static /* synthetic */ Collection lambda$testScheduleMemoryRevoking$0(Collection tasks) {
        return tasks;
    }
}

