/*
 * Decompiled with CFR 0.152.
 */
package hex.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendTrain;
import hex.FrameTask;
import hex.deepwater.DeepWaterDatasetIterator;
import hex.deepwater.DeepWaterImageIterator;
import hex.deepwater.DeepWaterIterator;
import hex.deepwater.DeepWaterModel;
import hex.deepwater.DeepWaterModelInfo;
import hex.deepwater.DeepWaterParameters;
import hex.deepwater.DeepWaterTextIterator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import java.util.concurrent.Future;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RandomUtils;

public class DeepWaterTask
extends FrameTask<DeepWaterTask> {
    private DeepWaterModelInfo _localmodel;
    private DeepWaterModelInfo _sharedmodel;
    private int _chunk_node_count = 1;
    private float _useFraction;
    private boolean _shuffle;
    private final Job _job;
    private static long _lastWarn;
    private static long _warnCount;

    public final DeepWaterModelInfo model_info() {
        assert (this._sharedmodel != null);
        return this._sharedmodel;
    }

    DeepWaterTask(DeepWaterModelInfo inputModel, float fraction, Job job) {
        super((Key<Job>)job._key, inputModel._dataInfo);
        this._sharedmodel = inputModel;
        this._useFraction = fraction;
        this._shuffle = this.model_info().get_params()._shuffle_training_data;
        this._job = job;
    }

    @Override
    protected void setupLocal() {
        assert (this._localmodel == null);
        this._localmodel = this._sharedmodel;
        this._sharedmodel = null;
        this._localmodel.set_processed_local(0L);
        int weightIdx = this._fr.find(this._localmodel.get_params()._weights_column);
        int respIdx = this._fr.find(this._localmodel.get_params()._response_column);
        int batchSize = this._localmodel.get_params()._mini_batch_size;
        DeepWaterIterator iter = null;
        long seed = 912559L + 53261L * this._localmodel.get_processed_global();
        Random rng = RandomUtils.getRNG((long[])new long[]{seed});
        if (this._fr.numRows() > Integer.MAX_VALUE) {
            throw H2O.unimpl((String)"Need to implement batching into int-sized chunks.");
        }
        int len = (int)this._fr.numRows();
        int j = 0;
        Futures fs = new Futures();
        ArrayList<Float> trainLabels = new ArrayList<Float>();
        ArrayList<Integer> trainData = new ArrayList<Integer>();
        try {
            if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image || this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) {
                float response;
                BufferedString file;
                double weight;
                int i;
                int dataIdx = 0;
                Log.debug((Object[])new Object[]{"Using column " + this._fr.name(dataIdx) + " for " + (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image ? "path to image data" : (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text ? "text data" : "path to arbitrary bytes"))});
                BufferedString bs = new BufferedString();
                int fullpasses = (int)this._useFraction;
                while (j++ < fullpasses) {
                    i = 0;
                    while ((long)i < this._fr.numRows()) {
                        double d = weight = weightIdx == -1 ? 1.0 : this._fr.vec(weightIdx).at((long)i);
                        if (weight != 0.0) {
                            file = this._fr.vec(dataIdx).atStr(bs, (long)i);
                            if (file != null) {
                                trainData.add((Integer)((Object)file.toString()));
                            }
                            response = (float)this._fr.vec(respIdx).at((long)i);
                            trainLabels.add(Float.valueOf(response));
                        }
                        ++i;
                    }
                }
                while ((float)trainData.size() < this._useFraction * (float)len || trainData.size() % batchSize != 0) {
                    assert (this._shuffle);
                    i = rng.nextInt(len);
                    weight = weightIdx == -1 ? 1.0 : this._fr.vec(weightIdx).at((long)i);
                    if (weight == 0.0) continue;
                    file = this._fr.vec(dataIdx).atStr(bs, (long)i);
                    if (file != null) {
                        trainData.add((Integer)((Object)file.toString()));
                    }
                    response = (float)this._fr.vec(respIdx).at((long)i);
                    trainLabels.add(Float.valueOf(response));
                }
            } else if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) {
                float response;
                double weight;
                int i;
                double mul = this._localmodel._dataInfo._normRespMul != null ? this._localmodel._dataInfo._normRespMul[0] : 1.0;
                double sub = this._localmodel._dataInfo._normRespSub != null ? this._localmodel._dataInfo._normRespSub[0] : 0.0;
                int fullpasses = (int)this._useFraction;
                while (j++ < fullpasses) {
                    i = 0;
                    while ((long)i < this._fr.numRows()) {
                        double d = weight = weightIdx == -1 ? 1.0 : this._fr.vec(weightIdx).at((long)i);
                        if (weight != 0.0) {
                            response = (float)((this._fr.vec(respIdx).at((long)i) - sub) / mul);
                            trainData.add(i);
                            trainLabels.add(Float.valueOf(response));
                        }
                        ++i;
                    }
                }
                while ((float)trainData.size() < this._useFraction * (float)len || trainData.size() % batchSize != 0) {
                    i = rng.nextInt(len);
                    double d = weight = weightIdx == -1 ? 1.0 : this._fr.vec(weightIdx).at((long)i);
                    if (weight == 0.0) continue;
                    response = (float)((this._fr.vec(respIdx).at((long)i) - sub) / mul);
                    trainData.add(i);
                    trainLabels.add(Float.valueOf(response));
                }
            }
            if (this._shuffle) {
                rng.setSeed(seed);
                Collections.shuffle(trainLabels, rng);
                rng.setSeed(seed);
                Collections.shuffle(trainData, rng);
            }
            if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image) {
                iter = new DeepWaterImageIterator(trainData, trainLabels, this._localmodel._meanData, batchSize, this._localmodel._width, this._localmodel._height, this._localmodel._channels, this._localmodel.get_params()._cache_data);
            } else if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) {
                assert (this._localmodel._dataInfo != null);
                iter = new DeepWaterDatasetIterator(trainData, trainLabels, this._localmodel._dataInfo, batchSize, this._localmodel.get_params()._cache_data);
            } else if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) {
                iter = new DeepWaterTextIterator(trainData, trainLabels, batchSize, 56, this._localmodel.get_params()._cache_data);
            }
            while (iter.Next(fs) && !this._job.isStopping()) {
                long n = this._localmodel.get_processed_total();
                this._localmodel._backend.setParameter(this._localmodel._model, "learning_rate", this._localmodel.get_params().learningRate(n));
                this._localmodel._backend.setParameter(this._localmodel._model, "momentum", this._localmodel.get_params().momentum(n));
                float[] preds = this._localmodel._backend.predict(this._localmodel._model, iter.getData());
                if (Float.isNaN(ArrayUtils.sum((float[])preds))) {
                    Log.err((Object[])new Object[]{DeepWaterModel.unstable_msg});
                    throw new UnsupportedOperationException(DeepWaterModel.unstable_msg);
                }
                NativeTrainTask ntt = new NativeTrainTask(this._localmodel._backend, this._localmodel._model, iter.getData(), iter.getLabel());
                fs.add((Future)H2O.submitTask((H2O.H2OCountedCompleter)ntt));
                this._localmodel.add_processed_local(iter._batch_size);
            }
            fs.blockForPending();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public void map(Chunk[] chunks, NewChunk[] outputs) {
    }

    @Override
    protected void closeLocal() {
        this._sharedmodel = null;
    }

    public void reduce(DeepWaterTask other) {
        if (this._localmodel != null && other._localmodel != null && other._localmodel.get_processed_local() > 0L && other._localmodel != this._localmodel) {
            if (this._localmodel.get_processed_local() == 0L) {
                this._localmodel = other._localmodel;
                this._chunk_node_count = other._chunk_node_count;
            } else {
                this._localmodel.add(other._localmodel);
                this._chunk_node_count += other._chunk_node_count;
            }
        }
    }

    protected void postGlobal() {
        DeepWaterParameters dlp = this._localmodel.get_params();
        if (H2O.CLOUD.size() > 1 && !dlp._replicate_training_data) {
            long now = System.currentTimeMillis();
            if (this._chunk_node_count < H2O.CLOUD.size() && now - _lastWarn > 5000L && _warnCount < 3L) {
                Log.warn((Object[])new Object[]{H2O.CLOUD.size() - this._chunk_node_count + " node(s) (out of " + H2O.CLOUD.size() + ") are not contributing to model updates. Consider setting replicate_training_data to true or using a larger training dataset (or fewer H2O nodes)."});
                _lastWarn = now;
                ++_warnCount;
            }
        }
        assert ((!dlp._replicate_training_data || H2O.CLOUD.size() == 1) == !this._run_local);
        if (!this._run_local) {
            this._localmodel.add_processed_global(this._localmodel.get_processed_local());
            this._localmodel.set_processed_local(0L);
            if (this._chunk_node_count > 1) {
                this._localmodel.div(this._chunk_node_count);
            }
        } else {
            this._sharedmodel = this._localmodel;
        }
        if (this._sharedmodel == null) {
            this._sharedmodel = this._localmodel;
        }
        this._localmodel = null;
    }

    private static class NativeTrainTask
    extends H2O.H2OCountedCompleter<NativeTrainTask> {
        long _timeInMillis;
        final BackendTrain _backend;
        final BackendModel _model;
        float[] _data;
        float[] _labels;

        NativeTrainTask(BackendTrain backend, BackendModel model, float[] data, float[] label) {
            this._backend = backend;
            this._model = model;
            this._data = data;
            this._labels = label;
        }

        public void compute2() {
            long start = System.currentTimeMillis();
            this._backend.train(this._model, this._data, this._labels);
            long end = System.currentTimeMillis();
            this._timeInMillis += end - start;
            this.tryComplete();
        }
    }
}

