/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DateTimeEncoding;
import com.facebook.presto.common.type.DateType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.IntegerType;
import com.facebook.presto.common.type.RealType;
import com.facebook.presto.common.type.SmallintType;
import com.facebook.presto.common.type.SqlVarbinary;
import com.facebook.presto.common.type.TimeType;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.common.type.TimestampType;
import com.facebook.presto.common.type.TimestampWithTimeZoneType;
import com.facebook.presto.common.type.TinyintType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.aggregation.AggregationTestUtils;
import com.facebook.presto.operator.aggregation.sketch.kll.KllSketchAggregationState;
import com.facebook.presto.operator.scalar.AbstractTestFunctions;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.facebook.presto.type.IntervalYearMonthType;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slices;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import org.apache.datasketches.common.ArrayOfDoublesSerDe;
import org.apache.datasketches.common.ArrayOfItemsSerDe;
import org.apache.datasketches.kll.KllItemsSketch;
import org.apache.datasketches.memory.Memory;
import org.apache.datasketches.memory.WritableMemory;
import org.intellij.lang.annotations.Language;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

public class TestKllSketchAggregationFunction
extends AbstractTestFunctions {
    private static final MetadataManager metadata = MetadataManager.createTestMetadataManager();
    private static final FunctionAndTypeManager FUNCTION_AND_TYPE_MANAGER = metadata.getFunctionAndTypeManager();
    private static final JavaAggregationFunctionImplementation DOUBLE_FUNCTION = TestKllSketchAggregationFunction.getFunction(new Type[]{DoubleType.DOUBLE});
    private static final JavaAggregationFunctionImplementation DOUBLE_WITH_K_FUNCTION = TestKllSketchAggregationFunction.getFunction("sketch_kll_with_k", new Type[]{DoubleType.DOUBLE, BigintType.BIGINT});

    @Test
    public void testDouble() {
        double[] items = DoubleStream.iterate(0.0, i -> i + ThreadLocalRandom.current().nextDouble()).limit(100L).toArray();
        BlockBuilder out = DoubleType.DOUBLE.createBlockBuilder(null, items.length);
        KllItemsSketch sketch = KllItemsSketch.newHeapInstance(Double::compareTo, (ArrayOfItemsSerDe)new ArrayOfDoublesSerDe());
        Arrays.stream(items).forEach(item -> {
            DoubleType.DOUBLE.writeDouble(out, item);
            sketch.update((Object)item);
        });
        Block input = out.build();
        SqlVarbinary result = (SqlVarbinary)AggregationTestUtils.executeAggregation(DOUBLE_FUNCTION, input);
        KllItemsSketch recreated = KllItemsSketch.wrap((Memory)WritableMemory.writableWrap((byte[])result.getBytes()), Double::compareTo, (ArrayOfItemsSerDe)new ArrayOfDoublesSerDe());
        TestKllSketchAggregationFunction.checkSketchesEqual((List)Arrays.stream(items).boxed().collect(ImmutableList.toImmutableList()), sketch, recreated);
    }

    @Test
    public void testDoubleWithK() {
        double[] items = DoubleStream.iterate(0.0, i -> i + ThreadLocalRandom.current().nextDouble()).limit(100L).toArray();
        BlockBuilder out = DoubleType.DOUBLE.createBlockBuilder(null, items.length);
        BlockBuilder kBlock = BigintType.BIGINT.createBlockBuilder(null, items.length);
        int k = 150;
        KllItemsSketch sketch = KllItemsSketch.newHeapInstance((int)k, Double::compareTo, (ArrayOfItemsSerDe)new ArrayOfDoublesSerDe());
        Arrays.stream(items).forEach(item -> {
            DoubleType.DOUBLE.writeDouble(out, item);
            sketch.update((Object)item);
            BigintType.BIGINT.writeLong(kBlock, (long)k);
        });
        Block input = out.build();
        SqlVarbinary result = (SqlVarbinary)AggregationTestUtils.executeAggregation(DOUBLE_WITH_K_FUNCTION, input, kBlock.build());
        KllItemsSketch recreated = KllItemsSketch.wrap((Memory)WritableMemory.writableWrap((byte[])result.getBytes()), Double::compareTo, (ArrayOfItemsSerDe)new ArrayOfDoublesSerDe());
        TestKllSketchAggregationFunction.checkSketchesEqual(DoubleStream.of(items).boxed().collect(Collectors.toList()), sketch, recreated);
    }

    @Test
    public void testInvalidK() {
        double[] items = DoubleStream.iterate(0.0, i -> i + ThreadLocalRandom.current().nextDouble()).limit(10L).toArray();
        BlockBuilder inputBlock = DoubleType.DOUBLE.createBlockBuilder(null, items.length);
        BlockBuilder kBlockLow = BigintType.BIGINT.createBlockBuilder(null, items.length);
        Arrays.stream(items).forEach(item -> {
            DoubleType.DOUBLE.writeDouble(inputBlock, item);
            BigintType.BIGINT.writeLong(kBlockLow, 7L);
        });
        Block input = inputBlock.build();
        TestKllSketchAggregationFunction.assertThrows(() -> AggregationTestUtils.executeAggregation(DOUBLE_WITH_K_FUNCTION, inputBlock.build(), kBlockLow.build()), PrestoException.class, "k value must satisfy 8 <= k <= 65535: 7");
        BlockBuilder kBlockHigh = BigintType.BIGINT.createBlockBuilder(null, items.length);
        Arrays.stream(items).forEach(item -> BigintType.BIGINT.writeLong(kBlockHigh, 65536L));
        TestKllSketchAggregationFunction.assertThrows(() -> AggregationTestUtils.executeAggregation(DOUBLE_WITH_K_FUNCTION, input, kBlockHigh.build()), PrestoException.class, "k value must satisfy 8 <= k <= 65535: 65536");
    }

    @DataProvider(name="testTypes")
    public Object[][] testTypesProvider() {
        Object[][] objectArray = new Object[13][];
        Object[] objectArray2 = new Object[3];
        objectArray2[0] = TinyintType.TINYINT;
        objectArray2[1] = () -> (long)ThreadLocalRandom.current().nextInt(0, 127);
        objectArray2[2] = (arg_0, arg_1) -> ((TinyintType)TinyintType.TINYINT).writeLong(arg_0, arg_1);
        objectArray[0] = objectArray2;
        Object[] objectArray3 = new Object[3];
        objectArray3[0] = SmallintType.SMALLINT;
        objectArray3[1] = () -> (long)ThreadLocalRandom.current().nextInt(0, Short.MAX_VALUE);
        objectArray3[2] = (arg_0, arg_1) -> ((SmallintType)SmallintType.SMALLINT).writeLong(arg_0, arg_1);
        objectArray[1] = objectArray3;
        Object[] objectArray4 = new Object[3];
        objectArray4[0] = IntegerType.INTEGER;
        objectArray4[1] = () -> (long)ThreadLocalRandom.current().nextInt();
        objectArray4[2] = (arg_0, arg_1) -> ((IntegerType)IntegerType.INTEGER).writeLong(arg_0, arg_1);
        objectArray[2] = objectArray4;
        Object[] objectArray5 = new Object[3];
        objectArray5[0] = BigintType.BIGINT;
        objectArray5[1] = () -> ThreadLocalRandom.current().nextLong();
        objectArray5[2] = (arg_0, arg_1) -> ((BigintType)BigintType.BIGINT).writeLong(arg_0, arg_1);
        objectArray[3] = objectArray5;
        Object[] objectArray6 = new Object[3];
        objectArray6[0] = RealType.REAL;
        objectArray6[1] = () -> (long)Float.floatToIntBits(ThreadLocalRandom.current().nextFloat());
        objectArray6[2] = (arg_0, arg_1) -> ((RealType)RealType.REAL).writeLong(arg_0, arg_1);
        objectArray[4] = objectArray6;
        Object[] objectArray7 = new Object[3];
        objectArray7[0] = DoubleType.DOUBLE;
        objectArray7[1] = () -> ThreadLocalRandom.current().nextDouble();
        objectArray7[2] = (arg_0, arg_1) -> ((DoubleType)DoubleType.DOUBLE).writeDouble(arg_0, arg_1);
        objectArray[5] = objectArray7;
        Object[] objectArray8 = new Object[3];
        objectArray8[0] = VarcharType.VARCHAR;
        objectArray8[1] = () -> Slices.utf8Slice((String)String.valueOf("abcdefghijklmnopqrstuvwxyz".charAt(ThreadLocalRandom.current().nextInt(26))));
        objectArray8[2] = (arg_0, arg_1) -> ((VarcharType)VarcharType.VARCHAR).writeSlice(arg_0, arg_1);
        objectArray[6] = objectArray8;
        Object[] objectArray9 = new Object[3];
        objectArray9[0] = BooleanType.BOOLEAN;
        objectArray9[1] = () -> ThreadLocalRandom.current().nextBoolean();
        objectArray9[2] = (arg_0, arg_1) -> ((BooleanType)BooleanType.BOOLEAN).writeBoolean(arg_0, arg_1);
        objectArray[7] = objectArray9;
        Object[] objectArray10 = new Object[3];
        objectArray10[0] = DateType.DATE;
        objectArray10[1] = () -> ThreadLocalRandom.current().nextLong(0L, 100L);
        objectArray10[2] = (arg_0, arg_1) -> ((DateType)DateType.DATE).writeLong(arg_0, arg_1);
        objectArray[8] = objectArray10;
        Object[] objectArray11 = new Object[3];
        objectArray11[0] = TimeType.TIME;
        objectArray11[1] = () -> ThreadLocalRandom.current().nextLong(0L, 100L);
        objectArray11[2] = (arg_0, arg_1) -> ((TimeType)TimeType.TIME).writeLong(arg_0, arg_1);
        objectArray[9] = objectArray11;
        Object[] objectArray12 = new Object[3];
        objectArray12[0] = TimestampType.TIMESTAMP;
        objectArray12[1] = () -> ThreadLocalRandom.current().nextLong(0L, 100L);
        objectArray12[2] = (arg_0, arg_1) -> ((TimestampType)TimestampType.TIMESTAMP).writeLong(arg_0, arg_1);
        objectArray[10] = objectArray12;
        Object[] objectArray13 = new Object[3];
        objectArray13[0] = TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
        objectArray13[1] = () -> DateTimeEncoding.packDateTimeWithZone((long)ThreadLocalRandom.current().nextLong(0L, 100L), (TimeZoneKey)TimeZoneKey.UTC_KEY);
        objectArray13[2] = (arg_0, arg_1) -> ((TimestampWithTimeZoneType)TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE).writeLong(arg_0, arg_1);
        objectArray[11] = objectArray13;
        Object[] objectArray14 = new Object[3];
        objectArray14[0] = IntervalYearMonthType.INTERVAL_YEAR_MONTH;
        objectArray14[1] = () -> ThreadLocalRandom.current().nextLong(0L, 100L);
        objectArray14[2] = (arg_0, arg_1) -> ((IntervalYearMonthType)IntervalYearMonthType.INTERVAL_YEAR_MONTH).writeLong(arg_0, arg_1);
        objectArray[12] = objectArray14;
        return objectArray;
    }

    @Test(dataProvider="testTypes")
    public void testTypes(Type type, Supplier<Object> values, BiConsumer<BlockBuilder, Object> writeBlockValue) {
        int length = 100;
        JavaAggregationFunctionImplementation function = TestKllSketchAggregationFunction.getFunction(type);
        BlockBuilder out = type.createBlockBuilder(null, length);
        KllSketchAggregationState.SketchParameters parameters = KllSketchAggregationState.getSketchParameters((Type)type);
        KllItemsSketch sketch = KllItemsSketch.newHeapInstance((Comparator)parameters.getComparator(), (ArrayOfItemsSerDe)parameters.getSerde());
        List addedValues = (List)Stream.generate(values).limit(length).map(item -> {
            writeBlockValue.accept(out, item);
            sketch.update(parameters.getConversion().apply(item));
            return item;
        }).collect(ImmutableList.toImmutableList());
        Block input = out.build();
        SqlVarbinary result = (SqlVarbinary)AggregationTestUtils.executeAggregation(function, input);
        KllItemsSketch recreated = KllItemsSketch.wrap((Memory)WritableMemory.writableWrap((byte[])result.getBytes()), (Comparator)parameters.getComparator(), (ArrayOfItemsSerDe)parameters.getSerde());
        List sketchItems = addedValues.stream().map(parameters.getConversion()::apply).collect(Collectors.toList());
        TestKllSketchAggregationFunction.checkSketchesEqual(sketchItems, sketch, recreated);
    }

    @Test
    public void testEmptyInput() {
        AggregationTestUtils.assertAggregation(DOUBLE_FUNCTION, null, DoubleType.DOUBLE.createBlockBuilder(null, 0).build());
    }

    @Test
    public void testNulls() {
        AggregationTestUtils.assertAggregation(DOUBLE_FUNCTION, null, DoubleType.DOUBLE.createBlockBuilder(null, 2).appendNull().appendNull().build());
    }

    private static void assertThrows(Assert.ThrowingRunnable runnable, Class<?> exceptionType, @Language(value="regexp") String regex) {
        try {
            runnable.run();
            throw new AssertionError((Object)"no exception was thrown");
        }
        catch (Throwable e) {
            Assert.assertEquals(e.getClass(), exceptionType);
            Assert.assertTrue((boolean)Optional.ofNullable(e.getMessage()).orElse("").matches(regex), (String)String.format("Error message: '%s' didn't match regex: '%s'", e.getMessage(), regex));
            return;
        }
    }

    private static <T> void checkSketchesEqual(List<T> items, KllItemsSketch<T> expected, KllItemsSketch<T> actual) {
        Assert.assertEquals((int)expected.getK(), (int)actual.getK());
        items.forEach(item -> Assert.assertEquals((double)actual.getRank(item), (double)expected.getRank(item), (double)1.0E-8));
        Assert.assertEquals((long[])actual.getSortedView().getCumulativeWeights(), (long[])expected.getSortedView().getCumulativeWeights(), (String)"weights are not equal");
        Assert.assertEquals((Object[])actual.getSortedView().getQuantiles(), (Object[])expected.getSortedView().getQuantiles(), (String)"quantiles are not equal");
    }

    private static JavaAggregationFunctionImplementation getFunction(Type ... types) {
        return TestKllSketchAggregationFunction.getFunction("sketch_kll", types);
    }

    private static JavaAggregationFunctionImplementation getFunction(String name, Type ... types) {
        return FUNCTION_AND_TYPE_MANAGER.getJavaAggregateFunctionImplementation(metadata.getFunctionAndTypeManager().lookupFunction(name, TypeSignatureProvider.fromTypes((Type[])types)));
    }
}

