/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import java.util.Random;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.RandomUtils;
import water.util.VecUtils;

public class AstStratifiedSplit
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"ary", "test_frac", "seed"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "h2o.random_stratified_split";
    }

    @Override
    public Val apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Frame fr = stk.track(asts[1].exec(env)).getFrame();
        if (fr.numCols() != 1) {
            throw new IllegalArgumentException("Must give a single column to stratify against. Got: " + fr.numCols() + " columns.");
        }
        Vec y = fr.anyVec();
        if (!(y.isCategorical() || y.isNumeric() && y.isInt())) {
            throw new IllegalArgumentException("stratification only applies to integer and categorical columns. Got: " + y.get_type_str());
        }
        final double testFrac = asts[2].exec(env).getNum();
        long seed = (long)asts[3].exec(env).getNum();
        seed = seed == -1L ? new Random().nextLong() : seed;
        final long[] classes = ((VecUtils.CollectDomain)new VecUtils.CollectDomain().doAll(y)).domain();
        final int nClass = y.isNumeric() ? classes.length : y.domain().length;
        final long[] seeds = new long[nClass];
        for (int i = 0; i < nClass; ++i) {
            seeds[i] = RandomUtils.getRNG(seed + (long)i).nextLong();
        }
        String[] dom = new String[]{"train", "test"};
        return new ValFrame(((MRTask)new MRTask(){

            private boolean isTest(int row, long seed) {
                return RandomUtils.getRNG((long)row + seed).nextDouble() <= testFrac;
            }

            @Override
            public void map(Chunk y, NewChunk ss) {
                int start = (int)y.start();
                for (int classLabel = 0; classLabel < nClass; ++classLabel) {
                    for (int row = 0; row < y._len; ++row) {
                        if (y.at8(row) != (classes == null ? (long)classLabel : classes[classLabel])) continue;
                        if (this.isTest(start + row, seeds[classLabel])) {
                            ss.addNum(1L, 0);
                            continue;
                        }
                        ss.addNum(0L, 0);
                    }
                }
            }
        }.doAll(1, (byte)3, new Frame(y))).outputFrame(new String[]{"test_train_split"}, new String[][]{dom}));
    }
}

