/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;
import water.DKV;
import water.Iced;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.RandomUtils;
import water.util.VecUtils;

public class AstStratifiedSplit
extends AstPrimitive {
    public static final String OUTPUT_COLUMN_NAME = "test_train_split";
    public static final String[] OUTPUT_COLUMN_DOMAIN = new String[]{"train", "test"};

    @Override
    public String[] args() {
        return new String[]{"ary", "test_frac", "seed"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "h2o.random_stratified_split";
    }

    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Frame frame = stk.track(asts[1].exec(env)).getFrame();
        double testFrac = asts[2].exec(env).getNum();
        long seed = (long)asts[3].exec(env).getNum();
        if (frame.numCols() != 1) {
            throw new IllegalArgumentException("Must give a single column to stratify against. Got: " + frame.numCols() + " columns.");
        }
        Vec stratifyingColumn = frame.anyVec();
        Frame result = new Frame(Key.make(), new String[]{OUTPUT_COLUMN_NAME}, new Vec[]{AstStratifiedSplit.split(stratifyingColumn, testFrac, seed, OUTPUT_COLUMN_DOMAIN)});
        return new ValFrame(result);
    }

    public static Vec split(Vec stratifyingColumn, double splittingFraction, long randomizationSeed, String[] splittingDom) {
        AstStratifiedSplit.checkIfCanStratifyBy(stratifyingColumn);
        randomizationSeed = randomizationSeed == -1L ? new Random().nextLong() : randomizationSeed;
        long[] classes = ((VecUtils.CollectIntegerDomain)new VecUtils.CollectIntegerDomain().doAll(stratifyingColumn)).domain();
        int numClasses = stratifyingColumn.isNumeric() ? classes.length : stratifyingColumn.domain().length;
        Vec outputVec = stratifyingColumn.makeCon(0.0, (byte)4);
        outputVec.setDomain(splittingDom);
        DKV.put(outputVec);
        ClassIdxTask finTask = (ClassIdxTask)new ClassIdxTask(numClasses, classes).doAll(stratifyingColumn);
        HashSet<Long> usedIdxs = new HashSet<Long>();
        for (int classLabel = 0; classLabel < numClasses; ++classLabel) {
            LongAry indexAry = finTask._indexes[classLabel];
            long tnum = Math.max(Math.round((double)indexAry.size() * splittingFraction), 1L);
            HashSet<Long> tmpIdxs = new HashSet<Long>();
            int generated = 0;
            int count = 0;
            while ((long)generated < tnum) {
                int i = (int)(RandomUtils.getRNG((long)count + randomizationSeed).nextDouble() * (double)indexAry.size());
                if (tmpIdxs.contains(indexAry.get(i))) {
                    ++count;
                    continue;
                }
                tmpIdxs.add(indexAry.get(i));
                ++generated;
                ++count;
            }
            usedIdxs.addAll(tmpIdxs);
        }
        new ClassAssignMRTask(usedIdxs).doAll(outputVec);
        return outputVec;
    }

    static void checkIfCanStratifyBy(Vec vec) {
        if (!(vec.isCategorical() || vec.isNumeric() && vec.isInt())) {
            throw new IllegalArgumentException("Stratification only applies to integer and categorical columns. Got: " + vec.get_type_str());
        }
        if (vec.length() > Integer.MAX_VALUE) {
            throw new IllegalArgumentException("Cannot stratified the frame because it is too long: nrows=" + vec.length());
        }
    }

    public static class LongAry
    extends Iced<LongAry> {
        long[] _ary = new long[4];
        int _sz;

        public LongAry(long ... vals) {
            this._ary = vals;
            this._sz = vals.length;
        }

        public void add(long i) {
            if (this._sz == this._ary.length) {
                this._ary = Arrays.copyOf(this._ary, Math.max(4, this._ary.length * 2));
            }
            this._ary[this._sz++] = i;
        }

        public long get(int i) {
            if (i >= this._sz) {
                throw new ArrayIndexOutOfBoundsException(i);
            }
            return this._ary[i];
        }

        public int size() {
            return this._sz;
        }

        public long[] toArray() {
            return Arrays.copyOf(this._ary, this._sz);
        }

        public void clear() {
            this._sz = 0;
        }
    }

    public static class ClassIdxTask
    extends MRTask<ClassIdxTask> {
        LongAry[] _indexes;
        private final int _nclasses;
        private long[] _classes;
        private transient HashMap<Long, Integer> _classMap;

        public ClassIdxTask(int nclasses, long[] classes) {
            this._nclasses = nclasses;
            this._classes = classes;
        }

        @Override
        protected void setupLocal() {
            this._classMap = new HashMap(2 * this._classes.length);
            for (int i = 0; i < this._classes.length; ++i) {
                this._classMap.put(this._classes[i], i);
            }
        }

        @Override
        public void map(Chunk[] ck) {
            int i;
            this._indexes = new LongAry[this._nclasses];
            for (i = 0; i < this._nclasses; ++i) {
                this._indexes[i] = new LongAry(new long[0]);
            }
            for (i = 0; i < ck[0].len(); ++i) {
                long clas = ck[0].at8(i);
                Integer clas_idx = this._classMap.get(clas);
                if (clas_idx == null) continue;
                this._indexes[clas_idx].add(ck[0].start() + (long)i);
            }
            this._classes = null;
        }

        @Override
        public void reduce(ClassIdxTask c) {
            for (int i = 0; i < c._indexes.length; ++i) {
                for (int j = 0; j < c._indexes[i].size(); ++j) {
                    this._indexes[i].add(c._indexes[i].get(j));
                }
            }
        }
    }

    public static class ClassAssignMRTask
    extends MRTask<ClassAssignMRTask> {
        HashSet<Long> _idx;

        ClassAssignMRTask(HashSet<Long> idx) {
            this._idx = idx;
        }

        @Override
        public void map(Chunk ck) {
            for (int i = 0; i < ck.len(); ++i) {
                if (!this._idx.contains(ck.start() + (long)i)) continue;
                ck.set(i, 1.0);
            }
            this._idx = null;
        }
    }
}

