/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.output;

import com.google.common.collect.ImmutableList;
import io.airlift.concurrent.Threads;
import io.airlift.slice.Slice;
import io.airlift.units.DataSize;
import io.trino.Session;
import io.trino.block.BlockAssertions;
import io.trino.execution.StateMachine;
import io.trino.execution.buffer.BufferState;
import io.trino.execution.buffer.OutputBuffer;
import io.trino.execution.buffer.OutputBuffers;
import io.trino.execution.buffer.PagesSerdeFactory;
import io.trino.execution.buffer.PartitionedOutputBuffer;
import io.trino.jmh.Benchmarks;
import io.trino.memory.context.AggregatedMemoryContext;
import io.trino.memory.context.LocalMemoryContext;
import io.trino.memory.context.SimpleLocalMemoryContext;
import io.trino.operator.BucketPartitionFunction;
import io.trino.operator.DriverContext;
import io.trino.operator.HashGenerator;
import io.trino.operator.OperatorFactories;
import io.trino.operator.OutputFactory;
import io.trino.operator.PageTestUtils;
import io.trino.operator.PartitionFunction;
import io.trino.operator.PrecomputedHashGenerator;
import io.trino.operator.TaskContext;
import io.trino.operator.TrinoOperatorFactories;
import io.trino.operator.output.PartitionedOutputOperator;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockEncodingSerde;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.TestingBlockEncodingSerde;
import io.trino.spi.connector.BucketFunction;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.HashBucketFunction;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingSession;
import io.trino.testing.TestingTaskContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;
import org.testng.annotations.Test;

@State(value=Scope.Thread)
@OutputTimeUnit(value=TimeUnit.MILLISECONDS)
@Fork(value=2)
@Warmup(iterations=10, time=500, timeUnit=TimeUnit.MILLISECONDS)
@Measurement(iterations=10, time=500, timeUnit=TimeUnit.MILLISECONDS)
@BenchmarkMode(value={Mode.AverageTime})
public class BenchmarkPartitionedOutputOperator {
    private static final OperatorFactories OPERATOR_FACTORIES = new TrinoOperatorFactories();

    @Benchmark
    public void addPage(BenchmarkData data) {
        PartitionedOutputOperator operator = data.createPartitionedOutputOperator();
        for (int i = 0; i < data.getPageCount(); ++i) {
            operator.addInput(data.getDataPage());
        }
        operator.finish();
    }

    @Test
    public void verifyAddPage() {
        BenchmarkData data = new BenchmarkData();
        data.setup(null);
        new BenchmarkPartitionedOutputOperator().addPage(data);
    }

    private static RowType rowTypeWithDefaultFieldNames(List<Type> types) {
        List<Object> fields = new ArrayList<RowType.Field>();
        for (int i = 0; i < types.size(); ++i) {
            fields.add(new RowType.Field(Optional.of("field" + i), types.get(i)));
        }
        fields = Collections.unmodifiableList(fields);
        return RowType.from(fields);
    }

    private static MapType createMapType(Type keyType, Type valueType) {
        return new MapType(keyType, valueType, new TypeOperators());
    }

    public static void main(String[] args) throws Exception {
        Benchmarks.benchmark(BenchmarkPartitionedOutputOperator.class).withOptions(optionsBuilder -> optionsBuilder.jvmArgs(new String[]{"-Xmx16g"})).run();
    }

    static {
        try {
            List<BenchmarkData.TestType> types = List.of(BenchmarkData.TestType.BIGINT, BenchmarkData.TestType.DICTIONARY_BIGINT, BenchmarkData.TestType.RLE_BIGINT, BenchmarkData.TestType.LONG_DECIMAL, BenchmarkData.TestType.INTEGER, BenchmarkData.TestType.SMALLINT, BenchmarkData.TestType.BOOLEAN, BenchmarkData.TestType.VARCHAR, BenchmarkData.TestType.ARRAY_BIGINT);
            BenchmarkPartitionedOutputOperator benchmark = new BenchmarkPartitionedOutputOperator();
            types.forEach(type -> {
                BenchmarkData data = new BenchmarkData();
                data.setType((BenchmarkData.TestType)((Object)type));
                data.setup(null);
                data.setPageCount(1);
                benchmark.addPage(data);
            });
        }
        catch (Throwable throwable) {
            throw new RuntimeException(throwable);
        }
    }

