/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.operator.aggregation.fixedhistogram.FixedDoubleHistogram;
import com.facebook.presto.operator.aggregation.state.PrecisionRecallState;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.SqlType;
import com.google.common.collect.Streams;
import java.util.Collections;
import java.util.Iterator;
import java.util.NoSuchElementException;

public abstract class PrecisionRecallAggregation {
    private static final double DEFAULT_WEIGHT = 1.0;
    private static final double MIN_PREDICTION_VALUE = 0.0;
    private static final double MAX_PREDICTION_VALUE = 1.0;
    private static final double MAX_PREDICTION_VALUE_FOR_HISTOGRAM = 0.99999999999;
    private static final String ILLEGAL_PREDICTION_VALUE_MESSAGE = String.format("Prediction value must be between %s and %s", 0.0, 1.0);
    private static final String NEGATIVE_WEIGHT_MESSAGE = "Weights must be non-negative";
    private static final String INCONSISTENT_BUCKET_COUNT_MESSAGE = "Bucket count must be constant";

    protected PrecisionRecallAggregation() {
    }

    @InputFunction
    public static void input(@AggregationState PrecisionRecallState state, @SqlType(value="bigint") long bucketCount, @SqlType(value="boolean") boolean outcome, @SqlType(value="double") double pred, @SqlType(value="double") double weight) {
        if (state.getTrueWeights() == null) {
            state.setTrueWeights(new FixedDoubleHistogram((int)bucketCount, 0.0, 1.0));
            state.setFalseWeights(new FixedDoubleHistogram((int)bucketCount, 0.0, 1.0));
        }
        if (pred < 0.0 || pred > 1.0) {
            throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, ILLEGAL_PREDICTION_VALUE_MESSAGE);
        }
        pred = Math.min(pred, 0.99999999999);
        if (weight < 0.0) {
            throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, NEGATIVE_WEIGHT_MESSAGE);
        }
        if (bucketCount != (long)state.getTrueWeights().getBucketCount()) {
            throw new PrestoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, INCONSISTENT_BUCKET_COUNT_MESSAGE);
        }
        if (outcome) {
            state.getTrueWeights().add(pred, weight);
        } else {
            state.getFalseWeights().add(pred, weight);
        }
    }

    @InputFunction
    public static void input(@AggregationState PrecisionRecallState state, @SqlType(value="bigint") long bucketCount, @SqlType(value="boolean") boolean outcome, @SqlType(value="double") double pred) {
        PrecisionRecallAggregation.input(state, bucketCount, outcome, pred, 1.0);
    }

    @CombineFunction
    public static void combine(@AggregationState PrecisionRecallState state, @AggregationState PrecisionRecallState otherState) {
        if (state.getTrueWeights() == null && otherState.getTrueWeights() != null) {
            state.setTrueWeights(otherState.getTrueWeights().clone());
            state.setFalseWeights(otherState.getFalseWeights().clone());
            return;
        }
        if (state.getTrueWeights() != null && otherState.getTrueWeights() != null) {
            state.getTrueWeights().mergeWith(otherState.getTrueWeights());
            state.getFalseWeights().mergeWith(otherState.getFalseWeights());
        }
    }

    protected static Iterator<BucketResult> getResultsIterator(final @AggregationState PrecisionRecallState state) {
        if (state.getTrueWeights() == null) {
            return Collections.emptyList().iterator();
        }
        final double totalTrueWeight = Streams.stream(state.getTrueWeights().iterator()).mapToDouble(FixedDoubleHistogram.Bucket::getWeight).sum();
        final double totalFalseWeight = Streams.stream(state.getFalseWeights().iterator()).mapToDouble(FixedDoubleHistogram.Bucket::getWeight).sum();
        return new Iterator<BucketResult>(){
            Iterator<FixedDoubleHistogram.Bucket> trueIterator;
            Iterator<FixedDoubleHistogram.Bucket> falseIterator;
            double runningFalseWeight;
            double runningTrueWeight;
            {
                this.trueIterator = state.getTrueWeights().iterator();
                this.falseIterator = state.getFalseWeights().iterator();
            }

            @Override
            public boolean hasNext() {
                return this.trueIterator.hasNext() && totalTrueWeight > this.runningTrueWeight;
            }

            @Override
            public BucketResult next() {
                if (!this.trueIterator.hasNext() || !this.falseIterator.hasNext()) {
                    throw new NoSuchElementException();
                }
                FixedDoubleHistogram.Bucket trueResult = this.trueIterator.next();
                FixedDoubleHistogram.Bucket falseResult = this.falseIterator.next();
                BucketResult result = new BucketResult(trueResult.getLeft(), totalTrueWeight, totalFalseWeight, totalTrueWeight - this.runningTrueWeight, this.runningFalseWeight, totalFalseWeight - this.runningFalseWeight, this.runningTrueWeight);
                this.runningTrueWeight += trueResult.getWeight();
                this.runningFalseWeight += falseResult.getWeight();
                return result;
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException();
            }
        };
    }

    protected static class BucketResult {
        private final double threshold;
        private final double positive;
        private final double negative;
        private final double truePositive;
        private final double trueNegative;
        private final double falsePositive;
        private final double falseNegative;

        public double getThreshold() {
            return this.threshold;
        }

        public double getPositive() {
            return this.positive;
        }

        public double getNegative() {
            return this.negative;
        }

        public double getTruePositive() {
            return this.truePositive;
        }

        public double getTrueNegative() {
            return this.trueNegative;
        }

        public double getFalsePositive() {
            return this.falsePositive;
        }

        public double getFalseNegative() {
            return this.falseNegative;
        }

        public BucketResult(double threshold, double positive, double negative, double truePositive, double trueNegative, double falsePositive, double falseNegative) {
            this.threshold = threshold;
            this.positive = positive;
            this.negative = negative;
            this.truePositive = truePositive;
            this.trueNegative = trueNegative;
            this.falsePositive = falsePositive;
            this.falseNegative = falseNegative;
        }
    }
}

