/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.output;

import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.MoreExecutors;
import io.airlift.concurrent.Threads;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.SessionTestUtils;
import io.trino.block.BlockAssertions;
import io.trino.execution.StateMachine;
import io.trino.execution.buffer.BufferResult;
import io.trino.execution.buffer.BufferState;
import io.trino.execution.buffer.OutputBuffer;
import io.trino.execution.buffer.OutputBufferInfo;
import io.trino.execution.buffer.OutputBufferStatus;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.buffer.PagesSerdeFactory;
import io.trino.execution.buffer.PipelinedOutputBuffers;
import io.trino.execution.buffer.TestingPagesSerdeFactory;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.operator.BucketPartitionFunction;
import io.trino.operator.DriverContext;
import io.trino.operator.Operator;
import io.trino.operator.PartitionFunction;
import io.trino.operator.exchange.PageChannelSelector;
import io.trino.operator.output.PartitionedOutputOperator;
import io.trino.operator.output.PositionsAppenderFactory;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingTaskContext;
import io.trino.type.BlockTypeOperators;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
@Execution(value=ExecutionMode.CONCURRENT)
public class TestPagePartitionerPool {
    private ScheduledExecutorService driverYieldExecutor;

    @BeforeAll
    public void setUp() {
        this.driverYieldExecutor = Executors.newScheduledThreadPool(0, Threads.threadsNamed((String)"TestPagePartitionerPool-driver-yield-%s"));
    }

    @AfterAll
    public void destroy() {
        this.driverYieldExecutor.shutdown();
    }

    @Test
    public void testBuffersReusedAcrossSplits() {
        Page split = new Page(new Block[]{BlockAssertions.createLongsBlock(1)});
        DataSize maxPagePartitioningBufferSize = DataSize.ofBytes((long)(split.getSizeInBytes() + 1L));
        OutputBufferMock outputBuffer = new OutputBufferMock();
        AggregatedMemoryContext memoryContext = AggregatedMemoryContext.newSimpleAggregatedMemoryContext();
        PartitionedOutputOperator.PartitionedOutputOperatorFactory factory = TestPagePartitionerPool.createFactory(maxPagePartitioningBufferSize, outputBuffer, memoryContext);
        Assertions.assertThat((long)memoryContext.getBytes()).isEqualTo(0L);
        long initialRetainedBytesOneOperator = this.processSplitsConcurrently(factory, memoryContext, split);
        Assertions.assertThat((int)outputBuffer.totalEnqueuedPageCount()).isEqualTo(0);
        Assertions.assertThat((long)memoryContext.getBytes()).isGreaterThanOrEqualTo(initialRetainedBytesOneOperator + split.getSizeInBytes());
        this.processSplitsConcurrently(factory, memoryContext, split);
        Assertions.assertThat((int)outputBuffer.totalEnqueuedPageCount()).isEqualTo(1);
        Assertions.assertThat((long)memoryContext.getBytes()).isEqualTo(initialRetainedBytesOneOperator);
        long initialRetainedBytesTwoOperators = this.processSplitsConcurrently(factory, memoryContext, split, split);
        Assertions.assertThat((int)outputBuffer.totalEnqueuedPageCount()).isEqualTo(1);
        Assertions.assertThat((long)memoryContext.getBytes()).isGreaterThanOrEqualTo(initialRetainedBytesTwoOperators + 2L * split.getSizeInBytes());
        this.processSplitsConcurrently(factory, memoryContext, split, split);
        Assertions.assertThat((int)outputBuffer.totalEnqueuedPageCount()).isEqualTo(3);
        Assertions.assertThat((long)memoryContext.getBytes()).isEqualTo(initialRetainedBytesTwoOperators);
        this.processSplitsConcurrently(factory, memoryContext, split, split, split, split);
        Assertions.assertThat((int)outputBuffer.totalEnqueuedPageCount()).isEqualTo(5);
        Assertions.assertThat((long)memoryContext.getBytes()).isGreaterThanOrEqualTo(initialRetainedBytesTwoOperators + 2L * split.getSizeInBytes());
        this.processSplitsConcurrently(factory, memoryContext, split, split);
        Assertions.assertThat((int)outputBuffer.totalEnqueuedPageCount()).isEqualTo(7);
        Assertions.assertThat((long)memoryContext.getBytes()).isEqualTo(initialRetainedBytesTwoOperators);
        this.processSplitsConcurrently(factory, memoryContext, split);
        Assertions.assertThat((long)memoryContext.getBytes()).isGreaterThanOrEqualTo(initialRetainedBytesTwoOperators + split.getSizeInBytes());
        Operator operator = factory.createOperator(this.driverContext());
        factory.noMoreOperators();
        Assertions.assertThat((int)outputBuffer.totalEnqueuedPageCount()).isEqualTo(8);
        Assertions.assertThat((long)memoryContext.getBytes()).isEqualTo(initialRetainedBytesOneOperator);
        operator.addInput(split);
        operator.finish();
        Assertions.assertThat((int)outputBuffer.totalEnqueuedPageCount()).isEqualTo(9);
        Assertions.assertThat((long)memoryContext.getBytes()).isEqualTo(0L);
    }