    @State(value=Scope.Thread)
    public static class BenchmarkData {
        private static final int DEFAULT_POSITION_COUNT = 8192;
        private static final DataSize MAX_PARTITION_BUFFER_SIZE = DataSize.of((long)256L, (DataSize.Unit)DataSize.Unit.MEGABYTE);
        private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(Threads.daemonThreadsNamed((String)"BenchmarkPartitionedOutputOperator-executor-%s"));
        private static final ScheduledExecutorService SCHEDULER = Executors.newScheduledThreadPool(1, Threads.daemonThreadsNamed((String)"BenchmarkPartitionedOutputOperator-scheduledExecutor-%s"));
        private final OperatorFactories operatorFactories;
        private final Session session;
        @Param(value={"2", "16", "256"})
        private int partitionCount = 256;
        @Param(value={"true", "false"})
        private boolean enableCompression;
        @Param(value={"1", "2"})
        private int channelCount = 1;
        @Param(value={"8192"})
        private int positionCount = 8192;
        @Param(value={"BIGINT", "BIGINT_SKEWED_HASH", "DICTIONARY_BIGINT", "RLE_BIGINT", "BIGINT_PARTITION_CHANNEL_20_PERCENT", "BIGINT_DICTIONARY_PARTITION_CHANNEL_20_PERCENT", "BIGINT_DICTIONARY_PARTITION_CHANNEL_50_PERCENT", "BIGINT_DICTIONARY_PARTITION_CHANNEL_80_PERCENT", "BIGINT_DICTIONARY_PARTITION_CHANNEL_100_PERCENT", "RLE_PARTITION_BIGINT", "RLE_PARTITION_NULL_BIGINT", "LONG_DECIMAL", "INTEGER", "SMALLINT", "BOOLEAN", "VARCHAR", "ARRAY_BIGINT", "ARRAY_VARCHAR", "ARRAY_ARRAY_BIGINT", "MAP_BIGINT_BIGINT", "MAP_BIGINT_MAP_BIGINT_BIGINT", "ROW_BIGINT_BIGINT", "ROW_ARRAY_BIGINT_ARRAY_BIGINT"})
        private TestType type = TestType.BIGINT;
        @Param(value={"0", "0.2"})
        private float nullRate = 0.2f;
        private OptionalInt nullChannel;
        private List<Type> types;
        private int pageCount;
        private Page dataPage;
        private Blackhole blackhole;

        public BenchmarkData() {
            this(OPERATOR_FACTORIES, TestingSession.testSessionBuilder().build());
        }

        protected BenchmarkData(OperatorFactories operatorFactories, Session session) {
            this.operatorFactories = Objects.requireNonNull(operatorFactories, "operatorFactories is null");
            this.session = Objects.requireNonNull(session, "session is null");
        }

        public int getPageCount() {
            return this.pageCount;
        }

        public void setPageCount(int pageCount) {
            this.pageCount = pageCount;
        }

        public void setType(TestType type) {
            this.type = Objects.requireNonNull(type, "type is null");
        }

        public Page getDataPage() {
            return this.dataPage;
        }

        @Setup
        public void setup(Blackhole blackhole) {
            this.blackhole = blackhole;
            this.types = this.type.getTypes(this.channelCount);
            this.dataPage = this.type.createPage(this.types, this.positionCount, this.nullRate);
            this.pageCount = this.type.getPageCount();
            this.nullChannel = this.type.getNullChannel();
            this.types = ImmutableList.builder().addAll(this.types).add((Object)BigintType.BIGINT).build();
        }

        private static Page page(int positionCount, int channelCount, Supplier<Block> standardBlock, Block partitionBlock) {
            ImmutableList.Builder blocks = ImmutableList.builder();
            for (int i = 0; i < channelCount; ++i) {
                blocks.add((Object)standardBlock.get());
            }
            blocks.add((Object)partitionBlock);
            return new Page(positionCount, (Block[])blocks.build().toArray((Object[])new Block[0]));
        }

        private PartitionedOutputBuffer createPartitionedOutputBuffer() {
            OutputBuffers buffers = OutputBuffers.createInitialEmptyOutputBuffers((OutputBuffers.BufferType)OutputBuffers.BufferType.PARTITIONED);
            for (int partition = 0; partition < this.partitionCount; ++partition) {
                buffers = buffers.withBuffer(new OutputBuffers.OutputBufferId(partition), partition);
            }
            return this.createPartitionedBuffer(buffers.withNoMoreBufferIds(), DataSize.of((long)Long.MAX_VALUE, (DataSize.Unit)DataSize.Unit.BYTE));
        }

