/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.aggregation;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.io.BaseEncoding;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.stats.cardinality.HyperLogLog;
import io.airlift.testing.Assertions;
import io.trino.metadata.Metadata;
import io.trino.metadata.MetadataManager;
import io.trino.operator.aggregation.AggregationTestUtils;
import io.trino.operator.aggregation.InternalAggregationFunction;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.SqlVarbinary;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.tree.QualifiedName;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.testng.Assert;
import org.testng.annotations.Test;

public abstract class AbstractTestApproximateSetGeneric {
    private static final double STD_ERROR = 0.023;
    protected static final Metadata metadata = MetadataManager.createTestMetadataManager();

    protected abstract Type getValueType();

    protected abstract Object randomValue();

    protected int getUniqueValuesCount() {
        return 20000;
    }

    @Test
    public void testNoPositions() {
        Assert.assertNull((Object)this.estimateSet((List<?>)ImmutableList.of()));
        Assert.assertNull((Object)this.estimateSetPartial((List<?>)ImmutableList.of()));
    }

    @Test
    public void testSinglePosition() {
        this.assertCount((List<?>)ImmutableList.of((Object)this.randomValue()), 1L);
    }

    @Test
    public void testAllPositionsNull() {
        List<Object> justNulls = Collections.nCopies(100, null);
        Assert.assertNull((Object)this.estimateSet(justNulls));
        Assert.assertNull((Object)this.estimateSetPartial(justNulls));
        Assert.assertNull((Object)this.esitmateSetGrouped(justNulls));
    }

    @Test
    public void testMixedNullsAndNonNulls() {
        int uniques = this.getUniqueValuesCount();
        List<Object> baseline = this.createRandomSample(uniques, (int)((double)uniques * 1.5));
        Iterator<Object> iterator = baseline.iterator();
        ArrayList<Object> mixed = new ArrayList<Object>();
        while (iterator.hasNext()) {
            mixed.add(ThreadLocalRandom.current().nextBoolean() ? null : iterator.next());
        }
        this.assertCount(mixed, this.esitmateSetGrouped(baseline).cardinality());
    }

    @Test
    public void testMultiplePositions() {
        DescriptiveStatistics stats = new DescriptiveStatistics();
        for (int i = 0; i < 500; ++i) {
            int uniques = ThreadLocalRandom.current().nextInt(this.getUniqueValuesCount()) + 1;
            List<Object> values = this.createRandomSample(uniques, (int)((double)uniques * 1.5));
            long actualCount = this.esitmateSetGrouped(values).cardinality();
            double error = (double)(actualCount - (long)uniques) * 1.0 / (double)uniques;
            stats.addValue(error);
        }
        Assertions.assertLessThan((Comparable)Double.valueOf(stats.getMean()), (Comparable)Double.valueOf(0.01));
        Assertions.assertLessThan((Comparable)Double.valueOf(stats.getStandardDeviation()), (Comparable)Double.valueOf(0.033));
    }

    @Test
    public void testMultiplePositionsPartial() {
        for (int i = 0; i < 100; ++i) {
            int uniques = ThreadLocalRandom.current().nextInt(this.getUniqueValuesCount()) + 1;
            List<Object> values = this.createRandomSample(uniques, (int)((double)uniques * 1.5));
            Assert.assertEquals((long)this.estimateSetPartial(values).cardinality(), (long)this.esitmateSetGrouped(values).cardinality());
        }
    }

    @Test
    public void testResultStability() {
        for (int i = 0; i < 10; ++i) {
            ArrayList<Object> sample = new ArrayList<Object>(this.getResultStabilityTestSample());
            Collections.shuffle(sample);
            Assert.assertEquals((String)BaseEncoding.base16().encode(this.estimateSet(sample).serialize().getBytes()), (String)this.getResultStabilityExpected());
            Assert.assertEquals((String)BaseEncoding.base16().encode(this.estimateSetPartial(sample).serialize().getBytes()), (String)this.getResultStabilityExpected());
            Assert.assertEquals((String)BaseEncoding.base16().encode(this.esitmateSetGrouped(sample).serialize().getBytes()), (String)this.getResultStabilityExpected());
        }
    }

    protected abstract List<Object> getResultStabilityTestSample();

