/*
 * Decompiled with CFR 0.152.
 */
package water.util;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import jsr166y.CountedCompleter;
import water.DTask;
import water.H2O;
import water.H2ONode;
import water.Key;
import water.MRTask;
import water.RPC;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.nbhm.NonBlockingHashMap;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.RandomUtils;

public class MRUtils {
    public static Frame sampleFrame(Frame fr, long rows, final long seed) {
        float fraction;
        if (fr == null) {
            return null;
        }
        float f = fraction = rows > 0L ? (float)rows / (float)fr.numRows() : 1.0f;
        if (fraction >= 1.0f) {
            return fr;
        }
        Key newKey = fr._key != null ? Key.make(fr._key.toString() + (fr._key.toString().contains("temporary") ? ".sample." : ".temporary.sample.") + PrettyPrint.formatPct(fraction).replace(" ", "")) : null;
        Frame r = ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                Random rng = RandomUtils.getRNG(seed + (long)cs[0].cidx());
                int count = 0;
                for (int r = 0; r < cs[0]._len; ++r) {
                    if (!(rng.nextFloat() < fraction) && (count != 0 || r != cs[0]._len - 1)) continue;
                    ++count;
                    for (int i = 0; i < ncs.length; ++i) {
                        ncs[i].addNum(cs[i].atd(r));
                    }
                }
            }
        }.doAll(fr.types(), fr)).outputFrame(newKey, fr.names(), fr.domains());
        if (r.numRows() == 0L) {
            Log.warn("You asked for " + rows + " rows (out of " + fr.numRows() + "), but you got none (seed=" + seed + ").");
            Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\"");
            return MRUtils.sampleFrame(fr, rows, seed + 1L);
        }
        return r;
    }

    public static Frame shuffleFramePerChunk(Frame fr, final long seed) {
        return ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                int[] idx = new int[cs[0]._len];
                for (int r = 0; r < idx.length; ++r) {
                    idx[r] = r;
                }
                ArrayUtils.shuffleArray(idx, RandomUtils.getRNG(seed));
                int[] arr$ = idx;
                int len$ = arr$.length;
                for (int i$ = 0; i$ < len$; ++i$) {
                    long anIdx = arr$[i$];
                    for (int i = 0; i < ncs.length; ++i) {
                        ncs[i].addNum(cs[i].atd((int)anIdx));
                    }
                }
            }
        }.doAll(fr.numCols(), (byte)3, fr)).outputFrame(fr.names(), fr.domains());
    }

    public static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, float[] sampling_ratios, long maxrows, long seed, boolean allowOversampling, boolean verbose) {
        int i;
        double[] dist;
        if (fr == null) {
            return null;
        }
        assert (label.isCategorical());
        if (maxrows < (long)label.domain().length) {
            Log.warn("Attempting to do stratified sampling to fewer samples than there are class labels - automatically increasing to #rows == #labels (" + label.domain().length + ").");
            maxrows = label.domain().length;
        }
        ClassDist cd = new ClassDist(label);
        double[] dArray = dist = weights != null ? ((ClassDist)cd.doAll(label, weights)).dist() : ((ClassDist)cd.doAll(label)).dist();
        assert (dist.length > 0);
        Log.info("Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off"));
        if (verbose) {
            for (i = 0; i < dist.length; ++i) {
                Log.info("Class " + label.factor(i) + ": count: " + dist[i] + " prior: " + (float)dist[i] / (float)fr.numRows());
            }
        }
        float[] fArray = sampling_ratios = sampling_ratios == null ? new float[dist.length] : (float[])sampling_ratios.clone();
        assert (sampling_ratios.length == dist.length);
        if (ArrayUtils.minValue(sampling_ratios) == 0.0f && ArrayUtils.maxValue(sampling_ratios) == 0.0f) {
            for (i = 0; i < dist.length; ++i) {
                sampling_ratios[i] = (float)fr.numRows() / (float)label.domain().length / (float)dist[i];
            }
            float inv_scale = ArrayUtils.minValue(sampling_ratios);
            if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale)) {
                ArrayUtils.div(sampling_ratios, inv_scale);
            }
        }
        if (!allowOversampling) {
            for (i = 0; i < sampling_ratios.length; ++i) {
                sampling_ratios[i] = Math.min(1.0f, sampling_ratios[i]);
            }
        }
        float numrows = 0.0f;
        for (int i2 = 0; i2 < sampling_ratios.length; ++i2) {
            numrows = (float)((double)numrows + (double)sampling_ratios[i2] * dist[i2]);
        }
        if (Float.isNaN(numrows)) {
            throw new IllegalArgumentException("Error during sampling - too few points?");
        }
        long actualnumrows = Math.min(maxrows, (long)Math.round(numrows));
        assert (actualnumrows >= 0L);
        Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows" + ((float)actualnumrows < numrows ? " (limited by max_after_balance_size)." : "."));
        if ((float)actualnumrows != numrows) {
            ArrayUtils.mult(sampling_ratios, (float)actualnumrows / numrows);
            if (verbose) {
                Log.info("Downsampling majority class by " + (float)actualnumrows / numrows + " to limit number of rows to " + String.format("%,d", maxrows));
            }
        }
        for (int i3 = 0; i3 < label.domain().length; ++i3) {
            Log.info("Class '" + label.domain()[i3] + "' sampling ratio: " + sampling_ratios[i3]);
        }
        return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed, verbose);
    }

    public static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, float[] sampling_ratios, long seed, boolean debug) {
        return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed, debug, 0);
    }

    private static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, final float[] sampling_ratios, final long seed, boolean debug, int count) {
        double[] dist;
        if (fr == null) {
            return null;
        }
        assert (label.isCategorical());
        assert (sampling_ratios != null && sampling_ratios.length == label.domain().length);
        final int labelidx = fr.find(label);
        assert (labelidx >= 0);
        int weightsidx = fr.find(weights);
        boolean poisson = false;
        Frame r = ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                Random rng = RandomUtils.getRNG(seed + (long)cs[0].cidx());
                for (int r = 0; r < cs[0]._len; ++r) {
                    if (cs[labelidx].isNA(r)) continue;
                    int label = (int)cs[labelidx].at8(r);
                    assert (sampling_ratios.length > label && label >= 0);
                    float remainder = sampling_ratios[label] - (float)((int)sampling_ratios[label]);
                    int sampling_reps = (int)sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0);
                    for (int i = 0; i < ncs.length; ++i) {
                        for (int j = 0; j < sampling_reps; ++j) {
                            ncs[i].addNum(cs[i].atd(r));
                        }
                    }
                }
            }
        }.doAll(fr.types(), fr)).outputFrame(fr.names(), fr.domains());
        Vec lab = r.vecs()[labelidx];
        Vec wei = weightsidx != -1 ? r.vecs()[weightsidx] : null;
        double[] dArray = dist = wei != null ? ((ClassDist)new ClassDist(lab).doAll(lab, wei)).dist() : ((ClassDist)new ClassDist(lab).doAll(lab)).dist();
        if (dist == null) {
            return fr;
        }
        if (debug) {
            double sumdist = ArrayUtils.sum(dist);
            Log.info("After stratified sampling: " + sumdist + " rows.");
            for (int i = 0; i < dist.length; ++i) {
                Log.info("Class " + r.vecs()[labelidx].factor(i) + ": count: " + dist[i] + " sampling ratio: " + sampling_ratios[i] + " actual relative frequency: " + (double)((float)dist[i]) / sumdist * (double)dist.length);
            }
        }
        if (ArrayUtils.minValue(dist) == 0.0 && count < 10) {
            Log.info("Re-doing stratified sampling because not all classes were represented (unlucky draw).");
            r.delete();
            return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed + 1L, debug, ++count);
        }
        Frame shuffled = MRUtils.shuffleFramePerChunk(r, seed + 92339987L);
        r.delete();
        return shuffled;
    }

    public static class ParallelTasks<T extends DTask<T>>
    extends H2O.H2OCountedCompleter {
        public final transient T[] _tasks;
        public final transient int _maxP;
        private transient AtomicInteger _nextTask;

        public ParallelTasks(H2O.H2OCountedCompleter cmp, T[] tsks) {
            this(cmp, (DTask[])tsks, H2O.CLOUD.size());
        }

        public ParallelTasks(H2O.H2OCountedCompleter cmp, T[] tsks, int maxP) {
            super(cmp);
            this._maxP = maxP;
            this._tasks = tsks;
            this.addToPendingCount(this._tasks.length - 1);
        }

        private void forkDTask(int i) {
            int nodeId = i % H2O.CLOUD.size();
            this.forkDTask(i, H2O.CLOUD._memary[nodeId]);
        }

        private void forkDTask(int i, H2ONode n) {
            if (n == H2O.SELF) {
                ((CountedCompleter)this._tasks[i]).setCompleter(new Callback(H2O.SELF, i));
                H2O.submitTask(this._tasks[i]);
            } else {
                new RPC<T>(n, this._tasks[i]).addCompleter(this).call();
            }
        }

        @Override
        public void compute2() {
            int n = Math.min(this._maxP, this._tasks.length);
            this._nextTask = new AtomicInteger(n);
            for (int i = 0; i < n; ++i) {
                this.forkDTask(i);
            }
        }

        class Callback
        extends H2O.H2OCallback<H2O.H2OCountedCompleter> {
            final int i;
            final H2ONode n;

            public Callback(H2ONode n, int i) {
                super(ParallelTasks.this);
                this.n = n;
                this.i = i;
            }

            @Override
            public void callback(H2O.H2OCountedCompleter cc) {
                Log.info("callback for task " + this.i);
                int nextI = ParallelTasks.this._nextTask.getAndIncrement();
                if (nextI < ParallelTasks.this._tasks.length) {
                    ParallelTasks.this.forkDTask(nextI, this.n);
                }
            }
        }
    }

    public static class Dist
    extends MRTask<Dist> {
        private transient NonBlockingHashMap<Double, Integer> _dist;

        @Override
        public void map(Chunk ys) {
            this._dist = new NonBlockingHashMap();
            for (int row = 0; row < ys._len; ++row) {
                double v;
                Integer oldV;
                if (ys.isNA(row) || (oldV = this._dist.putIfAbsent(v = ys.atd(row), 1)) == null) continue;
                this._dist.put(v, oldV + 1);
            }
        }

        @Override
        public void reduce(Dist mrt) {
            if (this._dist != mrt._dist) {
                NonBlockingHashMap<Double, Integer> l = this._dist;
                NonBlockingHashMap<Double, Integer> r = mrt._dist;
                if (l.size() < r.size()) {
                    l = r;
                    r = this._dist;
                }
                for (Double v : r.keySet()) {
                    Integer oldVal = l.putIfAbsent(v, r.get(v));
                    if (oldVal == null) continue;
                    l.put(v, oldVal + r.get(v));
                }
                this._dist = l;
                mrt._dist = null;
            }
        }

        public double[] dist() {
            int i = 0;
            double[] dist = new double[this._dist.size()];
            for (int v : this._dist.values()) {
                dist[i++] = v;
            }
            return dist;
        }

        public double[] keys() {
            int i = 0;
            double[] keys = new double[this._dist.size()];
            for (double v : this._dist.keySet()) {
                keys[i++] = v;
            }
            return keys;
        }
    }

    public static class ClassDist
    extends MRTask<ClassDist> {
        final int _nclass;
        protected double[] _ys;

        public ClassDist(Vec label) {
            this._nclass = label.domain().length;
        }

        public ClassDist(int n) {
            this._nclass = n;
        }

        public final double[] dist() {
            return this._ys;
        }

        public final double[] rel_dist() {
            double sum = ArrayUtils.sum(this._ys);
            return ArrayUtils.div(Arrays.copyOf(this._ys, this._ys.length), sum);
        }

        @Override
        public void map(Chunk ys) {
            this._ys = new double[this._nclass];
            for (int i = 0; i < ys._len; ++i) {
                if (ys.isNA(i)) continue;
                int n = (int)ys.at8(i);
                this._ys[n] = this._ys[n] + 1.0;
            }
        }

        @Override
        public void map(Chunk ys, Chunk ws) {
            this._ys = new double[this._nclass];
            for (int i = 0; i < ys._len; ++i) {
                if (ys.isNA(i)) continue;
                int n = (int)ys.at8(i);
                this._ys[n] = this._ys[n] + ws.atd(i);
            }
        }

        @Override
        public void reduce(ClassDist that) {
            ArrayUtils.add(this._ys, that._ys);
        }
    }
}

