/*
 * Decompiled with CFR 0.152.
 */
package io.trino.memory;

import com.google.common.collect.Iterables;
import io.airlift.concurrent.Threads;
import io.airlift.stats.GcMonitor;
import io.airlift.stats.TestingGcMonitor;
import io.airlift.units.DataSize;
import io.trino.ExceededMemoryLimitException;
import io.trino.execution.StageId;
import io.trino.execution.TaskId;
import io.trino.execution.TaskStateMachine;
import io.trino.memory.MemoryPool;
import io.trino.memory.QueryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.memory.context.MemoryTrackingContext;
import io.trino.operator.DriverContext;
import io.trino.operator.DriverStats;
import io.trino.operator.OperatorContext;
import io.trino.operator.OperatorStats;
import io.trino.operator.PipelineContext;
import io.trino.operator.PipelineStats;
import io.trino.operator.TaskContext;
import io.trino.operator.TaskStats;
import io.trino.spi.QueryId;
import io.trino.spiller.SpillSpaceTracker;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingSession;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import org.assertj.core.api.AbstractLongAssert;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(value=TestInstance.Lifecycle.PER_METHOD)
public class TestMemoryTracking {
    private static final DataSize queryMaxMemory = DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.GIGABYTE);
    private static final DataSize memoryPoolSize = DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.GIGABYTE);
    private static final DataSize maxSpillSize = DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.GIGABYTE);
    private static final DataSize queryMaxSpillSize = DataSize.of((long)1L, (DataSize.Unit)DataSize.Unit.GIGABYTE);
    private static final SpillSpaceTracker spillSpaceTracker = new SpillSpaceTracker(maxSpillSize);
    private QueryContext queryContext;
    private TaskContext taskContext;
    private PipelineContext pipelineContext;
    private DriverContext driverContext;
    private OperatorContext operatorContext;
    private MemoryPool memoryPool;
    private ExecutorService notificationExecutor;
    private ScheduledExecutorService yieldExecutor;

    @AfterEach
    public void tearDown() {
        this.notificationExecutor.shutdownNow();
        this.yieldExecutor.shutdownNow();
        this.queryContext = null;
        this.taskContext = null;
        this.pipelineContext = null;
        this.driverContext = null;
        this.operatorContext = null;
        this.memoryPool = null;
    }

    @BeforeEach
    public void setUpTest() {
        this.notificationExecutor = Executors.newCachedThreadPool(Threads.daemonThreadsNamed((String)"local-query-runner-executor-%s"));
        this.yieldExecutor = Executors.newScheduledThreadPool(2, Threads.daemonThreadsNamed((String)"local-query-runner-scheduler-%s"));
        this.memoryPool = new MemoryPool(memoryPoolSize);
        this.queryContext = new QueryContext(new QueryId("test_query"), queryMaxMemory, this.memoryPool, (GcMonitor)new TestingGcMonitor(), (Executor)this.notificationExecutor, this.yieldExecutor, queryMaxSpillSize, spillSpaceTracker);
        this.taskContext = this.queryContext.addTaskContext(new TaskStateMachine(new TaskId(new StageId("test_query", 0), 0, 0), (Executor)this.notificationExecutor), TestingSession.testSessionBuilder().build(), () -> {}, true, true);
        this.pipelineContext = this.taskContext.addPipelineContext(0, true, true, false);
        this.driverContext = this.pipelineContext.addDriverContext();
        this.operatorContext = this.driverContext.addOperatorContext(1, new PlanNodeId("a"), "test-operator");
    }

    @Test
    public void testOperatorAllocations() {
        MemoryTrackingContext operatorMemoryContext = this.operatorContext.getOperatorMemoryContext();
        LocalMemoryContext userMemory = this.operatorContext.localUserMemoryContext();
        LocalMemoryContext revocableMemory = this.operatorContext.localRevocableMemoryContext();
        userMemory.setBytes(100L);
        this.assertOperatorMemoryAllocations(operatorMemoryContext, 100L, 0L);
        this.assertOperatorMemoryAllocations(operatorMemoryContext, 100L, 0L);
        userMemory.setBytes(500L);
        this.assertOperatorMemoryAllocations(operatorMemoryContext, 500L, 0L);
        userMemory.setBytes(userMemory.getBytes() - 500L);
        this.assertOperatorMemoryAllocations(operatorMemoryContext, 0L, 0L);
        revocableMemory.setBytes(300L);
        this.assertOperatorMemoryAllocations(operatorMemoryContext, 0L, 300L);
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> userMemory.setBytes(userMemory.getBytes() - 500L)).isInstanceOf(IllegalArgumentException.class)).hasMessage("bytes cannot be negative");
        this.operatorContext.destroy();
        this.assertOperatorMemoryAllocations(operatorMemoryContext, 0L, 0L);
    }

    @Test
    public void testLocalTotalMemoryLimitExceeded() {
        LocalMemoryContext memoryContext = this.operatorContext.newLocalUserMemoryContext("test");
        memoryContext.setBytes(100L);
        this.assertOperatorMemoryAllocations(this.operatorContext.getOperatorMemoryContext(), 100L, 0L);
        memoryContext.setBytes(queryMaxMemory.toBytes());
        this.assertOperatorMemoryAllocations(this.operatorContext.getOperatorMemoryContext(), queryMaxMemory.toBytes(), 0L);
        ((AbstractThrowableAssert)Assertions.assertThatThrownBy(() -> memoryContext.setBytes(queryMaxMemory.toBytes() + 1L)).isInstanceOf(ExceededMemoryLimitException.class)).hasMessage("Query exceeded per-node memory limit of %1$s [Allocated: %1$s, Delta: 1B, Top Consumers: {test=%1$s}]", new Object[]{queryMaxMemory});
    }

    @Test
    public void testLocalAllocations() {
        long pipelineLocalAllocation = 1000000L;
        long taskLocalAllocation = 10000000L;
        LocalMemoryContext pipelineLocalMemoryContext = this.pipelineContext.localMemoryContext();
        pipelineLocalMemoryContext.setBytes(pipelineLocalAllocation);
        this.assertLocalMemoryAllocations(this.pipelineContext.getPipelineMemoryContext(), pipelineLocalAllocation, 1000000L);
        LocalMemoryContext taskLocalMemoryContext = this.taskContext.localMemoryContext();
        taskLocalMemoryContext.setBytes(taskLocalAllocation);
        this.assertLocalMemoryAllocations(this.taskContext.getTaskMemoryContext(), pipelineLocalAllocation + taskLocalAllocation, 11000000L);
        ((AbstractLongAssert)Assertions.assertThat((long)this.pipelineContext.getPipelineStats().getUserMemoryReservation().toBytes()).describedAs("task level allocations should not be visible at the pipeline level", new Object[0])).isEqualTo(pipelineLocalAllocation);
        pipelineLocalMemoryContext.setBytes(pipelineLocalMemoryContext.getBytes() - pipelineLocalAllocation);
        this.assertLocalMemoryAllocations(this.pipelineContext.getPipelineMemoryContext(), taskLocalAllocation, 0L);
        taskLocalMemoryContext.setBytes(taskLocalMemoryContext.getBytes() - taskLocalAllocation);
        this.assertLocalMemoryAllocations(this.taskContext.getTaskMemoryContext(), 0L, 0L);
    }

    @Test
    public void testStats() {
        LocalMemoryContext userMemory = this.operatorContext.localUserMemoryContext();
        userMemory.setBytes(100000000L);
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 0L);
        userMemory.setBytes(600000000L);
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 600000000L, 0L);
        userMemory.setBytes(userMemory.getBytes() - 300000000L);
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 300000000L, 0L);
        userMemory.setBytes(userMemory.getBytes() - 300000000L);
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 0L, 0L);
        this.operatorContext.destroy();
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 0L, 0L);
    }

    @Test
    public void testRevocableMemoryAllocations() {
        LocalMemoryContext userMemory = this.operatorContext.localUserMemoryContext();
        LocalMemoryContext revocableMemory = this.operatorContext.localRevocableMemoryContext();
        revocableMemory.setBytes(100000000L);
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 0L, 100000000L);
        userMemory.setBytes(100000000L);
        revocableMemory.setBytes(200000000L);
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 200000000L);
    }

    @Test
    public void testTrySetBytes() {
        LocalMemoryContext localMemoryContext = this.operatorContext.localUserMemoryContext();
        Assertions.assertThat((boolean)localMemoryContext.trySetBytes(100000000L)).isTrue();
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 0L);
        Assertions.assertThat((boolean)localMemoryContext.trySetBytes(200000000L)).isTrue();
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 200000000L, 0L);
        Assertions.assertThat((boolean)localMemoryContext.trySetBytes(100000000L)).isTrue();
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 0L);
        Assertions.assertThat((boolean)localMemoryContext.trySetBytes(this.memoryPool.getMaxBytes() + 1L)).isFalse();
        this.assertStats(this.operatorContext.getNestedOperatorStats(), this.driverContext.getDriverStats(), this.pipelineContext.getPipelineStats(), this.taskContext.getTaskStats(), 100000000L, 0L);
    }

    @Test
    public void testTrySetZeroBytesFullPool() {
        LocalMemoryContext localMemoryContext = this.operatorContext.localUserMemoryContext();
        TaskId taskId = new TaskId(new StageId("test_query", 0), 0, 0);
        this.memoryPool.reserve(taskId, "test", this.memoryPool.getFreeBytes());
        Assertions.assertThat((boolean)localMemoryContext.trySetBytes(localMemoryContext.getBytes())).isTrue();
    }

    @Test
    public void testDestroy() {
        LocalMemoryContext newLocalUserMemoryContext = this.operatorContext.localUserMemoryContext();
        LocalMemoryContext newLocalRevocableMemoryContext = this.operatorContext.localRevocableMemoryContext();
        newLocalRevocableMemoryContext.setBytes(200000L);
        newLocalUserMemoryContext.setBytes(400000L);
        Assertions.assertThat((long)this.operatorContext.getOperatorMemoryContext().getUserMemory()).isEqualTo(400000L);
        this.operatorContext.destroy();
        this.assertOperatorMemoryAllocations(this.operatorContext.getOperatorMemoryContext(), 0L, 0L);
    }

    private void assertStats(List<OperatorStats> nestedOperatorStats, DriverStats driverStats, PipelineStats pipelineStats, TaskStats taskStats, long expectedUserMemory, long expectedRevocableMemory) {
        OperatorStats operatorStats = (OperatorStats)Iterables.getOnlyElement(nestedOperatorStats);
        Assertions.assertThat((long)operatorStats.getUserMemoryReservation().toBytes()).isEqualTo(expectedUserMemory);
        Assertions.assertThat((long)driverStats.getUserMemoryReservation().toBytes()).isEqualTo(expectedUserMemory);
        Assertions.assertThat((long)pipelineStats.getUserMemoryReservation().toBytes()).isEqualTo(expectedUserMemory);
        Assertions.assertThat((long)taskStats.getUserMemoryReservation().toBytes()).isEqualTo(expectedUserMemory);
        Assertions.assertThat((long)operatorStats.getRevocableMemoryReservation().toBytes()).isEqualTo(expectedRevocableMemory);
        Assertions.assertThat((long)driverStats.getRevocableMemoryReservation().toBytes()).isEqualTo(expectedRevocableMemory);
        Assertions.assertThat((long)pipelineStats.getRevocableMemoryReservation().toBytes()).isEqualTo(expectedRevocableMemory);
        Assertions.assertThat((long)taskStats.getRevocableMemoryReservation().toBytes()).isEqualTo(expectedRevocableMemory);
    }

    private void assertOperatorMemoryAllocations(MemoryTrackingContext memoryTrackingContext, long expectedUserMemory, long expectedRevocableMemory) {
        ((AbstractLongAssert)Assertions.assertThat((long)memoryTrackingContext.getUserMemory()).describedAs("User memory verification failed", new Object[0])).isEqualTo(expectedUserMemory);
        ((AbstractLongAssert)Assertions.assertThat((long)this.memoryPool.getReservedBytes()).describedAs("Memory pool verification failed", new Object[0])).isEqualTo(expectedUserMemory);
        ((AbstractLongAssert)Assertions.assertThat((long)memoryTrackingContext.getRevocableMemory()).describedAs("Revocable memory verification failed", new Object[0])).isEqualTo(expectedRevocableMemory);
    }

    private void assertLocalMemoryAllocations(MemoryTrackingContext memoryTrackingContext, long expectedPoolMemory, long expectedContextUserMemory) {
        ((AbstractLongAssert)Assertions.assertThat((long)memoryTrackingContext.getUserMemory()).describedAs("User memory verification failed", new Object[0])).isEqualTo(expectedContextUserMemory);
        ((AbstractLongAssert)Assertions.assertThat((long)this.memoryPool.getReservedBytes()).describedAs("Memory pool verification failed", new Object[0])).isEqualTo(expectedPoolMemory);
    }
}

