/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.trino.jmh.Benchmarks;
import io.trino.metadata.Metadata;
import io.trino.metadata.MetadataManager;
import io.trino.metadata.ResolvedFunction;
import io.trino.operator.GroupByIdBlock;
import io.trino.operator.aggregation.GroupedAccumulator;
import io.trino.operator.aggregation.InternalAggregationFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.spi.type.UnscaledDecimal128Arithmetic;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.tree.QualifiedName;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OperationsPerInvocation;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.runner.options.WarmupMode;
import org.testng.Assert;
import org.testng.annotations.Test;

@State(value=Scope.Thread)
@OutputTimeUnit(value=TimeUnit.NANOSECONDS)
@Fork(value=3)
@Warmup(iterations=10)
@Measurement(iterations=10)
@BenchmarkMode(value={Mode.AverageTime})
public class BenchmarkDecimalAggregation {
    private static final int ELEMENT_COUNT = 10000;

    @Benchmark
    @OperationsPerInvocation(value=10000)
    public GroupedAccumulator benchmark(BenchmarkData data) {
        GroupedAccumulator accumulator = data.getAccumulator();
        accumulator.addInput(data.getGroupIds(), data.getValues());
        return accumulator;
    }

    @Test
    public void verify() {
        BenchmarkData data = new BenchmarkData();
        data.setup();
        Assert.assertEquals((int)data.groupIds.getPositionCount(), (int)data.getValues().getPositionCount());
        new BenchmarkDecimalAggregation().benchmark(data);
    }

    public static void main(String[] args) throws Exception {
        new BenchmarkDecimalAggregation().verify();
        Benchmarks.benchmark(BenchmarkDecimalAggregation.class, (WarmupMode)WarmupMode.BULK).run();
    }

    @State(value=Scope.Thread)
    public static class BenchmarkData {
        @Param(value={"SHORT", "LONG"})
        private String type = "SHORT";
        @Param(value={"avg", "sum"})
        private String function = "avg";
        @Param(value={"10", "1000"})
        private int groupCount = 10;
        private GroupedAccumulator accumulator;
        private GroupByIdBlock groupIds;
        private Page values;

        @Setup
        public void setup() {
            MetadataManager metadata = MetadataManager.createTestMetadataManager();
            switch (this.type) {
                case "SHORT": {
                    DecimalType type = DecimalType.createDecimalType((int)14, (int)3);
                    this.values = this.createValues((Metadata)metadata, type, (arg_0, arg_1) -> ((DecimalType)type).writeLong(arg_0, arg_1));
                    break;
                }
                case "LONG": {
                    DecimalType type = DecimalType.createDecimalType((int)30, (int)10);
                    this.values = this.createValues((Metadata)metadata, type, (builder, value) -> type.writeSlice(builder, UnscaledDecimal128Arithmetic.unscaledDecimal((long)value)));
                    break;
                }
            }
            BlockBuilder ids = BigintType.BIGINT.createBlockBuilder(null, 10000);
            for (int i = 0; i < 10000; ++i) {
                BigintType.BIGINT.writeLong(ids, ThreadLocalRandom.current().nextLong(this.groupCount));
            }
            this.groupIds = new GroupByIdBlock((long)this.groupCount, ids.build());
        }

        private Page createValues(Metadata metadata, DecimalType type, ValueWriter writer) {
            ResolvedFunction resolvedFunction = metadata.resolveFunction(QualifiedName.of((String)this.function), TypeSignatureProvider.fromTypes((Type[])new Type[]{type}));
            InternalAggregationFunction implementation = metadata.getAggregateFunctionImplementation(resolvedFunction);
            this.accumulator = implementation.bind((List)ImmutableList.of((Object)0), Optional.empty()).createGroupedAccumulator();
            BlockBuilder builder = type.createBlockBuilder(null, 10000);
            for (int i = 0; i < 10000; ++i) {
                writer.write(builder, i);
            }
            return new Page(new Block[]{builder.build()});
        }

        public GroupedAccumulator getAccumulator() {
            return this.accumulator;
        }

        public Page getValues() {
            return this.values;
        }

        public GroupByIdBlock getGroupIds() {
            return this.groupIds;
        }

        static interface ValueWriter {
            public void write(BlockBuilder var1, int var2);
        }
    }
}