        private PartitionedOutputOperator createPartitionedOutputOperator() {
            BucketPartitionFunction partitionFunction = new BucketPartitionFunction((BucketFunction)new HashBucketFunction((HashGenerator)new PrecomputedHashGenerator(0), this.partitionCount), IntStream.range(0, this.partitionCount).toArray());
            PagesSerdeFactory serdeFactory = new PagesSerdeFactory((BlockEncodingSerde)new TestingBlockEncodingSerde(), this.enableCompression);
            PartitionedOutputBuffer buffer = this.createPartitionedOutputBuffer();
            TaskContext taskContext = this.createTaskContext();
            OutputFactory operatorFactory = this.operatorFactories.partitionedOutput(taskContext, (PartitionFunction)partitionFunction, (List)ImmutableList.of((Object)(this.types.size() - 1)), (List)ImmutableList.of(Optional.empty()), false, this.nullChannel, (OutputBuffer)buffer, MAX_PARTITION_BUFFER_SIZE);
            return (PartitionedOutputOperator)operatorFactory.createOutputOperator(0, new PlanNodeId("plan-node-0"), this.types, Function.identity(), serdeFactory).createOperator(this.createDriverContext(taskContext));
        }

        private DriverContext createDriverContext(TaskContext taskContext) {
            return taskContext.addPipelineContext(0, true, true, false).addDriverContext();
        }

        private TaskContext createTaskContext() {
            return TestingTaskContext.builder((Executor)EXECUTOR, (ScheduledExecutorService)SCHEDULER, (Session)this.session).build();
        }

        private TestingPartitionedOutputBuffer createPartitionedBuffer(OutputBuffers buffers, DataSize dataSize) {
            return new TestingPartitionedOutputBuffer("task-instance-id", (StateMachine<BufferState>)new StateMachine("bufferState", (Executor)SCHEDULER, (Object)BufferState.OPEN, (Iterable)BufferState.TERMINAL_BUFFER_STATES), buffers, dataSize, () -> new SimpleLocalMemoryContext(AggregatedMemoryContext.newSimpleAggregatedMemoryContext(), "test"), SCHEDULER, this.blackhole);
        }

