/*
 * Decompiled with CFR 0.152.
 */
package water.rapids;

import java.util.Random;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.AST;
import water.rapids.ASTPrim;
import water.rapids.Env;
import water.rapids.ValFrame;
import water.util.RandomUtils;
import water.util.VecUtils;

public class ASTKFold
extends ASTPrim {
    @Override
    public String[] args() {
        return new String[]{"ary", "nfolds", "seed"};
    }

    @Override
    public int nargs() {
        return 4;
    }

    @Override
    public String str() {
        return "kfold_column";
    }

    public static Vec kfoldColumn(Vec v, final int nfolds, final long seed) {
        new MRTask(){

            @Override
            public void map(Chunk c) {
                long start = c.start();
                for (int i = 0; i < c._len; ++i) {
                    int fold = Math.abs(RandomUtils.getRNG(start + seed + (long)i).nextInt()) % nfolds;
                    c.set(i, fold);
                }
            }
        }.doAll(v);
        return v;
    }

    public static Vec moduloKfoldColumn(Vec v, final int nfolds) {
        new MRTask(){

            @Override
            public void map(Chunk c) {
                long start = c.start();
                for (int i = 0; i < c._len; ++i) {
                    c.set(i, (int)((start + (long)i) % (long)nfolds));
                }
            }
        }.doAll(v);
        return v;
    }

    public static Vec stratifiedKFoldColumn(Vec y, final int nfolds, long seed) {
        if (!(y.isCategorical() || y.isNumeric() && y.isInt())) {
            throw new IllegalArgumentException("stratification only applies to integer and categorical columns. Got: " + y.get_type_str());
        }
        final long[] classes = ((VecUtils.CollectDomain)new VecUtils.CollectDomain().doAll(y)).domain();
        final int nClass = y.isNumeric() ? classes.length : y.domain().length;
        final long[] seeds = new long[nClass];
        for (int i = 0; i < nClass; ++i) {
            seeds[i] = RandomUtils.getRNG(seed + (long)i).nextLong();
        }
        return ((MRTask)new MRTask(){

            private int getFoldId(int absoluteRow, long seed) {
                return Math.abs(RandomUtils.getRNG((long)absoluteRow + seed).nextInt()) % nfolds;
            }

            @Override
            public void map(Chunk[] y) {
                int start = (int)y[0].start();
                for (int testFold = 0; testFold < nfolds; ++testFold) {
                    for (int classLabel = 0; classLabel < nClass; ++classLabel) {
                        for (int row = 0; row < y[0]._len; ++row) {
                            if (y[0].at8(row) != (classes == null ? (long)classLabel : classes[classLabel]) || testFold != this.getFoldId(start + row, seeds[classLabel])) continue;
                            y[1].set(row, testFold);
                        }
                    }
                }
            }
        }.doAll((Frame)new Frame((Vec[])new Vec[]{y, y.makeZero()})))._fr.vec(1);
    }

    @Override
    ValFrame apply(Env env, Env.StackHelp stk, AST[] asts) {
        Vec foldVec = stk.track(asts[1].exec(env)).getFrame().anyVec().makeZero();
        int nfolds = (int)asts[2].exec(env).getNum();
        long seed = (long)asts[3].exec(env).getNum();
        return new ValFrame(new Frame(ASTKFold.kfoldColumn(foldVec, nfolds, seed == -1L ? new Random().nextLong() : seed)));
    }
}

