/*
 * Decompiled with CFR 0.152.
 */
package hex.deepwater;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ToEigenVec;
import hex.deepwater.DeepWaterModel;
import hex.deepwater.DeepWaterModelInfo;
import hex.deepwater.DeepWaterModelOutput;
import hex.deepwater.DeepWaterParameters;
import hex.deepwater.DeepWaterTask;
import hex.deepwater.DeepWaterTask2;
import hex.genmodel.algos.deepwater.DeepwaterMojoModel;
import hex.util.LinearAlgebraUtils;
import java.util.ArrayList;
import java.util.Arrays;
import water.DKV;
import water.H2O;
import water.H2ONode;
import water.Iced;
import water.IcedUtils;
import water.Job;
import water.Key;
import water.Keyed;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Log;
import water.util.MRUtils;
import water.util.PrettyPrint;

public class DeepWater
extends ModelBuilder<DeepWaterModel, DeepWaterParameters, DeepWaterModelOutput> {
    public DeepWater(DeepWaterParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    public DeepWater(boolean startup_once) {
        super((Model.Parameters)new DeepWaterParameters(), startup_once);
    }

    static boolean haveBackend() {
        for (DeepWaterParameters.Backend b : DeepWaterParameters.Backend.values()) {
            if (DeepwaterMojoModel.createDeepWaterBackend((String)b.toString()) == null) continue;
            return true;
        }
        return false;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return DeepWater.haveBackend() ? ModelBuilder.BuilderVisibility.Stable : ModelBuilder.BuilderVisibility.Experimental;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public boolean haveMojo() {
        return true;
    }

    public boolean havePojo() {
        return false;
    }

    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public boolean isSupervised() {
        return !((DeepWaterParameters)this._parms)._autoencoder;
    }

    protected int nModelsInParallel() {
        return 1;
    }

    protected DeepWaterDriver trainModelImpl() {
        return new DeepWaterDriver();
    }

    public void init(boolean expensive) {
        super.init(expensive);
        ((DeepWaterParameters)this._parms).validate(this, expensive);
        if (expensive && this.error_count() == 0) {
            this.checkMemoryFootPrint();
        }
    }

    protected boolean ignoreStringColumns() {
        return ((DeepWaterParameters)this._parms).guessProblemType() == DeepWaterParameters.ProblemType.dataset;
    }

    public void cv_computeAndSetOptimalParameters(ModelBuilder<DeepWaterModel, DeepWaterParameters, DeepWaterModelOutput>[] cvModelBuilders) {
        ((DeepWaterParameters)this._parms)._overwrite_with_best_model = false;
        if (((DeepWaterParameters)this._parms)._stopping_rounds == 0 && ((DeepWaterParameters)this._parms)._max_runtime_secs == 0.0) {
            return;
        }
        ((DeepWaterParameters)this._parms)._stopping_rounds = 0;
        ((DeepWaterParameters)this._parms)._max_runtime_secs = 0.0;
        double sum = 0.0;
        for (ModelBuilder<DeepWaterModel, DeepWaterParameters, DeepWaterModelOutput> cvmb : cvModelBuilders) {
            sum += ((DeepWaterModel)DKV.getGet((Key)cvmb.dest())).last_scored().epoch_counter;
        }
        ((DeepWaterParameters)this._parms)._epochs = sum / (double)cvModelBuilders.length;
        if (!((DeepWaterParameters)this._parms)._quiet_mode) {
            this.warn("_epochs", "Setting optimal _epochs to " + ((DeepWaterParameters)this._parms)._epochs + " for cross-validation main model based on early stopping of cross-validation models.");
            this.warn("_stopping_rounds", "Disabling convergence-based early stopping for cross-validation main model.");
            this.warn("_max_runtime_secs", "Disabling maximum allowed runtime for cross-validation main model.");
        }
    }

    public class DeepWaterDriver
    extends ModelBuilder.Driver {
        public DeepWaterDriver() {
            super((ModelBuilder)DeepWater.this);
        }

        public void computeImpl() {
            DeepWater.this.init(true);
            long cs = ((DeepWaterParameters)DeepWater.this._parms).checksum();
            if (DeepWater.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)DeepWater.this);
            }
            this.buildModel();
            long cs2 = ((DeepWaterParameters)DeepWater.this._parms).checksum();
            assert (cs == cs2);
        }

        final void buildModel() {
            DeepWaterModel cp = null;
            if (((DeepWaterParameters)DeepWater.this._parms)._checkpoint == null) {
                cp = new DeepWaterModel((Key<DeepWaterModel>)DeepWater.this._result, (DeepWaterParameters)DeepWater.this._parms, new DeepWaterModelOutput(DeepWater.this), DeepWater.this.train(), DeepWater.this.valid(), DeepWater.this.nclasses());
            } else {
                DeepWaterModel previous = (DeepWaterModel)DKV.getGet((Key)((DeepWaterParameters)DeepWater.this._parms)._checkpoint);
                if (previous == null) {
                    throw new IllegalArgumentException("Checkpoint not found.");
                }
                Log.info((Object[])new Object[]{"Resuming from checkpoint."});
                DeepWater.this._job.update(0L, "Resuming from checkpoint");
                if (DeepWater.this.isClassifier() != ((DeepWaterModelOutput)previous._output).isClassifier()) {
                    throw new H2OIllegalArgumentException("Response type must be the same as for the checkpointed model.");
                }
                if (DeepWater.this.isSupervised() != ((DeepWaterModelOutput)previous._output).isSupervised()) {
                    throw new H2OIllegalArgumentException("Model type must be the same as for the checkpointed model.");
                }
                DeepWaterParameters.Sanity.checkIfParameterChangeAllowed((DeepWaterParameters)previous._parms, (DeepWaterParameters)DeepWater.this._parms);
                DataInfo dinfo = null;
                ArrayList<Key> removeMe = new ArrayList<Key>();
                try {
                    for (String st : previous.adaptTestForTrain(DeepWater.this._train, true, false)) {
                        Log.warn((Object[])new Object[]{st});
                    }
                    for (String st : previous.adaptTestForTrain(DeepWater.this._valid, true, false)) {
                        Log.warn((Object[])new Object[]{st});
                    }
                    if (previous.model_info()._dataInfo != null) {
                        dinfo = DeepWaterModel.makeDataInfo(DeepWater.this._train, DeepWater.this._valid, (DeepWaterParameters)DeepWater.this._parms);
                        DKV.put((Keyed)dinfo);
                        removeMe.add(dinfo._key);
                    }
                    cp = new DeepWaterModel((Key<DeepWaterModel>)DeepWater.this.dest(), (DeepWaterParameters)DeepWater.this._parms, previous, dinfo);
                    cp.write_lock(DeepWater.this._job);
                    if (!Arrays.equals(((DeepWaterModelOutput)cp._output)._names, ((DeepWaterModelOutput)previous._output)._names)) {
                        throw new H2OIllegalArgumentException("The columns of the training data must be the same as for the checkpointed model. Check ignored columns (or disable ignore_const_cols).");
                    }
                    if (!Arrays.deepEquals((Object[])((DeepWaterModelOutput)cp._output)._domains, (Object[])((DeepWaterModelOutput)previous._output)._domains)) {
                        throw new H2OIllegalArgumentException("Categorical factor levels of the training data must be the same as for the checkpointed model.");
                    }
                    if (dinfo != null && dinfo.fullN() != previous.model_info()._dataInfo.fullN()) {
                        throw new H2OIllegalArgumentException("Total number of predictors is different than for the checkpointed model.");
                    }
                    if (((DeepWaterParameters)DeepWater.this._parms)._epochs <= previous.epoch_counter) {
                        throw new H2OIllegalArgumentException("Total number of epochs must be larger than the number of epochs already trained for the checkpointed model (" + previous.epoch_counter + ").");
                    }
                    DeepWaterParameters actualParms = cp.model_info().get_params();
                    assert (actualParms != previous.model_info().get_params());
                    assert (actualParms != DeepWater.this._parms);
                    assert (actualParms != previous._parms);
                    DeepWaterParameters.Sanity.updateParametersDuringCheckpointRestart((DeepWaterParameters)DeepWater.this._parms, (DeepWaterParameters)previous._parms, false, false);
                    DeepWaterParameters.Sanity.updateParametersDuringCheckpointRestart((DeepWaterParameters)DeepWater.this._parms, cp.model_info().get_params(), true, true);
                    DeepWaterParameters.Sanity.modifyParms((DeepWaterParameters)DeepWater.this._parms, cp.model_info().get_params(), DeepWater.this.nclasses());
                    Log.info((Object[])new Object[]{"Continuing training after " + String.format("%.3f", previous.epoch_counter) + " epochs from the checkpointed model."});
                    cp.update(DeepWater.this._job);
                }
                catch (H2OIllegalArgumentException ex) {
                    if (cp != null) {
                        cp.unlock(DeepWater.this._job);
                        cp.delete();
                        cp = null;
                    }
                    throw ex;
                }
                finally {
                    if (cp != null) {
                        cp.unlock(DeepWater.this._job);
                    }
                    for (Key k : removeMe) {
                        DKV.remove((Key)k);
                    }
                }
            }
            this.trainModel(cp);
        }

        private float computeRowUsageFraction(long numRows, long train_samples_per_iteration, boolean replicate_training_data) {
            float rowUsageFraction = (float)train_samples_per_iteration / (float)numRows;
            if (replicate_training_data) {
                rowUsageFraction /= (float)H2O.CLOUD.size();
            }
            assert (rowUsageFraction > 0.0f);
            return rowUsageFraction;
        }

        private float rowFraction(Frame train, DeepWaterParameters p, DeepWaterModel m) {
            return this.computeRowUsageFraction(train.numRows(), m.actual_train_samples_per_iteration, p._replicate_training_data);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public final DeepWaterModel trainModel(DeepWaterModel model) {
            block46: {
                Frame validScoreFrame = null;
                boolean cache = false;
                try {
                    DeepWaterModel best_model;
                    Frame trainScoreFrame;
                    block45: {
                        long now;
                        if (model == null) {
                            model = (DeepWaterModel)DKV.get((Key)DeepWater.this.dest()).get();
                        }
                        Log.info((Object[])new Object[]{"Model category: " + (((DeepWaterParameters)DeepWater.this._parms)._autoencoder ? "Auto-Encoder" : (DeepWater.this.isClassifier() ? "Classification" : "Regression"))});
                        long model_size = model.model_info().size();
                        Log.info((Object[])new Object[]{"Approximate number of model parameters (weights/biases/aux): " + String.format("%,d", model_size / 4L)});
                        model.write_lock(DeepWater.this._job);
                        DeepWater.this._job.update(0L, "Setting up training data...");
                        DeepWaterParameters mp = model.model_info().get_params();
                        Frame tra_fr = new Frame(mp._train, DeepWater.this._train.names(), DeepWater.this._train.vecs());
                        Frame val_fr = DeepWater.this._valid != null ? new Frame(mp._valid, DeepWater.this._valid.names(), DeepWater.this._valid.vecs()) : null;
                        Frame train = tra_fr;
                        if (((DeepWaterModelOutput)model._output).isClassifier() && mp._balance_classes) {
                            DeepWater.this._job.update(0L, "Balancing class distribution of training data...");
                            float[] trainSamplingFactors = new float[train.lastVec().domain().length];
                            if (mp._class_sampling_factors != null) {
                                if (mp._class_sampling_factors.length != train.lastVec().domain().length) {
                                    throw new IllegalArgumentException("class_sampling_factors must have " + train.lastVec().domain().length + " elements");
                                }
                                trainSamplingFactors = (float[])mp._class_sampling_factors.clone();
                            }
                            train = MRUtils.sampleFrameStratified((Frame)train, (Vec)train.lastVec(), (Vec)train.vec(((DeepWaterModelOutput)model._output).weightsName()), (float[])trainSamplingFactors, (long)((long)(mp._max_after_balance_size * (float)train.numRows())), (long)mp._seed, (boolean)true, (boolean)false);
                            Vec l = train.lastVec();
                            Vec w = train.vec(((DeepWaterModelOutput)model._output).weightsName());
                            MRUtils.ClassDist cd = new MRUtils.ClassDist(l);
                            ((DeepWaterModelOutput)model._output)._modelClassDist = DeepWater.this._weights != null ? ((MRUtils.ClassDist)cd.doAll(new Vec[]{l, w})).rel_dist() : ((MRUtils.ClassDist)cd.doAll(new Vec[]{l})).rel_dist();
                        }
                        model.training_rows = train.numRows();
                        long l = ((DeepWaterParameters)DeepWater.this._parms)._train_samples_per_iteration > 0L ? ((DeepWaterParameters)DeepWater.this._parms)._train_samples_per_iteration : (model.actual_train_samples_per_iteration = ((DeepWaterParameters)DeepWater.this._parms)._train_samples_per_iteration == -2L ? (long)(32 * ((DeepWaterParameters)DeepWater.this._parms)._mini_batch_size) : DeepWater.this._train.numRows());
                        if (DeepWater.this._weights != null && DeepWater.this._weights.min() == 0.0 && DeepWater.this._weights.max() == 1.0 && DeepWater.this._weights.isInt()) {
                            model.training_rows = Math.round((double)train.numRows() * DeepWater.this._weights.mean());
                            Log.warn((Object[])new Object[]{"Not counting " + (train.numRows() - model.training_rows) + " rows with weight=0 towards an epoch."});
                        }
                        Log.info((Object[])new Object[]{"One epoch corresponds to " + model.training_rows + " training data rows."});
                        trainScoreFrame = MRUtils.sampleFrame((Frame)train, (long)mp._score_training_samples, (long)mp._seed);
                        if (trainScoreFrame != train) {
                            Scope.track((Frame[])new Frame[]{trainScoreFrame});
                        }
                        if (!((DeepWaterParameters)DeepWater.this._parms)._quiet_mode) {
                            Log.info((Object[])new Object[]{"Number of chunks of the training data: " + train.anyVec().nChunks()});
                        }
                        if (val_fr != null) {
                            model.validation_rows = val_fr.numRows();
                            DeepWater.this._job.update(0L, "Sampling validation data...");
                            validScoreFrame = MRUtils.sampleFrame((Frame)val_fr, (long)mp._score_validation_samples, (long)(mp._seed + 1L));
                            if (validScoreFrame != val_fr) {
                                Scope.track((Frame[])new Frame[]{validScoreFrame});
                            }
                            if (!((DeepWaterParameters)DeepWater.this._parms)._quiet_mode) {
                                Log.info((Object[])new Object[]{"Number of chunks of the validation data: " + validScoreFrame.anyVec().nChunks()});
                            }
                        }
                        if (mp._replicate_training_data && model.actual_train_samples_per_iteration == model.training_rows * (long)(mp._single_node_mode ? 1 : H2O.CLOUD.size()) && !mp._shuffle_training_data && H2O.CLOUD.size() > 1) {
                            if (!mp._quiet_mode) {
                                Log.info((Object[])new Object[]{"Enabling training data shuffling, because all nodes train on the full dataset (replicated training data)."});
                            }
                            mp._shuffle_training_data = true;
                        }
                        if (!mp._shuffle_training_data && model.actual_train_samples_per_iteration == model.training_rows && train.anyVec() != null && train.anyVec().nChunks() == 1) {
                            if (!mp._quiet_mode) {
                                Log.info((Object[])new Object[]{"Enabling training data shuffling to avoid training rows in the same order over and over (no Hogwild since there's only 1 chunk)."});
                            }
                            mp._shuffle_training_data = true;
                        }
                        model._timeLastIterationEnter = now = System.currentTimeMillis();
                        if (((DeepWaterParameters)DeepWater.this._parms)._autoencoder) {
                            DeepWater.this._job.update(0L, "Scoring null model of autoencoder...");
                            if (!mp._quiet_mode) {
                                Log.info((Object[])new Object[]{"Scoring the null model of the autoencoder."});
                            }
                            model.doScoring(trainScoreFrame, validScoreFrame, (Key<Job>)DeepWater.this._job._key, 0, false);
                        }
                        model.update(DeepWater.this._job);
                        model.total_setup_time_ms += now - DeepWater.this._job.start_time();
                        Log.info((Object[])new Object[]{"Total setup time: " + PrettyPrint.msecs((long)model.total_setup_time_ms, (boolean)true)});
                        Log.info((Object[])new Object[]{"Starting to train the Deep Learning model."});
                        DeepWater.this._job.update(0L, "Training...");
                        long bytes = ((DeepWaterParameters)DeepWater.this._parms)._problem_type == DeepWaterParameters.ProblemType.image ? train.numRows() * (long)model.model_info()._width * (long)model.model_info()._height * (long)model.model_info()._channels * 4L : train.byteSize();
                        cache = mp._cache_data;
                        if (cache) {
                            if (bytes < H2O.CLOUD.free_mem() / 4L) {
                                Log.info((Object[])new Object[]{"Automatically enabling data caching, expecting to require " + PrettyPrint.bytes((long)bytes) + "."});
                            } else {
                                Log.info((Object[])new Object[]{"Automatically disabling data caching, since it would require too much space: " + PrettyPrint.bytes((long)bytes) + "."});
                                mp._cache_data = false;
                                cache = false;
                            }
                        }
                        do {
                            ++model.iterations;
                            model.set_model_info(mp._epochs == 0.0 ? model.model_info() : (H2O.CLOUD.size() > 1 && mp._replicate_training_data ? (mp._single_node_mode ? ((DeepWaterTask2)new DeepWaterTask2(DeepWater.this._job._key, train, model.model_info(), this.rowFraction(train, mp, model), model.iterations).doAll(new Key[]{Key.make((H2ONode)H2O.SELF)})).model_info() : ((DeepWaterTask2)new DeepWaterTask2(DeepWater.this._job._key, train, model.model_info(), this.rowFraction(train, mp, model), model.iterations).doAllNodes()).model_info()) : ((DeepWaterTask)new DeepWaterTask(model.model_info(), this.rowFraction(train, mp, model), DeepWater.this._job).doAll(train)).model_info()));
                            long before = System.currentTimeMillis();
                            if (((DeepWaterParameters)DeepWater.this._parms)._export_native_parameters_prefix != null && !((DeepWaterParameters)DeepWater.this._parms)._export_native_parameters_prefix.equals("")) {
                                Log.info((Object[])new Object[]{"Saving model state."});
                                model.exportNativeModel(((DeepWaterParameters)DeepWater.this._parms)._export_native_parameters_prefix, model.iterations);
                            }
                            model.time_for_iteration_overhead_ms = System.currentTimeMillis() - before;
                            if (DeepWater.this.stop_requested() && !DeepWater.this.timeout()) {
                                throw new Job.JobCancelledException();
                            }
                            if (!model.doScoring(trainScoreFrame, validScoreFrame, (Key<Job>)DeepWater.this._job._key, model.iterations, false)) break block45;
                        } while (!DeepWater.this.timeout());
                        DeepWater.this._job.update((long)(mp._epochs * (double)train.numRows()));
                    }
                    if (!DeepWater.this.stop_requested() && ((DeepWaterParameters)DeepWater.this._parms)._overwrite_with_best_model && model.actual_best_model_key != null && ((DeepWaterParameters)DeepWater.this._parms)._nfolds == 0 && (best_model = (DeepWaterModel)DKV.getGet((Key)model.actual_best_model_key)) != null && best_model.loss() < model.loss() && Arrays.equals(best_model.model_info()._network, model.model_info()._network)) {
                        if (!((DeepWaterParameters)DeepWater.this._parms)._quiet_mode) {
                            Log.info((Object[])new Object[]{"Setting the model to be the best model so far (based on scoring history)."});
                            Log.info((Object[])new Object[]{"Best model's loss: " + best_model.loss() + " vs this model's loss (before overwriting it with the best model): " + model.loss()});
                        }
                        model.model_info().nativeToJava();
                        model.removeNativeState();
                        DeepWaterModelInfo mi = (DeepWaterModelInfo)IcedUtils.deepCopy((Iced)best_model.model_info());
                        mi.set_processed_global(model.model_info().get_processed_global());
                        mi.set_processed_local(model.model_info().get_processed_local());
                        model.set_model_info(mi);
                        model.update(DeepWater.this._job);
                        model.doScoring(trainScoreFrame, validScoreFrame, (Key<Job>)DeepWater.this._job._key, model.iterations, true);
                        if (!((DeepWaterParameters)DeepWater.this._parms)._quiet_mode) {
                            Log.info((Object[])new Object[]{"  Note: best model was at " + (float)best_model.epoch_counter + " (out of " + (float)model.epoch_counter + ") epochs."});
                        }
                        if ((double)Math.abs(best_model.loss() - model.loss()) >= 1.0E-5 * (double)Math.abs(model.loss() + best_model.loss())) {
                            Log.info((Object[])new Object[]{"Best model's loss: " + best_model.loss() + " vs this model's loss (after overwriting it with the best model) : " + model.loss()});
                            Log.warn((Object[])new Object[]{"Even though the model was reset to the previous best model, we observe different scoring results. Most likely, the data set has changed during a checkpoint restart. If so, please compare the metrics to observe your data shift."});
                        }
                    }
                    if (model != null) {
                        if (model.model_info() != null && model.model_info()._backend != null) {
                            model.model_info().nativeToJava();
                        }
                        if (cache) {
                            model.cleanUpCache();
                        }
                        model.removeNativeState();
                    }
                    if (((DeepWaterParameters)DeepWater.this._parms)._quiet_mode) break block46;
                }
                catch (Throwable throwable) {
                    if (model != null) {
                        if (model.model_info() != null && model.model_info()._backend != null) {
                            model.model_info().nativeToJava();
                        }
                        if (cache) {
                            model.cleanUpCache();
                        }
                        model.removeNativeState();
                    }
                    if (!((DeepWaterParameters)DeepWater.this._parms)._quiet_mode) {
                        Log.info((Object[])new Object[]{"=============================================================================================================================================================================="});
                        if (DeepWater.this.stop_requested()) {
                            Log.info((Object[])new Object[]{"Deep Water model training was interrupted."});
                        } else {
                            Log.info((Object[])new Object[]{"Finished training the Deep Water model."});
                            Log.info((Object[])new Object[]{model});
                        }
                        Log.info((Object[])new Object[]{"=============================================================================================================================================================================="});
                    }
                    if (model != null) {
                        model.unlock(DeepWater.this._job);
                        if (model.actual_best_model_key != null) {
                            assert (model.actual_best_model_key != model._key);
                            DKV.remove((Key)model.actual_best_model_key);
                        }
                    }
                    throw throwable;
                }
                Log.info((Object[])new Object[]{"=============================================================================================================================================================================="});
                if (DeepWater.this.stop_requested()) {
                    Log.info((Object[])new Object[]{"Deep Water model training was interrupted."});
                } else {
                    Log.info((Object[])new Object[]{"Finished training the Deep Water model."});
                    Log.info((Object[])new Object[]{model});
                }
                Log.info((Object[])new Object[]{"=============================================================================================================================================================================="});
            }
            if (model != null) {
                model.unlock(DeepWater.this._job);
                if (model.actual_best_model_key != null) {
                    assert (model.actual_best_model_key != model._key);
                    DKV.remove((Key)model.actual_best_model_key);
                }
            }
            return model;
        }
    }
}

