/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.quantile.Quantile;
import hex.quantile.QuantileModel;
import java.util.Arrays;
import java.util.Iterator;
import java.util.TreeSet;
import water.DKV;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.Scope;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

public class GainsLift
extends Iced {
    private double[] _quantiles;
    public int _groups = -1;
    public Vec _labels;
    public Vec _preds;
    public Vec _weights;
    public double[] response_rates;
    public double avg_response_rate;
    public long[] events;
    public long[] observations;
    TwoDimTable table;

    public GainsLift(Vec preds, Vec labels) {
        this(preds, labels, null);
    }

    public GainsLift(Vec preds, Vec labels, Vec weights) {
        this._preds = preds;
        this._labels = labels;
        this._weights = weights;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void init(Job job) throws IllegalArgumentException {
        boolean fast;
        this._labels = this._labels.toCategoricalVec();
        if (this._labels == null || this._preds == null) {
            throw new IllegalArgumentException("Missing actualLabels or predictedProbs!");
        }
        if (this._labels.length() != this._preds.length()) {
            throw new IllegalArgumentException("Both arguments must have the same length (" + this._labels.length() + "!=" + this._preds.length() + ")!");
        }
        if (!this._labels.isInt()) {
            throw new IllegalArgumentException("Actual column must be integer class labels!");
        }
        if (this._labels.cardinality() != -1 && this._labels.cardinality() != 2) {
            throw new IllegalArgumentException("Actual column must contain binary class labels, but found cardinality " + this._labels.cardinality() + "!");
        }
        if (this._preds.isCategorical()) {
            throw new IllegalArgumentException("Predicted probabilities cannot be class labels, expect probabilities.");
        }
        if (this._weights != null && !this._weights.isNumeric()) {
            throw new IllegalArgumentException("Observation weights must be numeric.");
        }
        if (!this._labels.group().equals(this._preds.group())) {
            this._preds = this._labels.align(this._preds);
            Scope.track(this._preds);
            if (this._weights != null) {
                this._weights = this._labels.align(this._weights);
                Scope.track(this._weights);
            }
        }
        if (fast = false) {
            assert (this._groups == 10);
            assert (Arrays.equals(Vec.PERCENTILES, new double[]{0.001, 0.01, 0.1, 0.2, 0.25, 0.3, 0.3333333333333333, 0.4, 0.5, 0.6, 0.6666666666666666, 0.7, 0.75, 0.8, 0.9, 0.99, 0.999}));
            double[] rq = this._preds.pctiles();
            this._quantiles = new double[]{rq[14], rq[13], rq[11], rq[9], rq[8], rq[7], rq[5], rq[3], rq[2], 0.0};
        } else {
            Frame fr = null;
            Keyed qm = null;
            try {
                QuantileModel.QuantileParameters qp = new QuantileModel.QuantileParameters();
                if (this._weights == null) {
                    fr = new Frame(Key.<Frame>make(), new String[]{"predictions"}, new Vec[]{this._preds});
                } else {
                    fr = new Frame(Key.<Frame>make(), new String[]{"predictions", "weights"}, new Vec[]{this._preds, this._weights});
                    qp._weights_column = "weights";
                }
                DKV.put(fr);
                qp._train = fr._key;
                if (this._groups > 0) {
                    qp._probs = new double[this._groups];
                    for (int i = 0; i < this._groups; ++i) {
                        qp._probs[i] = ((double)(this._groups - i) - 1.0) / (double)this._groups;
                    }
                } else {
                    qp._probs = new double[]{0.99, 0.98, 0.97, 0.96, 0.95, 0.9, 0.85, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0};
                }
                qm = job != null && !job.isDone() ? (QuantileModel)new Quantile(qp, job).trainModelNested(null) : (QuantileModel)new Quantile(qp).trainModel().get();
                this._quantiles = ((QuantileModel.QuantileOutput)((QuantileModel)qm)._output)._quantiles[0];
                TreeSet<Double> hs = new TreeSet<Double>();
                for (double d : this._quantiles) {
                    hs.add(d);
                }
                this._quantiles = new double[hs.size()];
                Iterator it = hs.descendingIterator();
                int i = 0;
                while (it.hasNext()) {
                    this._quantiles[i++] = (Double)it.next();
                }
            }
            finally {
                if (qm != null) {
                    qm.remove();
                }
                if (fr != null) {
                    DKV.remove(fr._key);
                }
            }
        }
    }

    public void exec() {
        this.exec(null);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void exec(Job job) {
        Scope.enter();
        this.init(job);
        try {
            GainsLiftBuilder gt = new GainsLiftBuilder(this._quantiles);
            gt = this._weights != null ? (GainsLiftBuilder)gt.doAll(this._labels, this._preds, this._weights) : (GainsLiftBuilder)gt.doAll(this._labels, this._preds);
            this.response_rates = gt.response_rates();
            this.avg_response_rate = gt.avg_response_rate();
            this.events = gt.events();
            this.observations = gt.observations();
        }
        finally {
            Scope.exit(new Key[0]);
        }
    }

    public String toString() {
        TwoDimTable t = this.createTwoDimTable();
        return t == null ? "" : t.toString();
    }

    public TwoDimTable createTwoDimTable() {
        if (this.response_rates == null || Double.isNaN(this.avg_response_rate)) {
            return null;
        }
        TwoDimTable table = new TwoDimTable("Gains/Lift Table", "Avg response rate: " + PrettyPrint.formatPct(this.avg_response_rate), new String[this.events.length], new String[]{"Group", "Cumulative Data Fraction", "Lower Threshold", "Lift", "Cumulative Lift", "Response Rate", "Cumulative Response Rate", "Capture Rate", "Cumulative Capture Rate", "Gain", "Cumulative Gain"}, new String[]{"int", "double", "double", "double", "double", "double", "double", "double", "double", "double", "double"}, new String[]{"%d", "%.8f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f", "%5f"}, "");
        long sum_e_i = 0L;
        long sum_n_i = 0L;
        double P = this.avg_response_rate;
        long N = ArrayUtils.sum(this.observations);
        long E = Math.round((double)N * P);
        for (int i = 0; i < this.events.length; ++i) {
            long e_i = this.events[i];
            long n_i = this.observations[i];
            double p_i = this.response_rates[i];
            double lift = p_i / P;
            double sum_lift = (double)(sum_e_i += e_i) / (double)(sum_n_i += n_i) / P;
            table.set(i, 0, i + 1);
            table.set(i, 1, (double)sum_n_i / (double)N);
            table.set(i, 2, this._quantiles[i]);
            table.set(i, 3, lift);
            table.set(i, 4, sum_lift);
            table.set(i, 5, p_i);
            table.set(i, 6, (double)sum_e_i / (double)sum_n_i);
            table.set(i, 7, (double)e_i / (double)E);
            table.set(i, 8, (double)sum_e_i / (double)E);
            table.set(i, 9, 100.0 * (lift - 1.0));
            table.set(i, 10, 100.0 * (sum_lift - 1.0));
            if (i != this.events.length - 1) continue;
            assert (sum_n_i == N) : "Cumulative data fraction must be 1.0, but is " + (double)sum_n_i / (double)N;
            assert (sum_e_i == E) : "Cumulative capture rate must be 1.0, but is " + (double)sum_e_i / (double)E;
            if (!Double.isNaN(sum_lift)) assert (Math.abs(sum_lift - 1.0) < 1.0E-8) : "Cumulative lift must be 1.0, but is " + sum_lift;
            assert (Math.abs((double)sum_e_i / (double)sum_n_i - this.avg_response_rate) < 1.0E-8) : "Cumulative response rate must be " + this.avg_response_rate + ", but is " + (double)sum_e_i / (double)sum_n_i;
        }
        this.table = table;
        return this.table;
    }

    public static class GainsLiftBuilder
    extends MRTask<GainsLiftBuilder> {
        private final double[] _thresh;
        private long[] _events;
        private long[] _observations;
        private long _avg_response;
        private double _avg_response_rate;
        private double[] _response_rates;

        public final double[] response_rates() {
            return this._response_rates;
        }

        public final double avg_response_rate() {
            return this._avg_response_rate;
        }

        public final long[] events() {
            return this._events;
        }

        public final long[] observations() {
            return this._observations;
        }

        public GainsLiftBuilder(double[] thresh) {
            this._thresh = (double[])thresh.clone();
        }

        @Override
        public void map(Chunk ca, Chunk cp) {
            this.map(ca, cp, (Chunk)null);
        }

        @Override
        public void map(Chunk ca, Chunk cp, Chunk cw) {
            this._events = new long[this._thresh.length];
            this._observations = new long[this._thresh.length];
            this._avg_response = 0L;
            int len = Math.min(ca._len, cp._len);
            for (int i = 0; i < len; ++i) {
                if (ca.isNA(i)) continue;
                int a = (int)ca.at8(i);
                if (a != 0 && a != 1) {
                    throw new IllegalArgumentException("Invalid values in actualLabels: must be binary (0 or 1).");
                }
                if (cp.isNA(i)) continue;
                double pr = cp.atd(i);
                double w = cw != null ? cw.atd(i) : 1.0;
                this.perRow(pr, a, w);
            }
        }

        public void perRow(double pr, int a, double w) {
            if (w == 0.0) {
                return;
            }
            assert (!Double.isNaN(pr));
            assert (!Double.isNaN(a));
            assert (!Double.isNaN(w));
            for (int t = 0; t < this._thresh.length; ++t) {
                if (!(pr >= this._thresh[t]) || t != 0 && !(pr < this._thresh[t - 1])) continue;
                int n = t;
                this._observations[n] = (long)((double)this._observations[n] + w);
                if (a != 1) break;
                int n2 = t;
                this._events[n2] = (long)((double)this._events[n2] + w);
                break;
            }
            if (a == 1) {
                this._avg_response = (long)((double)this._avg_response + w);
            }
        }

        @Override
        public void reduce(GainsLiftBuilder other) {
            ArrayUtils.add(this._events, other._events);
            ArrayUtils.add(this._observations, other._observations);
            this._avg_response += other._avg_response;
        }

        @Override
        public void postGlobal() {
            this._response_rates = new double[this._thresh.length];
            for (int i = 0; i < this._response_rates.length; ++i) {
                this._response_rates[i] = this._observations[i] == 0L ? 0.0 : (double)this._events[i] / (double)this._observations[i];
            }
            this._avg_response_rate = (double)this._avg_response / (double)ArrayUtils.sum(this._observations);
        }
    }
}

