/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.block.BlockAssertions;
import io.trino.jmh.Benchmarks;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.AggregationMetrics;
import io.trino.operator.aggregation.AggregatorFactory;
import io.trino.operator.aggregation.GroupedAggregator;
import io.trino.operator.aggregation.TestingAggregationFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.planner.plan.AggregationNode;
import java.util.List;
import java.util.OptionalInt;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.options.WarmupMode;

@State(value=Scope.Thread)
@OutputTimeUnit(value=TimeUnit.NANOSECONDS)
@Fork(value=3)
@Warmup(iterations=10)
@Measurement(iterations=10)
@BenchmarkMode(value={Mode.AverageTime})
public class BenchmarkDecimalAggregation {
    private static final Random RANDOM = new Random(633969769L);
    private static final int ELEMENT_COUNT = 1000000;

    @Benchmark
    @OperationsPerInvocation(value=1000000)
    public GroupedAggregator benchmark(BenchmarkData data) {
        GroupedAggregator aggregator = data.getPartialAggregatorFactory().createGroupedAggregator(new AggregationMetrics());
        aggregator.processPage(data.getGroupCount(), data.getGroupIds(), data.getValues());
        return aggregator;
    }

    @Benchmark
    @OperationsPerInvocation(value=1000000)
    public Block benchmarkEvaluateIntermediate(BenchmarkData data) {
        GroupedAggregator aggregator = data.getPartialAggregatorFactory().createGroupedAggregator(new AggregationMetrics());
        aggregator.processPage(data.getGroupCount(), data.getGroupIds(), data.getValues());
        BlockBuilder builder = aggregator.getType().createBlockBuilder(null, data.getGroupCount());
        for (int groupId = 0; groupId < data.getGroupCount(); ++groupId) {
            aggregator.evaluate(groupId, builder);
        }
        return builder.build();
    }

    @Benchmark
    public Block benchmarkEvaluateFinal(BenchmarkData data) {
        GroupedAggregator aggregator = data.getFinalAggregatorFactory().createGroupedAggregator(new AggregationMetrics());
        aggregator.processPage(data.getGroupCount(), data.getGroupIds(), data.getIntermediateValues());
        aggregator.processPage(data.getGroupCount(), data.getGroupIds(), data.getIntermediateValues());
        BlockBuilder builder = aggregator.getType().createBlockBuilder(null, data.getGroupCount());
        for (int groupId = 0; groupId < data.getGroupCount(); ++groupId) {
            aggregator.evaluate(groupId, builder);
        }
        return builder.build();
    }

    @Test
    public void verify() {
        BenchmarkData data = new BenchmarkData();
        data.setup();
        Assertions.assertThat((int)data.getGroupIds().length).isEqualTo(data.getValues().getPositionCount());
        new BenchmarkDecimalAggregation().benchmark(data);
    }

    public static void main(String[] args) throws Exception {
        new BenchmarkDecimalAggregation().verify();
        Benchmarks.benchmark(BenchmarkDecimalAggregation.class, (WarmupMode)WarmupMode.BULK).run();
    }

    @State(value=Scope.Thread)
    public static class BenchmarkData {
        @Param(value={"SHORT", "LONG"})
        private String type = "SHORT";
        @Param(value={"avg", "sum"})
        private String function = "avg";
        @Param(value={"10", "1000"})
        private int groupCount = 10;
        @Param(value={"0.0", "0.05"})
        private float nullRate;
        private AggregatorFactory partialAggregatorFactory;
        private AggregatorFactory finalAggregatorFactory;
        private int[] groupIds;
        private Page values;
        private Page intermediateValues;

        @Setup
        public void setup() {
            TestingFunctionResolution functionResolution = new TestingFunctionResolution();
            switch (this.type) {
                case "SHORT": {
                    DecimalType type = DecimalType.createDecimalType((int)14, (int)3);
                    this.values = this.createValues(functionResolution, (Type)type);
                    break;
                }
                case "LONG": {
                    DecimalType type = DecimalType.createDecimalType((int)30, (int)10);
                    this.values = this.createValues(functionResolution, (Type)type);
                    break;
                }
            }
            int[] ids = new int[1000000];
            for (int i = 0; i < 1000000; ++i) {
                ids[i] = RANDOM.nextInt(this.groupCount);
            }
            this.groupIds = ids;
            this.intermediateValues = new Page(new Block[]{this.createIntermediateValues(this.partialAggregatorFactory.createGroupedAggregator(new AggregationMetrics()), this.groupIds, this.values)});
        }

        private Block createIntermediateValues(GroupedAggregator aggregator, int[] groupIds, Page inputPage) {
            aggregator.processPage(this.groupCount, groupIds, inputPage);
            BlockBuilder builder = aggregator.getType().createBlockBuilder(null, this.groupCount);
            for (int groupId = 0; groupId < this.groupCount; ++groupId) {
                aggregator.evaluate(groupId, builder);
            }
            return builder.build();
        }

        private Page createValues(TestingFunctionResolution functionResolution, Type type) {
            TestingAggregationFunction implementation = functionResolution.getAggregateFunction(this.function, TypeSignatureProvider.fromTypes((Type[])new Type[]{type}));
            this.partialAggregatorFactory = implementation.createAggregatorFactory(AggregationNode.Step.PARTIAL, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty());
            this.finalAggregatorFactory = implementation.createAggregatorFactory(AggregationNode.Step.FINAL, (List<Integer>)ImmutableList.of((Object)0), OptionalInt.empty());
            return new Page(new Block[]{BlockAssertions.createRandomBlockForType(type, 1000000, this.nullRate)});
        }

        public AggregatorFactory getPartialAggregatorFactory() {
            return this.partialAggregatorFactory;
        }

        public AggregatorFactory getFinalAggregatorFactory() {
            return this.finalAggregatorFactory;
        }

        public Page getValues() {
            return this.values;
        }

        public int[] getGroupIds() {
            return this.groupIds;
        }

        public int getGroupCount() {
            return this.groupCount;
        }

        public Page getIntermediateValues() {
            return this.intermediateValues;
        }
    }
}

