/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.drf;

import hex.genmodel.GenModel;
import hex.tree.CompressedTree;
import hex.tree.DTreeScorer;
import java.util.Arrays;
import java.util.Random;
import water.Iced;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.ModelUtils;
import water.util.RandomUtils;

public class TreeMeasuresCollector
extends MRTask<TreeMeasuresCollector> {
    private final float _rate;
    private CompressedTree[][] _trees;
    private final int _var;
    private final boolean _oob;
    private final int _ncols;
    private final int _nclasses;
    private final boolean _classification;
    private final double _threshold;
    private final int _ntrees;
    private long[] _votes;
    private long[] _nrows;
    private float[] _sse;

    private TreeMeasuresCollector(CompressedTree[][] trees, int nclasses, int ncols, float rate, int variable, double threshold) {
        assert (trees.length > 0);
        assert (nclasses == trees[0].length);
        this._trees = trees;
        this._ncols = ncols;
        this._rate = rate;
        this._var = variable;
        this._oob = true;
        this._ntrees = trees.length;
        this._nclasses = nclasses;
        this._classification = nclasses > 1;
        this._threshold = threshold;
    }

    public void map(Chunk[] chks) {
        double[] data = new double[this._ncols];
        double[] preds = new double[this._nclasses + 1];
        Chunk cresp = this.chk_resp(chks);
        int nrows = cresp._len;
        int[] oob = new int[2 + Math.round((1.0f - this._rate) * (float)nrows * 1.2f + 0.5f)];
        int[] soob = null;
        this._nrows = new long[this._ntrees];
        this._votes = this._classification ? new long[this._ntrees] : null;
        this._sse = this._classification ? null : new float[this._ntrees];
        long seedForOob = ShuffleTask.seed(cresp.cidx());
        for (int tidx = 0; tidx < this._ntrees; ++tidx) {
            Random rng = this.rngForTree(this._trees[tidx], cresp.cidx());
            oob = ModelUtils.sampleOOBRows((int)nrows, (float)this._rate, (Random)rng, (int[])oob);
            int oobcnt = oob[0];
            if (this._var >= 0) {
                if (soob == null || soob.length < oobcnt) {
                    soob = new int[oobcnt];
                }
                ArrayUtils.shuffleArray((int[])oob, (int)oobcnt, (int[])soob, (long)seedForOob, (int)1);
            }
            for (int j = 1; j < 1 + oobcnt; ++j) {
                int row = oob[j];
                if (cresp.isNA(row)) continue;
                for (int i = 0; i < this._ncols; ++i) {
                    data[i] = chks[i].atd(row);
                }
                if (this._var >= 0) {
                    data[this._var] = chks[this._var].atd(soob[j - 1]);
                } else assert (soob == null);
                Arrays.fill(preds, 0.0);
                this.score0(data, preds, this._trees[tidx]);
                if (this._classification) {
                    int actu;
                    int pred = GenModel.getPrediction((double[])preds, (double[])data, (double)this._threshold);
                    if (pred == (actu = (int)cresp.at8(row))) {
                        int n = tidx;
                        this._votes[n] = this._votes[n] + 1L;
                    }
                } else {
                    double pred = preds[0];
                    double actu = cresp.atd(row);
                    int n = tidx;
                    this._sse[n] = (float)((double)this._sse[n] + (actu - pred) * (actu - pred));
                }
                int n = tidx;
                this._nrows[n] = this._nrows[n] + 1L;
            }
        }
        this._trees = null;
    }

    public void reduce(TreeMeasuresCollector t) {
        ArrayUtils.add((long[])this._votes, (long[])t._votes);
        ArrayUtils.add((long[])this._nrows, (long[])t._nrows);
        ArrayUtils.add((float[])this._sse, (float[])t._sse);
    }

    public TreeVotes resultVotes() {
        return new TreeVotes(this._votes, this._nrows, this._ntrees);
    }

    public TreeSSE resultSSE() {
        return new TreeSSE(this._sse, this._nrows, this._ntrees);
    }

    private void score0(double[] data, double[] preds, CompressedTree[] ts) {
        DTreeScorer.scoreTree(data, preds, ts);
    }

    private Chunk chk_resp(Chunk[] chks) {
        return chks[this._ncols];
    }

    private Random rngForTree(CompressedTree[] ts, int cidx) {
        return this._oob ? ts[0].rngForChunk(cidx) : new DummyRandom();
    }

    public static TreeVotes collectVotes(CompressedTree[] tree, int nclasses, Frame f, int ncols, float rate, int variable, double threshold) {
        return ((TreeMeasuresCollector)new TreeMeasuresCollector(new CompressedTree[][]{tree}, nclasses, ncols, rate, variable, threshold).doAll(f)).resultVotes();
    }

    public static TreeSSE collectSSE(CompressedTree[] tree, int nclasses, Frame f, int ncols, float rate, int variable, double threshold) {
        return ((TreeMeasuresCollector)new TreeMeasuresCollector(new CompressedTree[][]{tree}, nclasses, ncols, rate, variable, threshold).doAll(f)).resultSSE();
    }

    public static TreeVotes asVotes(TreeMeasures tm) {
        return (TreeVotes)tm;
    }

    public static TreeSSE asSSE(TreeMeasures tm) {
        return (TreeSSE)tm;
    }

    public static class TreeSSE
    extends TreeMeasures<TreeSSE> {
        private float[] _sse;

        public TreeSSE(int initialCapacity) {
            super(initialCapacity);
            this._sse = new float[initialCapacity];
        }

        public TreeSSE(float[] sse, long[] nrows, int ntrees) {
            super(nrows, ntrees);
            this._sse = sse;
        }

        @Override
        public double accuracy(int tidx) {
            return this._sse[tidx] / (float)this._nrows[tidx];
        }

        @Override
        public double[] imp(TreeSSE right) {
            assert (this.npredictors() == right.npredictors());
            int ntrees = this.npredictors();
            double imp = 0.0;
            double sd = 0.0;
            for (int tidx = 0; tidx < ntrees; ++tidx) {
                assert (right.nrows()[tidx] == this.nrows()[tidx]);
                double delta = (double)(this._sse[tidx] - right._sse[tidx]) / (double)this.nrows()[tidx];
                imp += delta;
                sd += delta * delta;
            }
            double av = imp / (double)ntrees;
            double csd = Math.sqrt((sd / (double)ntrees - av * av) / (double)ntrees);
            return new double[]{av, csd};
        }

        @Override
        public TreeSSE append(TreeSSE t) {
            for (int i = 0; i < t.npredictors(); ++i) {
                this.append(t._sse[i], t._nrows[i]);
            }
            return this;
        }

        public TreeSSE append(float sse, long allRows) {
            assert (this._sse.length > this._ntrees && this._sse.length == this._nrows.length) : "TreeVotes inconsistency!";
            this._sse[this._ntrees] = sse;
            this._nrows[this._ntrees] = allRows;
            ++this._ntrees;
            return this;
        }
    }

    public static class TreeVotes
    extends TreeMeasures<TreeVotes> {
        private long[] _votes;

        public TreeVotes(int initialCapacity) {
            super(initialCapacity);
            this._votes = new long[initialCapacity];
        }

        public TreeVotes(long[] votes, long[] nrows, int ntrees) {
            super(nrows, ntrees);
            this._votes = votes;
        }

        public final long[] votes() {
            return this._votes;
        }

        @Override
        public final double accuracy(int tidx) {
            assert (tidx < this._nrows.length && tidx < this._votes.length);
            return (double)this._votes[tidx] / (double)this._nrows[tidx];
        }

        @Override
        public final double[] imp(TreeVotes right) {
            assert (this.npredictors() == right.npredictors());
            int ntrees = this.npredictors();
            double imp = 0.0;
            double sd = 0.0;
            for (int tidx = 0; tidx < ntrees; ++tidx) {
                assert (right.nrows()[tidx] == this.nrows()[tidx]);
                double delta = (double)(right.votes()[tidx] - this.votes()[tidx]) / (double)this.nrows()[tidx];
                imp += delta;
                sd += delta * delta;
            }
            double av = imp / (double)ntrees;
            double csd = Math.sqrt((sd / (double)ntrees - av * av) / (double)ntrees);
            return new double[]{av, csd};
        }

        public TreeVotes append(long rightVotes, long allRows) {
            assert (this._votes.length > this._ntrees && this._votes.length == this._nrows.length) : "TreeVotes inconsistency!";
            this._votes[this._ntrees] = rightVotes;
            this._nrows[this._ntrees] = allRows;
            ++this._ntrees;
            return this;
        }

        @Override
        public TreeVotes append(TreeVotes tv) {
            for (int i = 0; i < tv.npredictors(); ++i) {
                this.append(tv._votes[i], tv._nrows[i]);
            }
            return this;
        }
    }

    public static abstract class TreeMeasures<T extends TreeMeasures>
    extends Iced {
        protected int _ntrees;
        protected long[] _nrows;

        public TreeMeasures(int initialCapacity) {
            this._nrows = new long[initialCapacity];
        }

        public TreeMeasures(long[] nrows, int ntrees) {
            this._nrows = nrows;
            this._ntrees = ntrees;
        }

        public final long[] nrows() {
            return this._nrows;
        }

        public final int npredictors() {
            return this._ntrees;
        }

        public abstract double accuracy(int var1);

        public final double[] accuracy() {
            double[] r = new double[this._ntrees];
            for (int tidx = 0; tidx < this._ntrees; ++tidx) {
                r[tidx] = this.accuracy(tidx);
            }
            return r;
        }

        public abstract double[] imp(T var1);

        public abstract T append(T var1);
    }

    private static final class DummyRandom
    extends Random {
        private DummyRandom() {
        }

        @Override
        public final float nextFloat() {
            return 1.0f;
        }
    }

    public static class ShuffleTask
    extends MRTask<ShuffleTask> {
        public void map(Chunk ic, Chunk oc) {
            if (ic._len == 0) {
                return;
            }
            Random rng = RandomUtils.getRNG((long[])new long[]{ShuffleTask.seed(ic.cidx())});
            oc.set(0, ic.atd(0));
            for (int row = 1; row < ic._len; ++row) {
                int j = rng.nextInt(row + 1);
                if (j != row) {
                    oc.set(row, oc.atd(j));
                }
                oc.set(j, ic.atd(row));
            }
        }

        public static long seed(int cidx) {
            return -2291796408025514455L + ((long)cidx << 32);
        }

        public static Vec shuffle(Vec ivec) {
            Vec ovec = ivec.makeZero();
            new ShuffleTask().doAll(new Vec[]{ivec, ovec});
            return ovec;
        }
    }
}

