/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.slice.XxHash64;
import io.trino.jmh.Benchmarks;
import io.trino.operator.FlatGroupByHash;
import io.trino.operator.GroupByHash;
import io.trino.operator.UpdateMemory;
import io.trino.operator.Work;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import io.trino.spi.block.DictionaryBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.AbstractLongType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spi.type.VarcharType;
import io.trino.sql.gen.JoinCompiler;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.IntStream;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.profile.GCProfiler;
import org.openjdk.jmh.runner.RunnerException;

@State(value=Scope.Thread)
@OutputTimeUnit(value=TimeUnit.NANOSECONDS)
@Fork(value=1)
@Warmup(iterations=5, time=500, timeUnit=TimeUnit.MILLISECONDS)
@Measurement(iterations=5, time=500, timeUnit=TimeUnit.MILLISECONDS)
@BenchmarkMode(value={Mode.AverageTime})
public class BenchmarkGroupByHash {
    private static final int POSITIONS = 10000000;
    private static final String GROUP_COUNT_STRING = "3000000";
    private static final int GROUP_COUNT = Integer.parseInt("3000000");
    private static final int EXPECTED_SIZE = 10000;
    private static final TypeOperators TYPE_OPERATORS = new TypeOperators();
    private static final JoinCompiler JOIN_COMPILER = new JoinCompiler(TYPE_OPERATORS);

    @Benchmark
    @OperationsPerInvocation(value=10000000)
    public Object addPages(MultiChannelBenchmarkData data) {
        FlatGroupByHash groupByHash = new FlatGroupByHash(data.getTypes(), data.isHashEnabled(), 10000, false, JOIN_COMPILER, UpdateMemory.NOOP);
        BenchmarkGroupByHash.addInputPagesToHash((GroupByHash)groupByHash, data.getPages());
        return groupByHash;
    }

    @Benchmark
    @OperationsPerInvocation(value=10000000)
    public Object writeData(WriteMultiChannelBenchmarkData data) {
        int[] groupIdsByPhysicalOrder;
        GroupByHash groupByHash = data.getPrefilledHash();
        ImmutableList.Builder pages = ImmutableList.builder();
        PageBuilder pageBuilder = new PageBuilder(10000000, data.getOutputTypes());
        for (int groupId : groupIdsByPhysicalOrder = data.getGroupIdsByPhysicalOrder()) {
            pageBuilder.declarePosition();
            groupByHash.appendValuesTo(groupId, pageBuilder);
            if (!pageBuilder.isFull()) continue;
            pages.add((Object)pageBuilder.build());
            pageBuilder.reset();
        }
        pages.add((Object)pageBuilder.build());
        return pageBuilder.build();
    }

    private static void addInputPagesToHash(GroupByHash groupByHash, List<Page> pages) {
        for (Page page : pages) {
            boolean finished;
            Work work = groupByHash.addPage(page);
            while (!(finished = work.process())) {
            }
        }
    }

    private static List<Page> createBigintPages(int positionCount, int groupCount, int channelCount, boolean hashEnabled, boolean useMixedBlockTypes) {
        ImmutableList types = Collections.nCopies(channelCount, BigintType.BIGINT);
        ImmutableList.Builder pages = ImmutableList.builder();
        if (hashEnabled) {
            types = ImmutableList.copyOf((Iterable)Iterables.concat(types, (Iterable)ImmutableList.of((Object)BigintType.BIGINT)));
        }
        PageBuilder pageBuilder = new PageBuilder(types);
        int pageCount = 0;
        for (int position = 0; position < positionCount; ++position) {
            int rand = ThreadLocalRandom.current().nextInt(groupCount);
            pageBuilder.declarePosition();
            for (int numChannel = 0; numChannel < channelCount; ++numChannel) {
                BigintType.BIGINT.writeLong(pageBuilder.getBlockBuilder(numChannel), (long)rand);
            }
            if (hashEnabled) {
                BigintType.BIGINT.writeLong(pageBuilder.getBlockBuilder(channelCount), AbstractLongType.hash((long)rand));
            }
            if (!pageBuilder.isFull()) continue;
            Page page = pageBuilder.build();
            pageBuilder.reset();
            if (useMixedBlockTypes) {
                if (pageCount % 3 == 0) {
                    pages.add((Object)page);
                } else if (pageCount % 3 == 1) {
                    Block[] blocks = new Block[page.getChannelCount()];
                    for (int channel = 0; channel < blocks.length; ++channel) {
                        blocks[channel] = RunLengthEncodedBlock.create((Block)page.getBlock(channel).getSingleValueBlock(0), (int)page.getPositionCount());
                    }
                    pages.add((Object)new Page(blocks));
                } else {
                    int[] positions = IntStream.range(0, page.getPositionCount()).toArray();
                    Block[] blocks = new Block[page.getChannelCount()];
                    for (int channel = 0; channel < page.getChannelCount(); ++channel) {
                        blocks[channel] = DictionaryBlock.create((int)positions.length, (Block)page.getBlock(channel), (int[])positions);
                    }
                    pages.add((Object)new Page(blocks));
                }
            } else {
                pages.add((Object)page);
            }
            ++pageCount;
        }
        pages.add((Object)pageBuilder.build());
        return pages.build();
    }

