/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.AUC;
import hex.ConfusionMatrix2;
import hex.tree.DTree;
import hex.tree.SharedTree;
import java.util.Arrays;
import water.DKV;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.ModelUtils;

public class Score
extends MRTask<Score> {
    final SharedTree _bldr;
    final int _nclass;
    final int _ncols;
    final boolean _oob;
    final boolean _validation;
    final int _cmlen;
    double _sum;
    long _snrows;
    long[][] _cm;
    long[][][] _cms;

    public double r2() {
        double mse = this._sum / (double)this._snrows;
        double stddev = ((Vec)DKV.get((Key)this._bldr._response_key).get()).sigma();
        double var = stddev * stddev;
        return 1.0 - mse / var;
    }

    public ConfusionMatrix2 cm() {
        return this._cm == null ? null : new ConfusionMatrix2(this._cm);
    }

    public AUC auc() {
        if (this._nclass != 2) {
            return null;
        }
        int n = this._cms.length;
        ConfusionMatrix2[] res = new ConfusionMatrix2[n];
        for (int i = 0; i < n; ++i) {
            res[i] = new ConfusionMatrix2(this._cms[i]);
        }
        return new AUC(res, ModelUtils.DEFAULT_THRESHOLDS, this._bldr.vresponse().domain());
    }

    public Score(SharedTree bldr, boolean oob) {
        this._bldr = bldr;
        this._nclass = bldr._nclass;
        this._ncols = bldr._ncols;
        this._oob = oob;
        this._validation = bldr._parms._valid != null;
        this._cmlen = bldr.vresponse().cardinality();
    }

    Score doIt(boolean build_tree_one_node) {
        if (this._bldr._parms._valid != null) {
            Frame res = this._bldr._model.score(this._bldr.valid(), false);
            this.doAll(res, build_tree_one_node);
            res.delete();
        } else {
            this.doAll(this._bldr.train(), build_tree_one_node);
        }
        return this;
    }

    public void map(Chunk[] chks) {
        Chunk ys = this._bldr.chk_resp(chks);
        float[] fs = new float[this._nclass + 1];
        this._cm = this._cmlen == -1 ? (long[][])null : new long[this._cmlen][this._cmlen];
        this._cms = this._cmlen == -1 ? (long[][][])null : new long[ModelUtils.DEFAULT_THRESHOLDS.length][2][2];
        for (int row = 0; row < ys._len; ++row) {
            float err;
            float sum;
            if (ys.isNA0(row)) continue;
            if (this._validation) {
                for (int i = 0; i < this._nclass; ++i) {
                    fs[i + 1] = (float)chks[i + 1].at0(row);
                }
                sum = this._nclass > 1 ? 1.0f : fs[1];
            } else {
                sum = this._bldr.score2(chks, fs, row);
            }
            int yact = 0;
            if (this._oob && this._bldr.outOfBagRow(chks, row)) continue;
            if (this._nclass > 1) {
                yact = (int)ys.at80(row);
                if (sum == 0.0f) {
                    err = 1.0f - 1.0f / (float)this._nclass;
                } else {
                    assert (0 <= yact && yact < this._nclass) : "weird ycls=" + yact + ", y=" + ys.at0(row);
                    float f = Float.isInfinite(sum) ? (Float.isInfinite(fs[yact + 1]) ? 0.0f : 1.0f) : (err = 1.0f - fs[yact + 1] / sum);
                }
                assert (!Double.isNaN(err)) : "fs[cls]=" + fs[yact + 1] + ", sum=" + sum;
            } else {
                err = (float)ys.at0(row) - sum;
            }
            this._sum += (double)(err * err);
            assert (!Double.isNaN(this._sum));
            if (this._nclass > 1) {
                if (this._nclass == 2) {
                    float snd = this._validation ? fs[2] : (!Float.isInfinite(sum) ? fs[2] / sum : (Float.isInfinite(fs[2]) ? 1.0f : 0.0f));
                    for (int i = 0; i < ModelUtils.DEFAULT_THRESHOLDS.length; ++i) {
                        int p = snd >= ModelUtils.DEFAULT_THRESHOLDS[i] ? 1 : 0;
                        long[] lArray = this._cms[i][yact];
                        int n = p;
                        lArray[n] = lArray[n] + 1L;
                    }
                }
                int ypred = this._validation ? (int)this._bldr.chk_work(chks, 0).at80(row) : ModelUtils.getPrediction((float[])fs, (int)row);
                long[] lArray = this._cm[yact];
                int n = ypred;
                lArray[n] = lArray[n] + 1L;
            }
            ++this._snrows;
        }
    }

    public void reduce(Score t) {
        this._sum += t._sum;
        if (this._cm != null) {
            ArrayUtils.add((long[][])this._cm, (long[][])t._cm);
        }
        this._snrows += t._snrows;
        if (this._cms != null) {
            for (int i = 0; i < this._cms.length; ++i) {
                ArrayUtils.add((long[][])this._cms[i], (long[][])t._cms[i]);
            }
        }
    }

    public Score report(int ntrees, DTree[] trees) {
        assert (!Double.isNaN(this._sum));
        Log.info((Object[])new Object[]{"============================================================== "});
        int lcnt = 0;
        if (trees != null) {
            for (DTree t : trees) {
                if (t == null) continue;
                lcnt += t._len;
            }
        }
        long err = this._snrows;
        Log.info((Object[])new Object[]{"r2 is " + this.r2() + ", with " + ntrees + "x" + this._nclass + " trees (average of " + (float)lcnt / (float)this._nclass + " nodes)"});
        if (this._nclass > 1) {
            for (int c = 0; c < this._nclass; ++c) {
                err -= this._cm[c][c];
            }
            Log.info((Object[])new Object[]{"Total of " + err + " errors on " + this._snrows + " rows, CM= " + Arrays.deepToString((Object[])this._cm)});
        } else {
            Log.info((Object[])new Object[]{"Reported on " + this._snrows + " rows."});
        }
        return this;
    }
}

