/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.jmh.Benchmarks;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.GroupByIdBlock;
import io.trino.operator.aggregation.AggregatorFactory;
import io.trino.operator.aggregation.GroupedAggregator;
import io.trino.operator.aggregation.TestingAggregationFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Int128;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.tree.QualifiedName;
import java.util.List;
import java.util.OptionalInt;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.options.WarmupMode;
import org.testng.Assert;
import org.testng.annotations.Test;

@State(value=Scope.Thread)
@OutputTimeUnit(value=TimeUnit.NANOSECONDS)
@Fork(value=3)
@Warmup(iterations=10)
@Measurement(iterations=10)
@BenchmarkMode(value={Mode.AverageTime})
public class BenchmarkDecimalAggregation {
    private static final int ELEMENT_COUNT = 1000000;

    @Benchmark
    @OperationsPerInvocation(value=1000000)
    public GroupedAggregator benchmark(BenchmarkData data) {
        GroupedAggregator aggregator = data.getPartialAggregatorFactory().createGroupedAggregator();
        aggregator.processPage(data.getGroupIds(), data.getValues());
        return aggregator;
    }

    @Benchmark
    @OperationsPerInvocation(value=1000000)
    public Block benchmarkEvaluateIntermediate(BenchmarkData data) {
        GroupedAggregator aggregator = data.getPartialAggregatorFactory().createGroupedAggregator();
        aggregator.processPage(data.getGroupIds(), data.getValues());
        BlockBuilder builder = aggregator.getType().createBlockBuilder(null, data.getGroupCount());
        for (int groupId = 0; groupId < data.getGroupCount(); ++groupId) {
            aggregator.evaluate(groupId, builder);
        }
        return builder.build();
    }

    @Benchmark
    public Block benchmarkEvaluateFinal(BenchmarkData data) {
        GroupedAggregator aggregator = data.getFinalAggregatorFactory().createGroupedAggregator();
        aggregator.processPage(data.getGroupIds(), data.getIntermediateValues());
        aggregator.processPage(data.getGroupIds(), data.getIntermediateValues());
        BlockBuilder builder = aggregator.getType().createBlockBuilder(null, data.getGroupCount());
        for (int groupId = 0; groupId < data.getGroupCount(); ++groupId) {
            aggregator.evaluate(groupId, builder);
        }
        return builder.build();
    }

    @Test
    public void verify() {
        BenchmarkData data = new BenchmarkData();
        data.setup();
        Assert.assertEquals((int)data.groupIds.getPositionCount(), (int)data.getValues().getPositionCount());
        new BenchmarkDecimalAggregation().benchmark(data);
    }

    public static void main(String[] args) throws Exception {
        new BenchmarkDecimalAggregation().verify();
        Benchmarks.benchmark(BenchmarkDecimalAggregation.class, (WarmupMode)WarmupMode.BULK).run();
    }

    @State(value=Scope.Thread)
    public static class BenchmarkData {
        @Param(value={"SHORT", "LONG"})
        private String type = "SHORT";
        @Param(value={"avg", "sum"})
        private String function = "avg";
        @Param(value={"10", "1000"})
        private int groupCount = 10;
        private AggregatorFactory partialAggregatorFactory;
        private AggregatorFactory finalAggregatorFactory;
        private GroupByIdBlock groupIds;
        private Page values;
        private Page intermediateValues;

        @Setup
        public void setup() {
            TestingFunctionResolution functionResolution = new TestingFunctionResolution();
            switch (this.type) {
                case "SHORT": {
                    DecimalType type = DecimalType.createDecimalType((int)14, (int)3);
                    this.values = this.createValues(functionResolution, type, (arg_0, arg_1) -> ((DecimalType)type).writeLong(arg_0, arg_1));
                    break;
                }
                case "LONG": {
                    DecimalType type = DecimalType.createDecimalType((int)30, (int)10);
                    this.values = this.createValues(functionResolution, type, (builder, value) -> type.writeObject(builder, (Object)Int128.valueOf((long)value)));
                    break;
                }
            }
            BlockBuilder ids = BigintType.BIGINT.createBlockBuilder(null, 1000000);
            for (int i = 0; i < 1000000; ++i) {
                BigintType.BIGINT.writeLong(ids, ThreadLocalRandom.current().nextLong(this.groupCount));
            }
            this.groupIds = new GroupByIdBlock((long)this.groupCount, ids.build());
            this.intermediateValues = new Page(new Block[]{this.createIntermediateValues(this.partialAggregatorFactory.createGroupedAggregator(), this.groupIds, this.values)});
        }

        private Block createIntermediateValues(GroupedAggregator aggregator, GroupByIdBlock groupIds, Page inputPage) {
            aggregator.processPage(groupIds, inputPage);
            BlockBuilder builder = aggregator.getType().createBlockBuilder(null, Math.toIntExact(groupIds.getGroupCount()));
            int groupId = 0;
            while ((long)groupId < groupIds.getGroupCount()) {
                aggregator.evaluate(groupId, builder);
                ++groupId;
            }
            return builder.build();
        }

        private Page createValues(TestingFunctionResolution functionResolution, DecimalType type, ValueWriter writer) {
            TestingAggregationFunction implementation = functionResolution.getAggregateFunction(QualifiedName.of((String)this.function), TypeSignatureProvider.fromTypes((Type[])new Type[]{type}));
            this.partialAggregatorFactory = implementation.createAggregatorFactory(AggregationNode.Step.PARTIAL, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty());
            this.finalAggregatorFactory = implementation.createAggregatorFactory(AggregationNode.Step.FINAL, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty());
            BlockBuilder builder = type.createBlockBuilder(null, 1000000);
            for (int i = 0; i < 1000000; ++i) {
                writer.write(builder, i);
            }
            return new Page(new Block[]{builder.build()});
        }

        public AggregatorFactory getPartialAggregatorFactory() {
            return this.partialAggregatorFactory;
        }

        public AggregatorFactory getFinalAggregatorFactory() {
            return this.finalAggregatorFactory;
        }

        public Page getValues() {
            return this.values;
        }

        public GroupByIdBlock getGroupIds() {
            return this.groupIds;
        }

        public int getGroupCount() {
            return this.groupCount;
        }

        public Page getIntermediateValues() {
            return this.intermediateValues;
        }

        static interface ValueWriter {
            public void write(BlockBuilder var1, int var2);
        }
    }
}

