/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Floats;
import io.airlift.stats.QuantileDigest;
import io.trino.block.BlockAssertions;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.aggregation.AggregationTestUtils;
import io.trino.operator.aggregation.FloatingPointBitsConverterUtil;
import io.trino.operator.aggregation.TestMergeQuantileDigestFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.SqlVarbinary;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.query.QueryAssertions;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.assertj.core.api.AssertProvider;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(value=TestInstance.Lifecycle.PER_CLASS)
public class TestQuantileDigestAggregationFunction {
    private static final Joiner ARRAY_JOINER = Joiner.on((String)",");
    private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution();
    private static final String NAME = "qdigest_agg";
    private QueryAssertions assertions;

    @BeforeAll
    public void init() {
        this.assertions = new QueryAssertions();
    }

    @AfterAll
    public void teardown() {
        this.assertions.close();
        this.assertions = null;
    }

    @Test
    public void testDoublesWithWeights() {
        this.testAggregationDouble(BlockAssertions.createDoublesBlock(1.0, null, 2.0, null, 3.0, null, 4.0, null, 5.0, null), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, 1.0, 2.0, 3.0, 4.0, 5.0);
        this.testAggregationDouble(BlockAssertions.createDoublesBlock(null, null, null, null, null), BlockAssertions.createRepeatedValuesBlock(1L, 5), Double.NaN, new double[0]);
        this.testAggregationDouble(BlockAssertions.createDoublesBlock(-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0, -9.0, -10.0), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0, -9.0, -10.0);
        this.testAggregationDouble(BlockAssertions.createDoublesBlock(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0);
        this.testAggregationDouble(BlockAssertions.createDoublesBlock(new Double[0]), BlockAssertions.createRepeatedValuesBlock(1L, 0), Double.NaN, new double[0]);
        this.testAggregationDouble(BlockAssertions.createDoublesBlock(1.0), BlockAssertions.createRepeatedValuesBlock(1L, 1), 0.01, 1.0);
        this.testAggregationDouble(BlockAssertions.createDoubleSequenceBlock(-1000, 1000), BlockAssertions.createRepeatedValuesBlock(1L, 2000), 0.01, LongStream.range(-1000L, 1000L).asDoubleStream().toArray());
    }

    @Test
    public void testRealsWithWeights() {
        this.testAggregationReal(BlockAssertions.createBlockOfReals(Float.valueOf(1.0f), null, Float.valueOf(2.0f), null, Float.valueOf(3.0f), null, Float.valueOf(4.0f), null, Float.valueOf(5.0f), null), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f);
        this.testAggregationReal(BlockAssertions.createBlockOfReals(null, null, null, null, null), BlockAssertions.createRepeatedValuesBlock(1L, 5), Double.NaN, new float[0]);
        this.testAggregationReal(BlockAssertions.createBlockOfReals(Float.valueOf(-1.0f), Float.valueOf(-2.0f), Float.valueOf(-3.0f), Float.valueOf(-4.0f), Float.valueOf(-5.0f), Float.valueOf(-6.0f), Float.valueOf(-7.0f), Float.valueOf(-8.0f), Float.valueOf(-9.0f), Float.valueOf(-10.0f)), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, -1.0f, -2.0f, -3.0f, -4.0f, -5.0f, -6.0f, -7.0f, -8.0f, -9.0f, -10.0f);
        this.testAggregationReal(BlockAssertions.createBlockOfReals(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(3.0f), Float.valueOf(4.0f), Float.valueOf(5.0f), Float.valueOf(6.0f), Float.valueOf(7.0f), Float.valueOf(8.0f), Float.valueOf(9.0f), Float.valueOf(10.0f)), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f);
        this.testAggregationReal(BlockAssertions.createBlockOfReals(new Float[0]), BlockAssertions.createRepeatedValuesBlock(1L, 0), Double.NaN, new float[0]);
        this.testAggregationReal(BlockAssertions.createBlockOfReals(Float.valueOf(1.0f)), BlockAssertions.createRepeatedValuesBlock(1L, 1), 0.01, 1.0f);
        this.testAggregationReal(BlockAssertions.createSequenceBlockOfReal(-1000, 1000), BlockAssertions.createRepeatedValuesBlock(1L, 2000), 0.01, Floats.toArray((Collection)((Collection)LongStream.range(-1000L, 1000L).mapToObj(Float::new).collect(ImmutableList.toImmutableList()))));
    }

    @Test
    public void testBigintsWithWeight() {
        this.testAggregationBigint(BlockAssertions.createLongsBlock(1L, null, 2L, null, 3L, null, 4L, null, 5L, null), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, 1L, 2L, 3L, 4L, 5L);
        this.testAggregationBigint(BlockAssertions.createLongsBlock(null, null, null, null, null), BlockAssertions.createRepeatedValuesBlock(1L, 5), Double.NaN, new long[0]);
        this.testAggregationBigint(BlockAssertions.createLongsBlock(-1, -2, -3, -4, -5, -6, -7, -8, -9, -10), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, -1L, -2L, -3L, -4L, -5L, -6L, -7L, -8L, -9L, -10L);
        this.testAggregationBigint(BlockAssertions.createLongsBlock(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), BlockAssertions.createRepeatedValuesBlock(1L, 10), 0.01, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L);
        this.testAggregationBigint(BlockAssertions.createLongsBlock(new int[0]), BlockAssertions.createRepeatedValuesBlock(1L, 0), Double.NaN, new long[0]);
        this.testAggregationBigint(BlockAssertions.createLongsBlock(1), BlockAssertions.createRepeatedValuesBlock(1L, 1), 0.01, 1L);
        this.testAggregationBigint(BlockAssertions.createLongSequenceBlock(-1000, 1000), BlockAssertions.createRepeatedValuesBlock(1L, 2000), 0.01, LongStream.range(-1000L, 1000L).toArray());
    }

    private void testAggregationBigint(Block inputBlock, Block weightsBlock, double maxError, long ... inputs) {
        this.testAggregationBigints(TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT}), new Page(new Block[]{inputBlock}), maxError, inputs);
        this.testAggregationBigints(TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT}), new Page(new Block[]{inputBlock, weightsBlock}), maxError, inputs);
        this.testAggregationBigints(TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BigintType.BIGINT, DoubleType.DOUBLE}), new Page(new Block[]{inputBlock, weightsBlock, BlockAssertions.createRepeatedValuesBlock(maxError, inputBlock.getPositionCount())}), maxError, inputs);
    }

    private void testAggregationReal(Block longsBlock, Block weightsBlock, double maxError, float ... inputs) {
        this.testAggregationReal(TypeSignatureProvider.fromTypes((Type[])new Type[]{RealType.REAL}), new Page(new Block[]{longsBlock}), maxError, inputs);
        this.testAggregationReal(TypeSignatureProvider.fromTypes((Type[])new Type[]{RealType.REAL, BigintType.BIGINT}), new Page(new Block[]{longsBlock, weightsBlock}), maxError, inputs);
        this.testAggregationReal(TypeSignatureProvider.fromTypes((Type[])new Type[]{RealType.REAL, BigintType.BIGINT, DoubleType.DOUBLE}), new Page(new Block[]{longsBlock, weightsBlock, BlockAssertions.createRepeatedValuesBlock(maxError, longsBlock.getPositionCount())}), maxError, inputs);
    }

    private void testAggregationDouble(Block longsBlock, Block weightsBlock, double maxError, double ... inputs) {
        this.testAggregationDoubles(TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE}), new Page(new Block[]{longsBlock}), maxError, inputs);
        this.testAggregationDoubles(TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE, BigintType.BIGINT}), new Page(new Block[]{longsBlock, weightsBlock}), maxError, inputs);
        this.testAggregationDoubles(TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE, BigintType.BIGINT, DoubleType.DOUBLE}), new Page(new Block[]{longsBlock, weightsBlock, BlockAssertions.createRepeatedValuesBlock(maxError, longsBlock.getPositionCount())}), maxError, inputs);
    }

    private void testAggregationBigints(List<TypeSignatureProvider> parameterTypes, Page page, double maxError, long ... inputs) {
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, NAME, parameterTypes, TestMergeQuantileDigestFunction.QDIGEST_EQUALITY, "test multiple positions", page, this.getExpectedValueLongs(maxError, inputs));
        List rows = Arrays.stream(inputs).sorted().boxed().collect(Collectors.toList());
        SqlVarbinary returned = (SqlVarbinary)AggregationTestUtils.aggregation(FUNCTION_RESOLUTION.getAggregateFunction(NAME, parameterTypes), page);
        this.assertPercentileWithinError("bigint", returned, maxError, rows, 0.1, 0.5, 0.9, 0.99);
    }

    private void testAggregationDoubles(List<TypeSignatureProvider> parameterTypes, Page page, double maxError, double ... inputs) {
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, NAME, parameterTypes, TestMergeQuantileDigestFunction.QDIGEST_EQUALITY, "test multiple positions", page, this.getExpectedValueDoubles(maxError, inputs));
        List rows = Arrays.stream(inputs).sorted().boxed().collect(Collectors.toList());
        SqlVarbinary returned = (SqlVarbinary)AggregationTestUtils.aggregation(FUNCTION_RESOLUTION.getAggregateFunction(NAME, parameterTypes), page);
        this.assertPercentileWithinError("double", returned, maxError, rows, 0.1, 0.5, 0.9, 0.99);
    }

    private void testAggregationReal(List<TypeSignatureProvider> parameterTypes, Page page, double maxError, float ... inputs) {
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, NAME, parameterTypes, TestMergeQuantileDigestFunction.QDIGEST_EQUALITY, "test multiple positions", page, this.getExpectedValuesFloats(maxError, inputs));
        List rows = Floats.asList((float[])inputs).stream().sorted().map(Float::doubleValue).collect(Collectors.toList());
        SqlVarbinary returned = (SqlVarbinary)AggregationTestUtils.aggregation(FUNCTION_RESOLUTION.getAggregateFunction(NAME, parameterTypes), page);
        this.assertPercentileWithinError("real", returned, maxError, rows, 0.1, 0.5, 0.9, 0.99);
    }

    private Object getExpectedValueLongs(double maxError, long ... values) {
        if (values.length == 0) {
            return null;
        }
        QuantileDigest qdigest = new QuantileDigest(maxError);
        Arrays.stream(values).forEach(arg_0 -> ((QuantileDigest)qdigest).add(arg_0));
        return new SqlVarbinary(qdigest.serialize().getBytes());
    }

    private Object getExpectedValueDoubles(double maxError, double ... values) {
        if (values.length == 0) {
            return null;
        }
        QuantileDigest qdigest = new QuantileDigest(maxError);
        Arrays.stream(values).forEach(value -> qdigest.add(FloatingPointBitsConverterUtil.doubleToSortableLong((double)value)));
        return new SqlVarbinary(qdigest.serialize().getBytes());
    }

    private Object getExpectedValuesFloats(double maxError, float ... values) {
        if (values.length == 0) {
            return null;
        }
        QuantileDigest qdigest = new QuantileDigest(maxError);
        Floats.asList((float[])values).forEach(value -> qdigest.add((long)FloatingPointBitsConverterUtil.floatToSortableInt((float)value.floatValue())));
        return new SqlVarbinary(qdigest.serialize().getBytes());
    }

    private void assertPercentileWithinError(String type, SqlVarbinary binary, double error, List<? extends Number> rows, double ... percentiles) {
        if (rows.isEmpty()) {
            return;
        }
        for (double percentile : percentiles) {
            this.assertPercentileWithinError(type, binary, error, rows, percentile);
        }
        this.assertPercentilesWithinError(type, binary, error, rows, percentiles);
    }

    private void assertPercentileWithinError(String type, SqlVarbinary binary, double error, List<? extends Number> rows, double percentile) {
        Number lowerBound = this.getLowerBound(error, rows, percentile);
        Number upperBound = this.getUpperBound(error, rows, percentile);
        ((QueryAssertions.ExpressionAssert)((Object)Assertions.assertThat((AssertProvider)this.assertions.expression(String.format("value_at_quantile(CAST(a AS qdigest(%s)), %s) >= %s", type, percentile, lowerBound)).binding("a", "X'%s'".formatted(binary.toString().replaceAll("\\s+", " ")))))).isEqualTo(true);
        ((QueryAssertions.ExpressionAssert)((Object)Assertions.assertThat((AssertProvider)this.assertions.expression(String.format("value_at_quantile(CAST(a AS qdigest(%s)), %s) <= %s", type, percentile, upperBound)).binding("a", "X'%s'".formatted(binary.toString().replaceAll("\\s+", " ")))))).isEqualTo(true);
    }

    private void assertPercentilesWithinError(String type, SqlVarbinary binary, double error, List<? extends Number> rows, double[] percentiles) {
        List boxedPercentiles = (List)Arrays.stream(percentiles).sorted().boxed().collect(ImmutableList.toImmutableList());
        List lowerBounds = (List)boxedPercentiles.stream().map(percentile -> this.getLowerBound(error, rows, (double)percentile)).collect(ImmutableList.toImmutableList());
        List upperBounds = (List)boxedPercentiles.stream().map(percentile -> this.getUpperBound(error, rows, (double)percentile)).collect(ImmutableList.toImmutableList());
        ((QueryAssertions.ExpressionAssert)((Object)Assertions.assertThat((AssertProvider)this.assertions.expression(String.format("zip_with(values_at_quantiles(CAST(a AS qdigest(%s)), ARRAY[%s]), ARRAY[%s], (value, lowerbound) -> value >= lowerbound)", type, ARRAY_JOINER.join((Iterable)boxedPercentiles), ARRAY_JOINER.join((Iterable)lowerBounds))).binding("a", "X'%s'".formatted(binary.toString().replaceAll("\\s+", " ")))))).hasType((Type)new ArrayType((Type)BooleanType.BOOLEAN)).isEqualTo(Collections.nCopies(percentiles.length, true));
        ((QueryAssertions.ExpressionAssert)((Object)Assertions.assertThat((AssertProvider)this.assertions.expression(String.format("zip_with(values_at_quantiles(CAST(a AS qdigest(%s)), ARRAY[%s]), ARRAY[%s], (value, upperbound) -> value <= upperbound)", type, ARRAY_JOINER.join((Iterable)boxedPercentiles), ARRAY_JOINER.join((Iterable)upperBounds))).binding("a", "X'%s'".formatted(binary.toString().replaceAll("\\s+", " ")))))).hasType((Type)new ArrayType((Type)BooleanType.BOOLEAN)).isEqualTo(Collections.nCopies(percentiles.length, true));
    }

    private Number getLowerBound(double error, List<? extends Number> rows, double percentile) {
        int medianIndex = (int)((double)rows.size() * percentile);
        int marginOfError = (int)((double)rows.size() * error / 2.0);
        return rows.get(Integer.max(medianIndex - marginOfError, 0));
    }

    private Number getUpperBound(double error, List<? extends Number> rows, double percentile) {
        int medianIndex = (int)((double)rows.size() * percentile);
        int marginOfError = (int)((double)rows.size() * error / 2.0);
        return rows.get(Integer.min(medianIndex + marginOfError, rows.size() - 1));
    }
}

