/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Random;
import org.apache.commons.lang.ArrayUtils;
import water.DKV;
import water.Iced;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.RandomUtils;
import water.util.VecUtils;

public class AstStratifiedSplit
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"ary", "test_frac", "seed"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "h2o.random_stratified_split";
    }

    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Key inputFrKey;
        Frame origfr = stk.track(asts[1].exec(env)).getFrame();
        Frame fr = origfr.deepCopy((inputFrKey = Key.make()).toString());
        if (fr.numCols() != 1) {
            throw new IllegalArgumentException("Must give a single column to stratify against. Got: " + fr.numCols() + " columns.");
        }
        Vec y = fr.anyVec();
        if (!(y.isCategorical() || y.isNumeric() && y.isInt())) {
            throw new IllegalArgumentException("stratification only applies to integer and categorical columns. Got: " + y.get_type_str());
        }
        double testFrac = asts[2].exec(env).getNum();
        long seed = (long)asts[3].exec(env).getNum();
        seed = seed == -1L ? new Random().nextLong() : seed;
        long[] classes = ((VecUtils.CollectDomain)new VecUtils.CollectDomain().doAll(y)).domain();
        int nClass = y.isNumeric() ? classes.length : y.domain().length;
        String[] domains = y.domain();
        long[] seeds = new long[nClass];
        for (int i = 0; i < nClass; ++i) {
            seeds[i] = RandomUtils.getRNG(seed + (long)i).nextLong();
        }
        String[] dom = new String[]{"train", "test"};
        Key<Frame> k1 = Key.make();
        Vec resVec = Vec.makeCon(0L, fr.anyVec().length());
        resVec.setDomain(new String[]{"train", "test"});
        Frame result = new Frame(k1, new String[]{"test_train_split"}, new Vec[]{resVec});
        DKV.put(result);
        ClassIdxTask finTask = (ClassIdxTask)new ClassIdxTask(nClass, classes).doAll(fr);
        HashSet<Long> usedIdxs = new HashSet<Long>();
        for (int classLabel = 0; classLabel < nClass; ++classLabel) {
            long tnum = Math.max(Math.round((double)finTask._iarray[classLabel].size() * testFrac), 1L);
            HashSet<Long> tmpIdxs = new HashSet<Long>();
            int generated = 0;
            int count = 0;
            while ((long)generated < tnum) {
                int i = (int)(RandomUtils.getRNG((long)count + seed).nextDouble() * (double)finTask._iarray[classLabel].size());
                if (tmpIdxs.contains(finTask._iarray[classLabel].get(i))) {
                    ++count;
                    continue;
                }
                tmpIdxs.add(finTask._iarray[classLabel].get(i));
                ++generated;
                ++count;
            }
            usedIdxs.addAll(tmpIdxs);
        }
        new ClassAssignMRTask(usedIdxs).doAll(result.anyVec());
        fr.delete();
        return new ValFrame(result);
    }

    public class LongAry
    extends Iced<LongAry> {
        long[] _ary = new long[4];
        int _sz;

        public LongAry(long ... vals) {
            this._ary = vals;
            this._sz = vals.length;
        }

        public void add(long i) {
            if (this._sz == this._ary.length) {
                this._ary = Arrays.copyOf(this._ary, Math.max(4, this._ary.length * 2));
            }
            this._ary[this._sz++] = i;
        }

        public long get(int i) {
            if (i >= this._sz) {
                throw new ArrayIndexOutOfBoundsException(i);
            }
            return this._ary[i];
        }

        public int size() {
            return this._sz;
        }

        public long[] toArray() {
            return Arrays.copyOf(this._ary, this._sz);
        }

        public void clear() {
            this._sz = 0;
        }
    }

    public class ClassIdxTask
    extends MRTask<ClassIdxTask> {
        LongAry[] _iarray;
        int _nclasses;
        ArrayList<Long> _classes;

        ClassIdxTask(int nclasses, long[] classes) {
            this._nclasses = nclasses;
            Long[] boxed = ArrayUtils.toObject((long[])classes);
            this._classes = new ArrayList<Long>(Arrays.asList(boxed));
        }

        @Override
        public void map(Chunk[] ck) {
            int i;
            this._iarray = new LongAry[this._nclasses];
            for (i = 0; i < this._nclasses; ++i) {
                this._iarray[i] = new LongAry(new long[0]);
            }
            for (i = 0; i < ck[0].len(); ++i) {
                long clas = ck[0].at8(i);
                int clas_idx = this._classes.indexOf(clas);
                this._iarray[clas_idx].add(ck[0].start() + (long)i);
            }
        }

        @Override
        public void reduce(ClassIdxTask c) {
            for (int i = 0; i < c._iarray.length; ++i) {
                for (int j = 0; j < c._iarray[i].size(); ++j) {
                    this._iarray[i].add(c._iarray[i].get(j));
                }
            }
        }
    }

    public static class ClassAssignMRTask
    extends MRTask<ClassAssignMRTask> {
        HashSet<Long> _idx;

        ClassAssignMRTask(HashSet<Long> idx) {
            this._idx = idx;
        }

        @Override
        public void map(Chunk ck) {
            for (int i = 0; i < ck.len(); ++i) {
                if (!this._idx.contains(ck.start() + (long)i)) continue;
                ck.set(i, 1.0);
            }
        }
    }
}

