/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.utils;

import java.util.List;
import java.util.Random;
import water.DKV;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.util.RandomUtils;

public class AutoMLUtils {
    public static Vec[] makeWeights(Vec responseVec, final double trainRatio, final double[] weightMult) {
        Vec[] weights = new Vec[2];
        weights[0] = responseVec.makeZero();
        final long seed = RandomUtils.getRNG((long[])new long[]{new Random().nextLong()}).nextLong();
        new MRTask(){

            public void map(Chunk[] c) {
                long start = c[0].start();
                RandomUtils.PCGRNG rng = new RandomUtils.PCGRNG(start, 1L);
                for (int i = 0; i < c[0]._len; ++i) {
                    int yval = (int)c[0].at8(i);
                    rng.setSeed(seed + start + (long)i);
                    c[1].set(i, (double)rng.nextFloat() < trainRatio ? (weightMult == null ? 1.0 : weightMult[yval]) : 0.0);
                }
            }
        }.doAll(new Vec[]{responseVec, weights[0]});
        if (null != weightMult) {
            weights[1] = new MRTask(){

                public void map(Chunk[] cs, NewChunk n) {
                    for (int i = 0; i < cs[0]._len; ++i) {
                        if (0L == cs[1].at8(i)) {
                            n.addNum(weightMult[(int)cs[0].at8(i)]);
                            continue;
                        }
                        n.addNum(0.0);
                    }
                }
            }.doAll((byte)3, new Frame(new Vec[]{responseVec, weights[0]})).outputFrame().anyVec();
        }
        return weights;
    }

    public static Vec[] makeStratifiedWeights(Vec responseVec, final double trainRatio, final double[] weightMult) {
        Vec[] weights = new Vec[2];
        long seed = RandomUtils.getRNG((long[])new long[]{new Random().nextLong()}).nextLong();
        final int nClass = responseVec.domain().length;
        final long[] seeds = new long[nClass];
        for (int i = 0; i < nClass; ++i) {
            seeds[i] = RandomUtils.getRNG((long[])new long[]{seed + (long)i}).nextLong();
        }
        weights[0] = new MRTask(){

            private boolean isTest(int row, long seed) {
                return RandomUtils.getRNG((long[])new long[]{(long)row + seed}).nextDouble() > trainRatio;
            }

            public void map(Chunk y, NewChunk ss) {
                int start = (int)y.start();
                for (int classLabel = 0; classLabel < nClass; ++classLabel) {
                    for (int row = 0; row < y._len; ++row) {
                        int yval = (int)y.at8(row);
                        if (yval != classLabel) continue;
                        ss.addNum(this.isTest(start + row, seeds[classLabel]) ? 0.0 : (weightMult == null ? 1.0 : weightMult[yval]));
                    }
                }
            }
        }.doAll((byte)3, new Vec[]{responseVec}).outputFrame().anyVec();
        if (null != weightMult) {
            weights[1] = weights[0].makeZero();
            new MRTask(){

                public void map(Chunk[] cs) {
                    for (int i = 0; i < cs[0]._len; ++i) {
                        if (0L != cs[1].at8(i)) continue;
                        cs[2].set(i, weightMult[(int)cs[0].at8(i)]);
                    }
                }
            }.doAll(new Vec[]{responseVec, weights[0], weights[1]});
        }
        return weights;
    }

    public static Frame[] makeTrainTest(Frame fr, String response, double trainRatio, boolean stratified, double[] weightMult) {
        Frame[] res = new Frame[2];
        Vec[] trainTestWeights = stratified ? AutoMLUtils.makeStratifiedWeights(fr.vec(response), trainRatio, weightMult) : AutoMLUtils.makeWeights(fr.vec(response), trainRatio, weightMult);
        Vec[] vecs = new Vec[fr.numCols() + 1];
        String[] names = new String[fr.numCols() + 1];
        System.arraycopy(fr.names(), 0, names, 0, fr.names().length);
        System.arraycopy(fr.vecs(), 0, vecs, 0, fr.vecs().length);
        names[names.length - 1] = "weight";
        vecs[vecs.length - 1] = trainTestWeights[0];
        res[0] = new Frame(Key.make(), (String[])names.clone(), (Vec[])vecs.clone());
        DKV.put((Keyed)res[0]);
        vecs = (Vec[])vecs.clone();
        vecs[vecs.length - 1] = trainTestWeights[1];
        res[1] = new Frame(Key.make(), (String[])names.clone(), (Vec[])vecs.clone());
        DKV.put((Keyed)res[1]);
        return res;
    }

    public static Frame[] makeTrainTestFromWeight(Frame fr, Vec[] trainTestWeight) {
        Frame[] res = new Frame[2];
        Vec[] vecs = new Vec[fr.numCols() + 1];
        String[] names = new String[fr.numCols() + 1];
        System.arraycopy(fr.names(), 0, names, 0, fr.names().length);
        System.arraycopy(fr.vecs(), 0, vecs, 0, fr.vecs().length);
        names[names.length - 1] = "weight";
        vecs[vecs.length - 1] = trainTestWeight[0];
        res[0] = new Frame(Key.make(), (String[])names.clone(), (Vec[])vecs.clone());
        DKV.put((Keyed)res[0]);
        vecs = (Vec[])vecs.clone();
        vecs[vecs.length - 1] = trainTestWeight[1];
        res[1] = new Frame(Key.make(), (String[])names.clone(), (Vec[])vecs.clone());
        DKV.put((Keyed)res[1]);
        return res;
    }

    public static int[] intListToA(List<Integer> list) {
        int[] a = new int[]{};
        if (list.size() > 0) {
            a = new int[list.size()];
            for (int i = 0; i < a.length; ++i) {
                a[i] = list.get(i);
            }
        }
        return a;
    }

    public static void cleanup_adapt(Frame adaptFr, Frame fr) {
        Key[] keys = adaptFr.keys();
        for (int i = 0; i < keys.length; ++i) {
            if (fr.find(keys[i]) != -1) continue;
            keys[i].remove();
        }
        DKV.remove((Key)adaptFr._key);
    }
}