    private static List<Page> createVarcharPages(int positionCount, int groupCount, int channelCount, boolean hashEnabled) {
        ImmutableList types = Collections.nCopies(channelCount, VarcharType.VARCHAR);
        ImmutableList.Builder pages = ImmutableList.builder();
        if (hashEnabled) {
            types = ImmutableList.copyOf((Iterable)Iterables.concat(types, (Iterable)ImmutableList.of((Object)BigintType.BIGINT)));
        }
        PageBuilder pageBuilder = new PageBuilder(types);
        for (int position = 0; position < positionCount; ++position) {
            int rand = ThreadLocalRandom.current().nextInt(groupCount);
            Slice value = Slices.wrappedHeapBuffer((ByteBuffer)ByteBuffer.allocate(4).putInt(rand).flip());
            pageBuilder.declarePosition();
            for (int channel = 0; channel < channelCount; ++channel) {
                VarcharType.VARCHAR.writeSlice(pageBuilder.getBlockBuilder(channel), value);
            }
            if (hashEnabled) {
                BigintType.BIGINT.writeLong(pageBuilder.getBlockBuilder(channelCount), XxHash64.hash((Slice)value));
            }
            if (!pageBuilder.isFull()) continue;
            pages.add((Object)pageBuilder.build());
            pageBuilder.reset();
        }
        pages.add((Object)pageBuilder.build());
        return pages.build();
    }

    public static void main(String[] args) throws RunnerException {
        MultiChannelBenchmarkData data = new MultiChannelBenchmarkData();
        data.setup();
        new BenchmarkGroupByHash().addPages(data);
        WriteMultiChannelBenchmarkData writeData = new WriteMultiChannelBenchmarkData();
        writeData.setup(data);
        new BenchmarkGroupByHash().writeData(writeData);
        Benchmarks.benchmark(BenchmarkGroupByHash.class).withOptions(optionsBuilder -> optionsBuilder.addProfiler(GCProfiler.class).jvmArgs(new String[]{"-Xmx10g"})).run();
    }

    @State(value=Scope.Thread)
    public static class MultiChannelBenchmarkData {
        @Param(value={"1", "5", "10", "15", "20"})
        private int channelCount = 1;
        @Param(value={"3000000"})
        private int groupCount = GROUP_COUNT;
        @Param(value={"true", "false"})
        private boolean hashEnabled;
        @Param(value={"VARCHAR", "BIGINT"})
        private String dataType = "VARCHAR";
        private List<Page> pages;
        private List<Type> types;

        @Setup
        public void setup() {
            switch (this.dataType) {
                case "VARCHAR": {
                    this.types = Collections.nCopies(this.channelCount, VarcharType.VARCHAR);
                    this.pages = BenchmarkGroupByHash.createVarcharPages(10000000, this.groupCount, this.channelCount, this.hashEnabled);
                    break;
                }
                case "BIGINT": {
                    this.types = Collections.nCopies(this.channelCount, BigintType.BIGINT);
                    this.pages = BenchmarkGroupByHash.createBigintPages(10000000, this.groupCount, this.channelCount, this.hashEnabled, false);
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unsupported dataType");
                }
            }
        }

        public int getChannelCount() {
            return this.channelCount;
        }

        public List<Page> getPages() {
            return this.pages;
        }

        public boolean isHashEnabled() {
            return this.hashEnabled;
        }

        public List<Type> getTypes() {
            return this.types;
        }
    }

    @State(value=Scope.Thread)
    public static class WriteMultiChannelBenchmarkData {
        private GroupByHash prefilledHash;
        private int[] groupIdsByPhysicalOrder;
        private List<Type> outputTypes;

        @Setup
        public void setup(MultiChannelBenchmarkData data) {
            this.prefilledHash = new FlatGroupByHash(data.getTypes(), data.isHashEnabled(), 10000, false, JOIN_COMPILER, UpdateMemory.NOOP);
            BenchmarkGroupByHash.addInputPagesToHash(this.prefilledHash, data.getPages());
            Integer[] groupIds = new Integer[this.prefilledHash.getGroupCount()];
            for (int i = 0; i < groupIds.length; ++i) {
                groupIds[i] = i;
            }
            GroupByHash groupByHash = this.prefilledHash;
            if (groupByHash instanceof FlatGroupByHash) {
                FlatGroupByHash flatGroupByHash = (FlatGroupByHash)groupByHash;
                Arrays.sort(groupIds, Comparator.comparing(arg_0 -> ((FlatGroupByHash)flatGroupByHash).getPhysicalPosition(arg_0)));
            }
            this.groupIdsByPhysicalOrder = Arrays.stream(groupIds).mapToInt(Integer::intValue).toArray();
            this.outputTypes = new ArrayList<Type>(data.getTypes());
            if (data.isHashEnabled()) {
                this.outputTypes.add((Type)BigintType.BIGINT);
            }
        }

        public GroupByHash getPrefilledHash() {
            return this.prefilledHash;
        }

        public int[] getGroupIdsByPhysicalOrder() {
            return this.groupIdsByPhysicalOrder;
        }

        public List<Type> getOutputTypes() {
            return this.outputTypes;
        }
    }
}