        public static enum TestType {
            BIGINT((Type)BigintType.BIGINT, 5000),
            BIGINT_SKEWED_HASH((Type)BigintType.BIGINT, 5000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return BenchmarkData.page(positionCount, types.size(), () -> BlockAssertions.createRandomBlockForType((Type)BigintType.BIGINT, positionCount, nullRate), BlockAssertions.createRandomLongsBlock(positionCount, 2));
                }
            }
            ,
            DICTIONARY_BIGINT((Type)BigintType.BIGINT, 3000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return PageTestUtils.createRandomDictionaryPage(types, positionCount, nullRate);
                }
            }
            ,
            RLE_BIGINT((Type)BigintType.BIGINT, 3000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return PageTestUtils.createRandomRlePage(types, positionCount, nullRate);
                }
            }
            ,
            BIGINT_PARTITION_CHANNEL_20_PERCENT((Type)BigintType.BIGINT, 3000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return BenchmarkData.page(positionCount, types.size(), () -> BlockAssertions.createRandomBlockForType((Type)BigintType.BIGINT, positionCount, nullRate), BlockAssertions.createLongsBlock((Iterable)LongStream.range(0L, positionCount).mapToObj(value -> value % (long)(positionCount / 5)).collect(ImmutableList.toImmutableList())));
                }
            }
            ,
            BIGINT_DICTIONARY_PARTITION_CHANNEL_20_PERCENT((Type)BigintType.BIGINT, 3000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return BenchmarkData.page(positionCount, types.size(), () -> BlockAssertions.createRandomBlockForType((Type)BigintType.BIGINT, positionCount, nullRate), BlockAssertions.createLongDictionaryBlock(0, positionCount, positionCount / 5));
                }
            }
            ,
            BIGINT_DICTIONARY_PARTITION_CHANNEL_50_PERCENT((Type)BigintType.BIGINT, 3000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return BenchmarkData.page(positionCount, types.size(), () -> BlockAssertions.createRandomBlockForType((Type)BigintType.BIGINT, positionCount, nullRate), BlockAssertions.createLongDictionaryBlock(0, positionCount, positionCount / 2));
                }
            }
            ,
            BIGINT_DICTIONARY_PARTITION_CHANNEL_80_PERCENT((Type)BigintType.BIGINT, 3000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return BenchmarkData.page(positionCount, types.size(), () -> BlockAssertions.createRandomBlockForType((Type)BigintType.BIGINT, positionCount, nullRate), BlockAssertions.createLongDictionaryBlock(0, positionCount, (int)((double)positionCount * 0.8)));
                }
            }
            ,
            BIGINT_DICTIONARY_PARTITION_CHANNEL_100_PERCENT((Type)BigintType.BIGINT, 3000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return BenchmarkData.page(positionCount, types.size(), () -> BlockAssertions.createRandomBlockForType((Type)BigintType.BIGINT, positionCount, nullRate), BlockAssertions.createLongDictionaryBlock(0, positionCount, positionCount));
                }
            }
            ,
            RLE_PARTITION_BIGINT((Type)BigintType.BIGINT, 5000){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return BenchmarkData.page(positionCount, types.size(), () -> BlockAssertions.createRandomBlockForType((Type)BigintType.BIGINT, positionCount, nullRate), (Block)BlockAssertions.createRLEBlock(42L, positionCount));
                }
            }
            ,
            RLE_PARTITION_NULL_BIGINT((Type)BigintType.BIGINT, 20){

                @Override
                public Page createPage(List<Type> types, int positionCount, float nullRate) {
                    return BenchmarkData.page(positionCount, types.size(), () -> BlockAssertions.createRandomBlockForType((Type)BigintType.BIGINT, positionCount, nullRate), (Block)new RunLengthEncodedBlock(BlockAssertions.createLongsBlock(new Long[]{null}), positionCount));
                }

                @Override
                public OptionalInt getNullChannel() {
                    return OptionalInt.of(1);
                }
            }
            ,
            LONG_DECIMAL((Type)DecimalType.createDecimalType((int)19), 5000),
            INTEGER((Type)IntegerType.INTEGER, 5000),
            SMALLINT((Type)SmallintType.SMALLINT, 5000),
            BOOLEAN((Type)BooleanType.BOOLEAN, 5000),
            VARCHAR((Type)VarcharType.VARCHAR, 5000),
            ARRAY_BIGINT((Type)new ArrayType((Type)BigintType.BIGINT), 1000),
            ARRAY_VARCHAR((Type)new ArrayType((Type)VarcharType.VARCHAR), 1000),
            ARRAY_ARRAY_BIGINT((Type)new ArrayType((Type)new ArrayType((Type)BigintType.BIGINT)), 1000),
            MAP_BIGINT_BIGINT((Type)BenchmarkPartitionedOutputOperator.createMapType((Type)BigintType.BIGINT, (Type)BigintType.BIGINT), 1000),
            MAP_BIGINT_MAP_BIGINT_BIGINT((Type)BenchmarkPartitionedOutputOperator.createMapType((Type)BigintType.BIGINT, (Type)BenchmarkPartitionedOutputOperator.createMapType((Type)BigintType.BIGINT, (Type)BigintType.BIGINT)), 1000),
            ROW_BIGINT_BIGINT((Type)BenchmarkPartitionedOutputOperator.rowTypeWithDefaultFieldNames((List<Type>)ImmutableList.of((Object)BigintType.BIGINT, (Object)BigintType.BIGINT)), 1000),
            ROW_ARRAY_BIGINT_ARRAY_BIGINT((Type)BenchmarkPartitionedOutputOperator.rowTypeWithDefaultFieldNames((List<Type>)ImmutableList.of((Object)new ArrayType((Type)BigintType.BIGINT), (Object)new ArrayType((Type)BigintType.BIGINT))), 1000);

            private final Type type;
            private final int pageCount;

            private TestType(Type type, int pageCount) {
                this.type = Objects.requireNonNull(type, "type is null");
                this.pageCount = pageCount;
            }

            public Page createPage(List<Type> types, int positionCount, float nullRate) {
                return PageTestUtils.createRandomPage(types, positionCount, nullRate);
            }

            public int getPageCount() {
                return this.pageCount;
            }

            public OptionalInt getNullChannel() {
                return OptionalInt.empty();
            }

            public List<Type> getTypes(int channelCount) {
                return Collections.nCopies(channelCount, this.type);
            }
        }

        private static class TestingPartitionedOutputBuffer
        extends PartitionedOutputBuffer {
            private final Blackhole blackhole;

            public TestingPartitionedOutputBuffer(String taskInstanceId, StateMachine<BufferState> state, OutputBuffers outputBuffers, DataSize maxBufferSize, Supplier<LocalMemoryContext> systemMemoryContextSupplier, Executor notificationExecutor, Blackhole blackhole) {
                super(taskInstanceId, state, outputBuffers, maxBufferSize, systemMemoryContextSupplier, notificationExecutor);
                this.blackhole = blackhole;
            }

            public void enqueue(int partitionNumber, List<Slice> pages) {
                if (this.blackhole != null) {
                    this.blackhole.consume(pages);
                }
            }
        }
    }
}

