/*
 * Decompiled with CFR 0.152.
 */
package io.prestosql.operator.aggregation;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import io.airlift.testing.Assertions;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.MetadataManager;
import io.prestosql.operator.aggregation.AggregationTestUtils;
import io.prestosql.operator.aggregation.InternalAggregationFunction;
import io.prestosql.spi.Page;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockBuilder;
import io.prestosql.spi.type.DoubleType;
import io.prestosql.spi.type.Type;
import io.prestosql.sql.analyzer.TypeSignatureProvider;
import io.prestosql.sql.tree.QualifiedName;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ThreadLocalRandom;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

public abstract class AbstractTestApproximateCountDistinct {
    protected static final Metadata metadata = MetadataManager.createTestMetadataManager();

    protected abstract Type getValueType();

    protected abstract Object randomValue();

    protected int getUniqueValuesCount() {
        return 20000;
    }

    @DataProvider(name="provideStandardErrors")
    public Object[][] provideStandardErrors() {
        return new Object[][]{{0.023}, {0.0115}};
    }

    @Test(dataProvider="provideStandardErrors")
    public void testNoPositions(double maxStandardError) {
        this.assertCount((List<?>)ImmutableList.of(), maxStandardError, 0L);
    }

    @Test(dataProvider="provideStandardErrors")
    public void testSinglePosition(double maxStandardError) {
        this.assertCount((List<?>)ImmutableList.of((Object)this.randomValue()), maxStandardError, 1L);
    }

    @Test(dataProvider="provideStandardErrors")
    public void testAllPositionsNull(double maxStandardError) {
        this.assertCount(Collections.nCopies(100, null), maxStandardError, 0L);
    }

    @Test(dataProvider="provideStandardErrors")
    public void testMixedNullsAndNonNulls(double maxStandardError) {
        int uniques = this.getUniqueValuesCount();
        List<Object> baseline = this.createRandomSample(uniques, (int)((double)uniques * 1.5));
        Iterator<Object> iterator = baseline.iterator();
        ArrayList<Object> mixed = new ArrayList<Object>();
        while (iterator.hasNext()) {
            mixed.add(ThreadLocalRandom.current().nextBoolean() ? null : iterator.next());
        }
        this.assertCount(mixed, maxStandardError, this.estimateGroupByCount(baseline, maxStandardError));
    }

    @Test(dataProvider="provideStandardErrors")
    public void testMultiplePositions(double maxStandardError) {
        DescriptiveStatistics stats = new DescriptiveStatistics();
        for (int i = 0; i < 500; ++i) {
            int uniques = ThreadLocalRandom.current().nextInt(this.getUniqueValuesCount()) + 1;
            List<Object> values = this.createRandomSample(uniques, (int)((double)uniques * 1.5));
            long actual = this.estimateGroupByCount(values, maxStandardError);
            double error = (double)(actual - (long)uniques) * 1.0 / (double)uniques;
            stats.addValue(error);
        }
        Assertions.assertLessThan((Comparable)Double.valueOf(stats.getMean()), (Comparable)Double.valueOf(0.01));
        Assertions.assertLessThan((Comparable)Double.valueOf(stats.getStandardDeviation()), (Comparable)Double.valueOf(0.01 + maxStandardError));
    }

    @Test(dataProvider="provideStandardErrors")
    public void testMultiplePositionsPartial(double maxStandardError) {
        for (int i = 0; i < 100; ++i) {
            int uniques = ThreadLocalRandom.current().nextInt(this.getUniqueValuesCount()) + 1;
            List<Object> values = this.createRandomSample(uniques, (int)((double)uniques * 1.5));
            Assert.assertEquals((long)this.estimateCountPartial(values, maxStandardError), (long)this.estimateGroupByCount(values, maxStandardError));
        }
    }

    protected void assertCount(List<?> values, double maxStandardError, long expectedCount) {
        if (!values.isEmpty()) {
            Assert.assertEquals((long)this.estimateGroupByCount(values, maxStandardError), (long)expectedCount);
        }
        Assert.assertEquals((long)this.estimateCount(values, maxStandardError), (long)expectedCount);
        Assert.assertEquals((long)this.estimateCountPartial(values, maxStandardError), (long)expectedCount);
    }

    private long estimateGroupByCount(List<?> values, double maxStandardError) {
        Object result = AggregationTestUtils.groupedAggregation(this.getAggregationFunction(), this.createPage(values, maxStandardError));
        return (Long)result;
    }

    private long estimateCount(List<?> values, double maxStandardError) {
        Object result = AggregationTestUtils.aggregation(this.getAggregationFunction(), this.createPage(values, maxStandardError));
        return (Long)result;
    }

    private long estimateCountPartial(List<?> values, double maxStandardError) {
        Object result = AggregationTestUtils.partialAggregation(this.getAggregationFunction(), this.createPage(values, maxStandardError));
        return (Long)result;
    }

    private InternalAggregationFunction getAggregationFunction() {
        return metadata.getAggregateFunctionImplementation(metadata.resolveFunction(QualifiedName.of((String)"approx_distinct"), TypeSignatureProvider.fromTypes((Type[])new Type[]{this.getValueType(), DoubleType.DOUBLE})));
    }

    private Page createPage(List<?> values, double maxStandardError) {
        if (values.isEmpty()) {
            return new Page(0, new Block[0]);
        }
        return new Page(values.size(), new Block[]{AbstractTestApproximateCountDistinct.createBlock(this.getValueType(), values), AbstractTestApproximateCountDistinct.createBlock((Type)DoubleType.DOUBLE, ImmutableList.copyOf(Collections.nCopies(values.size(), maxStandardError)))});
    }

    private static Block createBlock(Type type, List<?> values) {
        BlockBuilder blockBuilder = type.createBlockBuilder(null, values.size());
        for (Object value : values) {
            Class javaType = type.getJavaType();
            if (value == null) {
                blockBuilder.appendNull();
                continue;
            }
            if (javaType == Boolean.TYPE) {
                type.writeBoolean(blockBuilder, ((Boolean)value).booleanValue());
                continue;
            }
            if (javaType == Long.TYPE) {
                type.writeLong(blockBuilder, ((Long)value).longValue());
                continue;
            }
            if (javaType == Double.TYPE) {
                type.writeDouble(blockBuilder, ((Double)value).doubleValue());
                continue;
            }
            if (javaType == Slice.class) {
                Slice slice = (Slice)value;
                type.writeSlice(blockBuilder, slice, 0, slice.length());
                continue;
            }
            type.writeObject(blockBuilder, value);
        }
        return blockBuilder.build();
    }

    private List<Object> createRandomSample(int uniques, int total) {
        Preconditions.checkArgument((uniques <= total ? 1 : 0) != 0, (String)"uniques (%s) must be <= total (%s)", (int)uniques, (int)total);
        ArrayList<Object> result = new ArrayList<Object>(total);
        result.addAll(this.makeRandomSet(uniques));
        ThreadLocalRandom random = ThreadLocalRandom.current();
        while (result.size() < total) {
            int index = ((Random)random).nextInt(result.size());
            result.add(result.get(index));
        }
        return result;
    }

    private Set<Object> makeRandomSet(int count) {
        HashSet<Object> result = new HashSet<Object>();
        while (result.size() < count) {
            result.add(this.randomValue());
        }
        return result;
    }
}

