/*
 * Decompiled with CFR 0.152.
 */
package water.util;

import java.util.Arrays;
import java.util.Random;
import water.Key;
import water.MRTask;
import water.fvec.C16Chunk;
import water.fvec.CStrChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.IcedAtomicInt;
import water.util.IcedDouble;
import water.util.IcedHashMap;
import water.util.Log;
import water.util.PrettyPrint;
import water.util.RandomUtils;

public class MRUtils {
    public static Frame sampleFrame(Frame fr, long rows, final long seed) {
        float fraction;
        if (fr == null) {
            return null;
        }
        float f = fraction = rows > 0L ? (float)rows / (float)fr.numRows() : 1.0f;
        if (fraction >= 1.0f) {
            return fr;
        }
        Key newKey = fr._key != null ? Key.make(fr._key.toString() + (fr._key.toString().contains("temporary") ? ".sample." : ".temporary.sample.") + PrettyPrint.formatPct(fraction).replace(" ", "")) : null;
        Frame r = ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                Random rng = RandomUtils.getRNG(0L);
                BufferedString bStr = new BufferedString();
                int count = 0;
                for (int r = 0; r < cs[0]._len; ++r) {
                    rng.setSeed(seed + (long)r + cs[0].start());
                    if (!(rng.nextFloat() < fraction) && (count != 0 || r != cs[0]._len - 1)) continue;
                    ++count;
                    for (int i = 0; i < ncs.length; ++i) {
                        if (cs[i].isNA(r)) {
                            ncs[i].addNA();
                            continue;
                        }
                        if (cs[i] instanceof CStrChunk) {
                            ncs[i].addStr(cs[i].atStr(bStr, r));
                            continue;
                        }
                        if (cs[i] instanceof C16Chunk) {
                            ncs[i].addUUID(cs[i].at16l(r), cs[i].at16h(r));
                            continue;
                        }
                        ncs[i].addNum(cs[i].atd(r));
                    }
                }
            }
        }.doAll(fr.types(), fr)).outputFrame(newKey, fr.names(), fr.domains());
        if (r.numRows() == 0L) {
            Log.warn("You asked for " + rows + " rows (out of " + fr.numRows() + "), but you got none (seed=" + seed + ").");
            Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\"");
            return MRUtils.sampleFrame(fr, rows, seed + 1L);
        }
        return r;
    }

    public static Frame shuffleFramePerChunk(Frame fr, final long seed) {
        return ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                int[] idx = new int[cs[0]._len];
                for (int r = 0; r < idx.length; ++r) {
                    idx[r] = r;
                }
                ArrayUtils.shuffleArray(idx, RandomUtils.getRNG(seed));
                int[] arr$ = idx;
                int len$ = arr$.length;
                for (int i$ = 0; i$ < len$; ++i$) {
                    long anIdx = arr$[i$];
                    for (int i = 0; i < ncs.length; ++i) {
                        if (cs[i] instanceof CStrChunk) {
                            ncs[i].addStr(cs[i], cs[i].start() + anIdx);
                            continue;
                        }
                        ncs[i].addNum(cs[i].atd((int)anIdx));
                    }
                }
            }
        }.doAll(fr.types(), fr)).outputFrame(fr.names(), fr.domains());
    }

    public static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, float[] sampling_ratios, long maxrows, long seed, boolean allowOversampling, boolean verbose) {
        int i;
        double[] dist;
        if (fr == null) {
            return null;
        }
        assert (label.isCategorical());
        if (maxrows < (long)label.domain().length) {
            Log.warn("Attempting to do stratified sampling to fewer samples than there are class labels - automatically increasing to #rows == #labels (" + label.domain().length + ").");
            maxrows = label.domain().length;
        }
        ClassDist cd = new ClassDist(label);
        double[] dArray = dist = weights != null ? ((ClassDist)cd.doAll(label, weights)).dist() : ((ClassDist)cd.doAll(label)).dist();
        assert (dist.length > 0);
        Log.info("Doing stratified sampling for data set containing " + fr.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (allowOversampling ? "on" : "off"));
        if (verbose) {
            for (i = 0; i < dist.length; ++i) {
                Log.info("Class " + label.factor(i) + ": count: " + dist[i] + " prior: " + (float)dist[i] / (float)fr.numRows());
            }
        }
        float[] fArray = sampling_ratios = sampling_ratios == null ? new float[dist.length] : (float[])sampling_ratios.clone();
        assert (sampling_ratios.length == dist.length);
        if (ArrayUtils.minValue(sampling_ratios) == 0.0f && ArrayUtils.maxValue(sampling_ratios) == 0.0f) {
            for (i = 0; i < dist.length; ++i) {
                sampling_ratios[i] = (float)fr.numRows() / (float)label.domain().length / (float)dist[i];
            }
            float inv_scale = ArrayUtils.minValue(sampling_ratios);
            if (!Float.isNaN(inv_scale) && !Float.isInfinite(inv_scale)) {
                ArrayUtils.div(sampling_ratios, inv_scale);
            }
        }
        if (!allowOversampling) {
            for (i = 0; i < sampling_ratios.length; ++i) {
                sampling_ratios[i] = Math.min(1.0f, sampling_ratios[i]);
            }
        }
        float numrows = 0.0f;
        for (int i2 = 0; i2 < sampling_ratios.length; ++i2) {
            numrows = (float)((double)numrows + (double)sampling_ratios[i2] * dist[i2]);
        }
        if (Float.isNaN(numrows)) {
            throw new IllegalArgumentException("Error during sampling - too few points?");
        }
        long actualnumrows = Math.min(maxrows, (long)Math.round(numrows));
        assert (actualnumrows >= 0L);
        Log.info("Stratified sampling to a total of " + String.format("%,d", actualnumrows) + " rows" + ((float)actualnumrows < numrows ? " (limited by max_after_balance_size)." : "."));
        if ((float)actualnumrows != numrows) {
            ArrayUtils.mult(sampling_ratios, (float)actualnumrows / numrows);
            if (verbose) {
                Log.info("Downsampling majority class by " + (float)actualnumrows / numrows + " to limit number of rows to " + String.format("%,d", maxrows));
            }
        }
        for (int i3 = 0; i3 < label.domain().length; ++i3) {
            Log.info("Class '" + label.domain()[i3] + "' sampling ratio: " + sampling_ratios[i3]);
        }
        return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed, verbose);
    }

    public static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, float[] sampling_ratios, long seed, boolean debug) {
        return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed, debug, 0);
    }

    private static Frame sampleFrameStratified(Frame fr, Vec label, Vec weights, final float[] sampling_ratios, final long seed, boolean debug, int count) {
        double[] dist;
        if (fr == null) {
            return null;
        }
        assert (label.isCategorical());
        assert (sampling_ratios != null && sampling_ratios.length == label.domain().length);
        final int labelidx = fr.find(label);
        assert (labelidx >= 0);
        int weightsidx = fr.find(weights);
        boolean poisson = false;
        Frame r = ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                Random rng = RandomUtils.getRNG(seed);
                for (int r = 0; r < cs[0]._len; ++r) {
                    if (cs[labelidx].isNA(r)) continue;
                    rng.setSeed(cs[0].start() + (long)r + seed);
                    int label = (int)cs[labelidx].at8(r);
                    assert (sampling_ratios.length > label && label >= 0);
                    float remainder = sampling_ratios[label] - (float)((int)sampling_ratios[label]);
                    int sampling_reps = (int)sampling_ratios[label] + (rng.nextFloat() < remainder ? 1 : 0);
                    for (int i = 0; i < ncs.length; ++i) {
                        int j;
                        if (cs[i] instanceof CStrChunk) {
                            for (j = 0; j < sampling_reps; ++j) {
                                ncs[i].addStr(cs[i], cs[0].start() + (long)r);
                            }
                            continue;
                        }
                        for (j = 0; j < sampling_reps; ++j) {
                            ncs[i].addNum(cs[i].atd(r));
                        }
                    }
                }
            }
        }.doAll(fr.types(), fr)).outputFrame(fr.names(), fr.domains());
        Vec lab = r.vecs()[labelidx];
        Vec wei = weightsidx != -1 ? r.vecs()[weightsidx] : null;
        double[] dArray = dist = wei != null ? ((ClassDist)new ClassDist(lab).doAll(lab, wei)).dist() : ((ClassDist)new ClassDist(lab).doAll(lab)).dist();
        if (dist == null) {
            return fr;
        }
        if (debug) {
            double sumdist = ArrayUtils.sum(dist);
            Log.info("After stratified sampling: " + sumdist + " rows.");
            for (int i = 0; i < dist.length; ++i) {
                Log.info("Class " + r.vecs()[labelidx].factor(i) + ": count: " + dist[i] + " sampling ratio: " + sampling_ratios[i] + " actual relative frequency: " + (double)((float)dist[i]) / sumdist * (double)dist.length);
            }
        }
        if (ArrayUtils.minValue(dist) == 0.0 && count < 10) {
            Log.info("Re-doing stratified sampling because not all classes were represented (unlucky draw).");
            r.remove();
            return MRUtils.sampleFrameStratified(fr, label, weights, sampling_ratios, seed + 1L, debug, ++count);
        }
        Frame shuffled = MRUtils.shuffleFramePerChunk(r, seed + 92339987L);
        r.remove();
        return shuffled;
    }

    public static class Dist
    extends MRTask<Dist> {
        private IcedHashMap<IcedDouble, IcedAtomicInt> _dist;

        @Override
        public void map(Chunk ys) {
            this._dist = new IcedHashMap();
            IcedDouble d = new IcedDouble(0.0);
            for (int row = 0; row < ys._len; ++row) {
                if (ys.isNA(row)) continue;
                d._val = ys.atd(row);
                IcedAtomicInt oldV = (IcedAtomicInt)this._dist.get(d);
                if (oldV == null) {
                    oldV = this._dist.putIfAbsent(new IcedDouble(d._val), new IcedAtomicInt(1));
                }
                if (oldV == null) continue;
                oldV.incrementAndGet();
            }
        }

        @Override
        public void reduce(Dist mrt) {
            if (this._dist != mrt._dist) {
                IcedHashMap<IcedDouble, IcedAtomicInt> l = this._dist;
                IcedHashMap<IcedDouble, IcedAtomicInt> r = mrt._dist;
                if (l.size() < r.size()) {
                    l = r;
                    r = this._dist;
                }
                for (IcedDouble v : r.keySet()) {
                    IcedAtomicInt oldVal = l.putIfAbsent(v, (IcedAtomicInt)r.get(v));
                    if (oldVal == null) continue;
                    oldVal.addAndGet(((IcedAtomicInt)r.get(v)).get());
                }
                this._dist = l;
                mrt._dist = null;
            }
        }

        public double[] dist() {
            int i = 0;
            double[] dist = new double[this._dist.size()];
            for (IcedAtomicInt v : this._dist.values()) {
                dist[i++] = v.get();
            }
            return dist;
        }

        public double[] keys() {
            int i = 0;
            double[] keys = new double[this._dist.size()];
            for (IcedDouble k : this._dist.keySet()) {
                keys[i++] = k._val;
            }
            return keys;
        }
    }

    public static class ClassDist
    extends MRTask<ClassDist> {
        final int _nclass;
        protected double[] _ys;

        public ClassDist(Vec label) {
            this._nclass = label.domain().length;
        }

        public ClassDist(int n) {
            this._nclass = n;
        }

        public final double[] dist() {
            return this._ys;
        }

        public final double[] rel_dist() {
            double sum = ArrayUtils.sum(this._ys);
            return ArrayUtils.div(Arrays.copyOf(this._ys, this._ys.length), sum);
        }

        @Override
        public void map(Chunk ys) {
            this._ys = new double[this._nclass];
            for (int i = 0; i < ys._len; ++i) {
                if (ys.isNA(i)) continue;
                int n = (int)ys.at8(i);
                this._ys[n] = this._ys[n] + 1.0;
            }
        }

        @Override
        public void map(Chunk ys, Chunk ws) {
            this._ys = new double[this._nclass];
            for (int i = 0; i < ys._len; ++i) {
                if (ys.isNA(i)) continue;
                int n = (int)ys.at8(i);
                this._ys[n] = this._ys[n] + ws.atd(i);
            }
        }

        @Override
        public void reduce(ClassDist that) {
            ArrayUtils.add(this._ys, that._ys);
        }
    }
}