    protected abstract String getResultStabilityExpected();

    protected void assertCount(List<?> values, long expectedCount) {
        if (!values.isEmpty()) {
            HyperLogLog actualSet = this.esitmateSetGrouped(values);
            Assert.assertEquals((long)actualSet.cardinality(), (long)expectedCount);
        }
        Assert.assertEquals((long)this.estimateSet(values).cardinality(), (long)expectedCount);
        Assert.assertEquals((long)this.estimateSetPartial(values).cardinality(), (long)expectedCount);
    }

    private HyperLogLog esitmateSetGrouped(List<?> values) {
        SqlVarbinary hllSerialized = (SqlVarbinary)AggregationTestUtils.groupedAggregation(this.getAggregationFunction(), this.createPage(values));
        if (hllSerialized == null) {
            return null;
        }
        return HyperLogLog.newInstance((Slice)Slices.wrappedBuffer((byte[])hllSerialized.getBytes()));
    }

    private HyperLogLog estimateSet(List<?> values) {
        SqlVarbinary hllSerialized = (SqlVarbinary)AggregationTestUtils.aggregation(this.getAggregationFunction(), this.createPage(values));
        if (hllSerialized == null) {
            return null;
        }
        return HyperLogLog.newInstance((Slice)Slices.wrappedBuffer((byte[])hllSerialized.getBytes()));
    }

    private HyperLogLog estimateSetPartial(List<?> values) {
        SqlVarbinary hllSerialized = (SqlVarbinary)AggregationTestUtils.partialAggregation(this.getAggregationFunction(), this.createPage(values));
        if (hllSerialized == null) {
            return null;
        }
        return HyperLogLog.newInstance((Slice)Slices.wrappedBuffer((byte[])hllSerialized.getBytes()));
    }

    private InternalAggregationFunction getAggregationFunction() {
        return metadata.getAggregateFunctionImplementation(metadata.resolveFunction(QualifiedName.of((String)"$approx_set"), TypeSignatureProvider.fromTypes((Type[])new Type[]{this.getValueType()})));
    }

    private Page createPage(List<?> values) {
        if (values.isEmpty()) {
            return new Page(0);
        }
        return new Page(values.size(), new Block[]{AbstractTestApproximateSetGeneric.createBlock(this.getValueType(), values), AbstractTestApproximateSetGeneric.createBlock((Type)DoubleType.DOUBLE, ImmutableList.copyOf(Collections.nCopies(values.size(), 0.023)))});
    }

    private static Block createBlock(Type type, List<?> values) {
        BlockBuilder blockBuilder = type.createBlockBuilder(null, values.size());
        for (Object value : values) {
            Class javaType = type.getJavaType();
            if (value == null) {
                blockBuilder.appendNull();
                continue;
            }
            if (javaType == Boolean.TYPE) {
                type.writeBoolean(blockBuilder, ((Boolean)value).booleanValue());
                continue;
            }
            if (javaType == Long.TYPE) {
                type.writeLong(blockBuilder, ((Long)value).longValue());
                continue;
            }
            if (javaType == Double.TYPE) {
                type.writeDouble(blockBuilder, ((Double)value).doubleValue());
                continue;
            }
            if (javaType == Slice.class) {
                Slice slice = (Slice)value;
                type.writeSlice(blockBuilder, slice, 0, slice.length());
                continue;
            }
            type.writeObject(blockBuilder, value);
        }
        return blockBuilder.build();
    }

    private List<Object> createRandomSample(int uniques, int total) {
        Preconditions.checkArgument((uniques <= total ? 1 : 0) != 0, (String)"uniques (%s) must be <= total (%s)", (int)uniques, (int)total);
        ArrayList<Object> result = new ArrayList<Object>(total);
        result.addAll(this.makeRandomSet(uniques));
        ThreadLocalRandom random = ThreadLocalRandom.current();
        while (result.size() < total) {
            int index = ((Random)random).nextInt(result.size());
            result.add(result.get(index));
        }
        return result;
    }

    private Set<Object> makeRandomSet(int count) {
        HashSet<Object> result = new HashSet<Object>();
        while (result.size() < count) {
            result.add(this.randomValue());
        }
        return result;
    }
}

