/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.stats.TDigest;
import io.trino.block.BlockAssertions;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.aggregation.AggregationTestUtils;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.SqlVarbinary;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.BiFunction;
import java.util.stream.LongStream;
import org.junit.jupiter.api.Test;
import org.testng.Assert;

public class TestTDigestAggregationFunction {
    private static final BiFunction<Object, Object, Boolean> TDIGEST_EQUALITY = (actualBinary, expectedBinary) -> {
        if (actualBinary == null && expectedBinary == null) {
            return true;
        }
        Objects.requireNonNull(actualBinary, "actual value was null");
        Objects.requireNonNull(expectedBinary, "expected value was null");
        TDigest actual = TDigest.deserialize((Slice)Slices.wrappedBuffer((byte[])((SqlVarbinary)actualBinary).getBytes()));
        TDigest expected = TDigest.deserialize((Slice)Slices.wrappedBuffer((byte[])((SqlVarbinary)expectedBinary).getBytes()));
        return actual.getMin() == expected.getMin() && actual.getMax() == expected.getMax() && TestTDigestAggregationFunction.returnSimilarResults(actual, expected, (actual.getMax() - actual.getMin()) / 1000.0);
    };
    private static final TestingFunctionResolution FUNCTION_RESOLUTION = new TestingFunctionResolution();

    @Test
    public void testTdigestAggregationFunction() {
        Object weights = ImmutableList.of((Object)1.5, (Object)2.0, (Object)1.1, (Object)1.111, (Object)3.5, (Object)4.4, (Object)4.4, (Object)1.0, (Object)9.9, (Object)9.0);
        this.testAggregation(BlockAssertions.createDoublesBlock(1.0, null, 2.0, null, 3.0, null, 4.0, null, 5.0, null), BlockAssertions.createDoublesBlock((Iterable<Double>)weights), (List<Double>)ImmutableList.of((Object)1.5, (Object)1.1, (Object)3.5, (Object)4.4, (Object)9.9), 1.0, 2.0, 3.0, 4.0, 5.0);
        this.testAggregation(BlockAssertions.createDoublesBlock(null, null, null, null, null), BlockAssertions.createRepeatedValuesBlock(1.0, 5), (List<Double>)ImmutableList.of(), new double[0]);
        this.testAggregation(BlockAssertions.createDoublesBlock(-1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0, -9.0, -10.0), BlockAssertions.createDoublesBlock((Iterable<Double>)weights), (List<Double>)ImmutableList.copyOf((Collection)weights), -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, -7.0, -8.0, -9.0, -10.0);
        this.testAggregation(BlockAssertions.createDoublesBlock(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0), BlockAssertions.createDoublesBlock((Iterable<Double>)weights), (List<Double>)ImmutableList.copyOf((Collection)weights), 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0);
        this.testAggregation(BlockAssertions.createDoublesBlock(new Double[0]), BlockAssertions.createRepeatedValuesBlock(1.0, 0), (List<Double>)ImmutableList.of(), new double[0]);
        this.testAggregation(BlockAssertions.createDoublesBlock(1.0), BlockAssertions.createRepeatedValuesBlock(1.1, 1), (List<Double>)ImmutableList.of((Object)1.1), 1.0);
        weights = (List)LongStream.range(-1000L, 1000L).asDoubleStream().map(number -> 2.0 - number / 1000.0).boxed().collect(ImmutableList.toImmutableList());
        this.testAggregation(TDIGEST_EQUALITY, BlockAssertions.createDoubleSequenceBlock(-1000, 1000), BlockAssertions.createDoublesBlock((Iterable<Double>)weights), (List<Double>)ImmutableList.copyOf((Collection)weights), LongStream.range(-1000L, 1000L).asDoubleStream().toArray());
    }

    private void testAggregation(Block doublesBlock, Block weightsBlock, List<Double> weights, double ... inputs) {
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, "tdigest_agg", (List<TypeSignatureProvider>)TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE}), this.getExpectedValue(Collections.nCopies(inputs.length, 1.0), inputs), new Page(new Block[]{doublesBlock}));
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, "tdigest_agg", (List<TypeSignatureProvider>)TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE, DoubleType.DOUBLE}), this.getExpectedValue(weights, inputs), new Page(new Block[]{doublesBlock, weightsBlock}));
    }

    private void testAggregation(BiFunction<Object, Object, Boolean> equalAssertion, Block doublesBlock, Block weightsBlock, List<Double> weights, double ... inputs) {
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, "tdigest_agg", (List<TypeSignatureProvider>)TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE}), equalAssertion, "Test multiple values", new Page(new Block[]{doublesBlock}), this.getExpectedValue(Collections.nCopies(inputs.length, 1.0), inputs));
        AggregationTestUtils.assertAggregation(FUNCTION_RESOLUTION, "tdigest_agg", (List<TypeSignatureProvider>)TypeSignatureProvider.fromTypes((Type[])new Type[]{DoubleType.DOUBLE, DoubleType.DOUBLE}), equalAssertion, "Test multiple values", new Page(new Block[]{doublesBlock, weightsBlock}), this.getExpectedValue(weights, inputs));
    }

    private Object getExpectedValue(List<Double> weights, double ... values) {
        Assert.assertEquals((int)weights.size(), (int)values.length, (String)"mismatched weights and values");
        if (values.length == 0) {
            return null;
        }
        TDigest tdigest = new TDigest();
        for (int i = 0; i < weights.size(); ++i) {
            tdigest.add(values[i], weights.get(i).doubleValue());
        }
        return new SqlVarbinary(tdigest.serialize().getBytes());
    }

    private static boolean returnSimilarResults(TDigest first, TDigest second, double maxError) {
        double[] quantiles;
        for (double quantile : quantiles = new double[]{1.0E-4, 0.001, 0.01, 0.1, 0.5, 0.567, 0.89, 0.999}) {
            if (!(Math.abs(first.valueAt(quantile) - second.valueAt(quantile)) > maxError)) continue;
            return false;
        }
        return true;
    }
}