    @Test
    public void testMemoryReleasedOnFailure() {
        Page split = new Page(new Block[]{BlockAssertions.createLongsBlock(1)});
        DataSize maxPagePartitioningBufferSize = DataSize.ofBytes((long)(split.getSizeInBytes() + 1L));
        final RuntimeException exception = new RuntimeException();
        OutputBufferMock outputBuffer = new OutputBufferMock(){

            @Override
            public void enqueue(int partition, List<Slice> pages) {
                throw exception;
            }
        };
        AggregatedMemoryContext memoryContext = AggregatedMemoryContext.newSimpleAggregatedMemoryContext();
        PartitionedOutputOperator.PartitionedOutputOperatorFactory factory = TestPagePartitionerPool.createFactory(maxPagePartitioningBufferSize, outputBuffer, memoryContext);
        long initialRetainedBytesOneOperator = this.processSplitsConcurrently(factory, memoryContext, split);
        Assertions.assertThat((long)memoryContext.getBytes()).isGreaterThanOrEqualTo(initialRetainedBytesOneOperator + split.getSizeInBytes());
        Assertions.assertThatThrownBy(() -> ((PartitionedOutputOperator.PartitionedOutputOperatorFactory)factory).noMoreOperators()).isEqualTo((Object)exception);
        Assertions.assertThat((long)memoryContext.getBytes()).isEqualTo(0L);
    }

    private static PartitionedOutputOperator.PartitionedOutputOperatorFactory createFactory(DataSize maxPagePartitioningBufferSize, OutputBufferMock outputBuffer, AggregatedMemoryContext memoryContext) {
        return new PartitionedOutputOperator.PartitionedOutputOperatorFactory(0, new PlanNodeId("0"), (List)ImmutableList.of((Object)BigintType.BIGINT), PageChannelSelector.identitySelection(), (PartitionFunction)new BucketPartitionFunction((page, position) -> 0, new int[1]), (List)ImmutableList.of((Object)0), (List)ImmutableList.of(), false, OptionalInt.empty(), (OutputBuffer)outputBuffer, (PagesSerdeFactory)new TestingPagesSerdeFactory(), maxPagePartitioningBufferSize, new PositionsAppenderFactory(new BlockTypeOperators()), Optional.empty(), memoryContext, 2, Optional.empty());
    }

    private long processSplitsConcurrently(PartitionedOutputOperator.PartitionedOutputOperatorFactory factory, AggregatedMemoryContext memoryContext, Page ... splits) {
        List operators = (List)Stream.of(splits).map(split -> factory.createOperator(this.driverContext())).collect(ImmutableList.toImmutableList());
        long initialRetainedBytes = memoryContext.getBytes();
        for (int i = 0; i < operators.size(); ++i) {
            ((Operator)operators.get(i)).addInput(splits[i]);
        }
        operators.forEach(Operator::finish);
        return initialRetainedBytes;
    }

    private DriverContext driverContext() {
        return TestingTaskContext.builder((Executor)MoreExecutors.directExecutor(), (ScheduledExecutorService)this.driverYieldExecutor, (Session)SessionTestUtils.TEST_SESSION).build().addPipelineContext(0, true, true, false).addDriverContext();
    }

    private static class OutputBufferMock
    implements OutputBuffer {
        Map<Integer, Integer> partitionBufferPages = new HashMap<Integer, Integer>();

        private OutputBufferMock() {
        }

        public int totalEnqueuedPageCount() {
            return this.partitionBufferPages.values().stream().mapToInt(Integer::intValue).sum();
        }

        public void enqueue(int partition, List<Slice> pages) {
            this.partitionBufferPages.compute(partition, (key, value) -> value == null ? pages.size() : value + pages.size());
        }

        public OutputBufferInfo getInfo() {
            throw new UnsupportedOperationException();
        }

        public BufferState getState() {
            throw new UnsupportedOperationException();
        }

        public double getUtilization() {
            throw new UnsupportedOperationException();
        }

        public OutputBufferStatus getStatus() {
            throw new UnsupportedOperationException();
        }

        public void addStateChangeListener(StateMachine.StateChangeListener<BufferState> stateChangeListener) {
            throw new UnsupportedOperationException();
        }

        public void setOutputBuffers(OutputBuffers newOutputBuffers) {
            throw new UnsupportedOperationException();
        }

        public ListenableFuture<BufferResult> get(PipelinedOutputBuffers.OutputBufferId bufferId, long token, DataSize maxSize) {
            throw new UnsupportedOperationException();
        }

        public void acknowledge(PipelinedOutputBuffers.OutputBufferId bufferId, long token) {
            throw new UnsupportedOperationException();
        }

        public void destroy(PipelinedOutputBuffers.OutputBufferId bufferId) {
            throw new UnsupportedOperationException();
        }

        public ListenableFuture<Void> isFull() {
            throw new UnsupportedOperationException();
        }

        public void enqueue(List<Slice> pages) {
            throw new UnsupportedOperationException();
        }

        public void setNoMorePages() {
            throw new UnsupportedOperationException();
        }

        public void destroy() {
            throw new UnsupportedOperationException();
        }

        public void abort() {
            throw new UnsupportedOperationException();
        }

        public long getPeakMemoryUsage() {
            throw new UnsupportedOperationException();
        }

        public Optional<Throwable> getFailureCause() {
            throw new UnsupportedOperationException();
        }
    }
}

