/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.block.BlockAssertions;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.BigintType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.operator.aggregation.AbstractTestAggregationFunction;
import com.facebook.presto.operator.aggregation.AggregationTestUtils;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.JavaAggregationFunctionImplementation;
import com.facebook.presto.sql.analyzer.TypeSignatureProvider;
import com.google.common.collect.ImmutableList;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public abstract class TestPrecisionRecallAggregation
extends AbstractTestAggregationFunction {
    private static final Integer NUM_BINS = 3;
    private static final double MIN_FALSE_PRED = 0.2;
    private static final double MAX_FALSE_PRED = 0.5;
    private final String functionName;
    private JavaAggregationFunctionImplementation precisionRecallFunction;

    @BeforeClass
    public void setUp() {
        FunctionAndTypeManager functionAndTypeManager = MetadataManager.createTestMetadataManager().getFunctionAndTypeManager();
        this.precisionRecallFunction = functionAndTypeManager.getJavaAggregateFunctionImplementation(functionAndTypeManager.lookupFunction(this.functionName, TypeSignatureProvider.fromTypes((Type[])new Type[]{BigintType.BIGINT, BooleanType.BOOLEAN, DoubleType.DOUBLE, DoubleType.DOUBLE})));
    }

    @Test
    public void testNegativeWeight() {
        try {
            AggregationTestUtils.assertAggregation(this.precisionRecallFunction, (Object)0.0, BlockAssertions.createLongsBlock(200L), BlockAssertions.createBooleansBlock(true), BlockAssertions.createDoublesBlock(0.2), BlockAssertions.createDoublesBlock(-0.2));
            Assert.fail((String)"Expected PrestoException");
        }
        catch (PrestoException e) {
            Assert.assertTrue((boolean)e.getMessage().toLowerCase(Locale.ENGLISH).contains("weight"));
            Assert.assertTrue((boolean)e.getMessage().toLowerCase(Locale.ENGLISH).contains("negative"));
        }
    }

    @Test
    public void testTooHighPrediction() {
        try {
            AggregationTestUtils.assertAggregation(this.precisionRecallFunction, (Object)0.0, BlockAssertions.createLongsBlock(200L), BlockAssertions.createBooleansBlock(true), BlockAssertions.createDoublesBlock(1.2), BlockAssertions.createDoublesBlock(0.2));
            Assert.fail((String)"Expected PrestoException");
        }
        catch (PrestoException e) {
            Assert.assertTrue((boolean)e.getMessage().toLowerCase(Locale.ENGLISH).contains("prediction"));
        }
    }

    @Test
    public void testTooLowPrediction() {
        try {
            AggregationTestUtils.assertAggregation(this.precisionRecallFunction, (Object)0.0, BlockAssertions.createLongsBlock(200L), BlockAssertions.createBooleansBlock(true), BlockAssertions.createDoublesBlock(-1.2), BlockAssertions.createDoublesBlock(0.2));
            Assert.fail((String)"Expected PrestoException");
        }
        catch (PrestoException e) {
            Assert.assertTrue((boolean)e.getMessage().toLowerCase(Locale.ENGLISH).contains("prediction"));
        }
    }

    @Test
    public void testNonConstantBuckets() {
        try {
            AggregationTestUtils.assertAggregation(this.precisionRecallFunction, (Object)0.0, BlockAssertions.createLongsBlock(200L, 300L), BlockAssertions.createBooleansBlock(true, false), BlockAssertions.createDoublesBlock(0.2, 0.3), BlockAssertions.createDoublesBlock(1.0, 1.0));
            Assert.fail((String)"Expected PrestoException");
        }
        catch (PrestoException e) {
            Assert.assertTrue((boolean)e.getMessage().toLowerCase(Locale.ENGLISH).contains("bucket"));
        }
    }

    @Override
    public Block[] getSequenceBlocks(int start, int length) {
        start = Math.abs(start);
        BlockBuilder bucketCountBlockBuilder = BigintType.BIGINT.createBlockBuilder(null, length);
        BlockBuilder outcomeBlockBuilder = BooleanType.BOOLEAN.createBlockBuilder(null, length);
        BlockBuilder predBlockBuilder = DoubleType.DOUBLE.createBlockBuilder(null, length);
        for (int i = start; i < start + length; ++i) {
            BigintType.BIGINT.writeLong(bucketCountBlockBuilder, (long)NUM_BINS.intValue());
            Result result = TestPrecisionRecallAggregation.getResult(start, length, i);
            BooleanType.BOOLEAN.writeBoolean(outcomeBlockBuilder, result.outcome.booleanValue());
            DoubleType.DOUBLE.writeDouble(predBlockBuilder, result.prediction.doubleValue());
        }
        return new Block[]{bucketCountBlockBuilder.build(), outcomeBlockBuilder.build(), predBlockBuilder.build()};
    }

    protected static Iterator<BucketResult> getResultsIterator(final int start, final int length) {
        final int effectiveStart = Math.abs(start);
        return new Iterator<BucketResult>(){
            int i;

            @Override
            public boolean hasNext() {
                Double left = (double)this.i / (double)NUM_BINS.intValue();
                for (int j = start; j < effectiveStart + length; ++j) {
                    Result result = TestPrecisionRecallAggregation.getResult(effectiveStart, length, j);
                    if (!result.outcome.booleanValue() || !(result.prediction >= left)) continue;
                    return true;
                }
                return false;
            }

            @Override
            public BucketResult next() {
                Double left = (double)this.i / (double)NUM_BINS.intValue();
                Double right = (double)(this.i + 1) / (double)NUM_BINS.intValue();
                Double totalTrue = 0.0;
                Double totalFalse = 0.0;
                Double remainingTrue = 0.0;
                Double remainingFalse = 0.0;
                for (int j = start; j < start + length; ++j) {
                    Result result = TestPrecisionRecallAggregation.getResult(start, length, j);
                    if (result.outcome.booleanValue()) {
                        totalTrue = totalTrue + 1.0;
                        if (!(result.prediction >= left)) continue;
                        remainingTrue = remainingTrue + 1.0;
                        continue;
                    }
                    totalFalse = totalFalse + 1.0;
                    if (!(result.prediction >= left)) continue;
                    remainingFalse = remainingFalse + 1.0;
                }
                ++this.i;
                return new BucketResult(left, right, totalTrue, totalFalse, remainingTrue, remainingFalse);
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    protected TestPrecisionRecallAggregation(String functionName) {
        this.functionName = functionName;
    }

    @Override
    protected String getFunctionName() {
        return this.functionName;
    }

    @Override
    protected List<String> getFunctionParameterTypes() {
        return ImmutableList.of((Object)"integer", (Object)"boolean", (Object)"double");
    }

    protected static Result getResult(int start, int length, int i) {
        Double prediction = Double.valueOf(i - start) / (double)(length + 1);
        Boolean outcome = prediction < 0.2 || prediction > 0.5;
        return new Result(outcome, prediction);
    }

    protected static class BucketResult {
        public final Double left;
        public final Double right;
        public final Double totalTrueWeight;
        public final Double totalFalseWeight;
        public final Double remainingTrueWeight;
        public final Double remainingFalseWeight;

        public BucketResult(Double left, Double right, Double totalTrueWeight, Double totalFalseWeight, Double remainingTrueWeight, Double remainingFalseWeight) {
            this.left = left;
            this.right = right;
            this.totalTrueWeight = totalTrueWeight;
            this.totalFalseWeight = totalFalseWeight;
            this.remainingTrueWeight = remainingTrueWeight;
            this.remainingFalseWeight = remainingFalseWeight;
        }
    }

    private static class Result {
        public final Boolean outcome;
        public final Double prediction;

        public Result(Boolean outcome, Double prediction) {
            this.outcome = outcome;
            this.prediction = prediction;
        }
    }
}

