/*
 * Decompiled with CFR 0.152.
 */
package hex;

import java.util.Arrays;
import java.util.Comparator;
import water.Iced;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Vec;

public class AUC2
extends Iced {
    public final int _nBins;
    public final double[] _ths;
    public final double[] _tps;
    public final double[] _fps;
    public final double _p;
    public final double _n;
    public final double _auc;
    public final double _gini;
    public final int _max_idx;
    public static final ThresholdCriterion DEFAULT_CM = ThresholdCriterion.f1;
    public static final int NBINS = 400;

    public double threshold(int idx) {
        return this._ths[idx];
    }

    public double tp(int idx) {
        return this._tps[idx];
    }

    public double fp(int idx) {
        return this._fps[idx];
    }

    public double tn(int idx) {
        return this._n - this._fps[idx];
    }

    public double fn(int idx) {
        return this._p - this._tps[idx];
    }

    public double maxF1() {
        return ThresholdCriterion.f1.max_criterion(this);
    }

    public AUC2(Vec probs, Vec actls) {
        this(400, probs, actls);
    }

    AUC2(int nBins, Vec probs, Vec actls) {
        this(((AUC_Impl)new AUC_Impl((int)nBins).doAll((Vec[])new Vec[]{probs, actls}))._bldr);
    }

    public AUC2(AUCBuilder bldr) {
        this._nBins = bldr._n;
        assert (this._nBins >= 1) : "Must have >= 1 bins for AUC calculation, but got " + this._nBins;
        this._ths = Arrays.copyOf(bldr._ths, this._nBins);
        this._tps = Arrays.copyOf(bldr._tps, this._nBins);
        this._fps = Arrays.copyOf(bldr._fps, this._nBins);
        for (int i = 0; i < this._nBins >> 1; ++i) {
            double tmp = this._ths[i];
            this._ths[i] = this._ths[this._nBins - 1 - i];
            this._ths[this._nBins - 1 - i] = tmp;
            double tmpt = this._tps[i];
            this._tps[i] = this._tps[this._nBins - 1 - i];
            this._tps[this._nBins - 1 - i] = tmpt;
            double tmpf = this._fps[i];
            this._fps[i] = this._fps[this._nBins - 1 - i];
            this._fps[this._nBins - 1 - i] = tmpf;
        }
        double p = 0.0;
        double n = 0.0;
        for (int i = 0; i < this._nBins; ++i) {
            this._tps[i] = p += this._tps[i];
            this._fps[i] = n += this._fps[i];
        }
        this._p = p;
        this._n = n;
        this._auc = this.compute_auc();
        this._gini = 2.0 * this._auc - 1.0;
        this._max_idx = DEFAULT_CM.max_criterion_idx(this);
    }

    private double compute_auc() {
        if (this._fps[this._nBins - 1] == 0.0) {
            return 1.0;
        }
        if (this._tps[this._nBins - 1] == 0.0) {
            return 0.0;
        }
        double tp0 = 0.0;
        double fp0 = 0.0;
        double area = 0.0;
        for (int i = 0; i < this._nBins; ++i) {
            area += (this._fps[i] - fp0) * (this._tps[i] + tp0) / 2.0;
            tp0 = this._tps[i];
            fp0 = this._fps[i];
        }
        return area / this._p / this._n;
    }

    public double[][] buildCM(int idx) {
        return new double[][]{{this.tn(idx), this.fp(idx)}, {this.fn(idx), this.tp(idx)}};
    }

    public double[][] defaultCM() {
        return this._max_idx == -1 ? (double[][])null : this.buildCM(this._max_idx);
    }

    public double defaultThreshold() {
        return this._max_idx == -1 ? 0.5 : this._ths[this._max_idx];
    }

    public double defaultErr() {
        return this._max_idx == -1 ? Double.NaN : (this.fp(this._max_idx) + this.fn(this._max_idx)) / (this._p + this._n);
    }

    public static double perfectAUC(Vec vprob, Vec vacts) {
        if (vacts.min() < 0.0 || vacts.max() > 1.0 || !vacts.isInt()) {
            throw new IllegalArgumentException("Actuals are either 0 or 1");
        }
        if (vprob.min() < 0.0 || vprob.max() > 1.0) {
            throw new IllegalArgumentException("Probabilities are between 0 and 1");
        }
        Pair[] ps = new Pair[(int)vprob.length()];
        Vec.Reader rprob = new Vec.Reader(vprob);
        Vec.Reader racts = new Vec.Reader(vacts);
        for (int i = 0; i < ps.length; ++i) {
            ps[i] = new Pair(rprob.at(i), (byte)racts.at8(i));
        }
        return AUC2.perfectAUC(ps);
    }

    public static double perfectAUC(double[] ds, double[] acts) {
        Pair[] ps = new Pair[ds.length];
        for (int i = 0; i < ps.length; ++i) {
            ps[i] = new Pair(ds[i], (byte)acts[i]);
        }
        return AUC2.perfectAUC(ps);
    }

    private static double perfectAUC(Pair[] ps) {
        Arrays.sort(ps, new Comparator<Pair>(){

            @Override
            public int compare(Pair a, Pair b) {
                return a._prob < b._prob ? 1 : (a._prob == b._prob ? b._act - a._act : -1);
            }
        });
        int tp0 = 0;
        int fp0 = 0;
        int tp1 = 0;
        int fp1 = 0;
        double prob = 1.0;
        double area = 0.0;
        for (Pair p : ps) {
            if (p._prob != prob) {
                area += (double)((fp1 - fp0) * (tp1 + tp0)) / 2.0;
                tp0 = tp1;
                fp0 = fp1;
                prob = p._prob;
            }
            if (p._act == 1) {
                ++tp1;
                continue;
            }
            ++fp1;
        }
        area += (double)tp0 * (double)(fp1 - fp0);
        return (area += (double)(tp1 - tp0) * (double)(fp1 - fp0) / 2.0) / (double)tp1 / (double)fp1;
    }

    private static class Pair {
        final double _prob;
        final byte _act;

        Pair(double prob, byte act) {
            this._prob = prob;
            this._act = act;
        }
    }

    public static class AUCBuilder
    extends Iced {
        final int _nBins;
        int _n;
        final double[] _ths;
        final double[] _sqe;
        final double[] _tps;
        final double[] _fps;
        int _ssx;

        public AUCBuilder(int nBins) {
            this._nBins = nBins;
            this._ths = new double[nBins << 1];
            this._sqe = new double[nBins << 1];
            this._tps = new double[nBins << 1];
            this._fps = new double[nBins << 1];
            this._ssx = -1;
        }

        public void perRow(double pred, int act, double w) {
            assert (!Double.isNaN(pred));
            assert (act == 0 || act == 1);
            int idx = Arrays.binarySearch(this._ths, 0, this._n, pred);
            if (idx >= 0) {
                if (act == 0) {
                    int n = idx;
                    this._fps[n] = this._fps[n] + w;
                } else {
                    int n = idx;
                    this._tps[n] = this._tps[n] + w;
                }
                this._ssx = -1;
                return;
            }
            idx = -idx - 1;
            if (this._n > this._nBins) {
                int ssx = this.find_smallest();
                double dssx = this.compute_delta_error(this._ths[ssx + 1], this.k(ssx + 1), this._ths[ssx], this.k(ssx));
                double d0 = this.compute_delta_error(pred, w, this._ths[idx], this.k(idx));
                double d1 = this.compute_delta_error(this._ths[idx + 1], this.k(idx + 1), pred, w);
                if (d0 < dssx || d1 < dssx) {
                    if (d1 < d0) {
                        ++idx;
                    } else {
                        d0 = d1;
                    }
                    double oldk = this.k(idx);
                    if (act == 0) {
                        int n = idx;
                        this._fps[n] = this._fps[n] + w;
                    } else {
                        int n = idx;
                        this._tps[n] = this._tps[n] + w;
                    }
                    this._ths[idx] = this._ths[idx] + (pred - this._ths[idx]) / oldk;
                    this._sqe[idx] = this._sqe[idx] + d0;
                    assert (ssx == this.find_smallest());
                    return;
                }
            }
            if (idx == this._ssx) {
                this._ssx = -1;
            } else if (idx < this._ssx) {
                ++this._ssx;
            }
            System.arraycopy(this._ths, idx, this._ths, idx + 1, this._n - idx);
            System.arraycopy(this._sqe, idx, this._sqe, idx + 1, this._n - idx);
            System.arraycopy(this._tps, idx, this._tps, idx + 1, this._n - idx);
            System.arraycopy(this._fps, idx, this._fps, idx + 1, this._n - idx);
            this._ths[idx] = pred;
            this._sqe[idx] = 0.0;
            if (act == 0) {
                this._tps[idx] = 0.0;
                this._fps[idx] = w;
            } else {
                this._tps[idx] = w;
                this._fps[idx] = 0.0;
            }
            ++this._n;
            if (this._n > this._nBins) {
                this.mergeOneBin();
            }
        }

        public void reduce(AUCBuilder bldr) {
            int x = this._n - 1;
            int y = bldr._n - 1;
            while (x + y + 1 >= 0) {
                boolean self_is_larger = y < 0 || x >= 0 && this._ths[x] >= bldr._ths[y];
                AUCBuilder b = self_is_larger ? this : bldr;
                int idx = self_is_larger ? x : y;
                this._ths[x + y + 1] = b._ths[idx];
                this._sqe[x + y + 1] = b._sqe[idx];
                this._tps[x + y + 1] = b._tps[idx];
                this._fps[x + y + 1] = b._fps[idx];
                if (self_is_larger) {
                    --x;
                    continue;
                }
                --y;
            }
            this._n += bldr._n;
            while (this._n > this._nBins || this.dups()) {
                this.mergeOneBin();
            }
        }

        private void mergeOneBin() {
            int ssx = this.find_smallest();
            double k0 = this.k(ssx);
            double k1 = this.k(ssx + 1);
            this._ths[ssx] = (this._ths[ssx] * k0 + this._ths[ssx + 1] * k1) / (k0 + k1);
            this._sqe[ssx] = this._sqe[ssx] + this._sqe[ssx + 1] + this.compute_delta_error(this._ths[ssx + 1], k1, this._ths[ssx], k0);
            int n = ssx;
            this._tps[n] = this._tps[n] + this._tps[ssx + 1];
            int n2 = ssx;
            this._fps[n2] = this._fps[n2] + this._fps[ssx + 1];
            System.arraycopy(this._ths, ssx + 2, this._ths, ssx + 1, this._n - ssx - 2);
            System.arraycopy(this._sqe, ssx + 2, this._sqe, ssx + 1, this._n - ssx - 2);
            System.arraycopy(this._tps, ssx + 2, this._tps, ssx + 1, this._n - ssx - 2);
            System.arraycopy(this._fps, ssx + 2, this._fps, ssx + 1, this._n - ssx - 2);
            --this._n;
            this._ssx = -1;
        }

        private int find_smallest() {
            if (this._ssx == -1) {
                this._ssx = this.find_smallest_impl();
                return this._ssx;
            }
            assert (this._ssx == this.find_smallest_impl());
            return this._ssx;
        }

        private int find_smallest_impl() {
            double minSQE = Double.MAX_VALUE;
            int minI = -1;
            int n = this._n;
            for (int i = 0; i < n - 1; ++i) {
                double derr = this.compute_delta_error(this._ths[i + 1], this.k(i + 1), this._ths[i], this.k(i));
                if (derr == 0.0) {
                    return i;
                }
                double sqe = this._sqe[i] + this._sqe[i + 1] + derr;
                if (!(sqe < minSQE)) continue;
                minI = i;
                minSQE = sqe;
            }
            return minI;
        }

        private boolean dups() {
            int n = this._n;
            for (int i = 0; i < n - 1; ++i) {
                double derr = this.compute_delta_error(this._ths[i + 1], this.k(i + 1), this._ths[i], this.k(i));
                if (derr != 0.0) continue;
                this._ssx = i;
                return true;
            }
            return false;
        }

        private double compute_delta_error(double ths1, double n1, double ths0, double n0) {
            double delta = (float)ths1 - (float)ths0;
            return delta * delta * n0 * n1 / (n0 + n1);
        }

        private double k(int idx) {
            return this._tps[idx] + this._fps[idx];
        }
    }

    private static class AUC_Impl
    extends MRTask<AUC_Impl> {
        final int _nBins;
        AUCBuilder _bldr;

        AUC_Impl(int nBins) {
            this._nBins = nBins;
        }

        @Override
        public void map(Chunk ps, Chunk as) {
            AUCBuilder bldr = this._bldr = new AUCBuilder(this._nBins);
            for (int row = 0; row < ps._len; ++row) {
                if (ps.isNA(row) || as.isNA(row)) continue;
                bldr.perRow(ps.atd(row), (int)as.at8(row), 1.0);
            }
        }

        @Override
        public void reduce(AUC_Impl auc) {
            this._bldr.reduce(auc._bldr);
        }
    }

    public static enum ThresholdCriterion {
        f1(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                double prec = precision.exec(tp, fp, fn, tn);
                double recl = tpr.exec(tp, fp, fn, tn);
                return 2.0 * (prec * recl) / (prec + recl);
            }
        }
        ,
        f2(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                double prec = precision.exec(tp, fp, fn, tn);
                double recl = tpr.exec(tp, fp, fn, tn);
                return 5.0 * (prec * recl) / (4.0 * prec + recl);
            }
        }
        ,
        f0point5(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                double prec = precision.exec(tp, fp, fn, tn);
                double recl = tpr.exec(tp, fp, fn, tn);
                return 1.25 * (prec * recl) / (0.25 * prec + recl);
            }
        }
        ,
        accuracy(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return (tn + tp) / (tp + fn + tn + fp);
            }
        }
        ,
        precision(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return tp / (tp + fp);
            }
        }
        ,
        absolute_MCC(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                double mcc = tp * tn - fp * fn;
                if (mcc == 0.0) {
                    return 0.0;
                }
                assert (Math.abs(mcc /= Math.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))) <= 1.0) : tp + " " + fp + " " + fn + " " + tn;
                return Math.abs(mcc);
            }
        }
        ,
        min_per_class_accuracy(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return Math.min(tp / (tp + fn), tn / (tn + fp));
            }
        }
        ,
        tns(true){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return tn;
            }
        }
        ,
        fns(true){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return fn;
            }
        }
        ,
        fps(true){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return fp;
            }
        }
        ,
        tps(true){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return tp;
            }
        }
        ,
        tnr(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return tn / (fp + tn);
            }
        }
        ,
        fnr(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return fn / (fn + tp);
            }
        }
        ,
        fpr(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return fp / (fp + tn);
            }
        }
        ,
        tpr(false){

            @Override
            double exec(double tp, double fp, double fn, double tn) {
                return tp / (tp + fn);
            }
        };

        public final boolean _isInt;
        public static final ThresholdCriterion[] VALUES;

        private ThresholdCriterion(boolean isInt) {
            this._isInt = isInt;
        }

        abstract double exec(double var1, double var3, double var5, double var7);

        public double exec(AUC2 auc, int idx) {
            return this.exec(auc.tp(idx), auc.fp(idx), auc.fn(idx), auc.tn(idx));
        }

        public double max_criterion(AUC2 auc) {
            return this.exec(auc, this.max_criterion_idx(auc));
        }

        public int max_criterion_idx(AUC2 auc) {
            double md = -1.7976931348623157E308;
            int mx = -1;
            for (int i = 0; i < auc._nBins; ++i) {
                double d = this.exec(auc, i);
                if (!(d > md)) continue;
                md = d;
                mx = i;
            }
            return mx;
        }

        static {
            VALUES = ThresholdCriterion.values();
        }
    }
}

