/*
 * Decompiled with CFR 0.152.
 */
package hex.deepwater;

import hex.Model;
import hex.ScoreKeeper;
import hex.deepwater.DeepWater;
import hex.deepwater.DeepWaterModel;
import hex.genmodel.utils.DistributionFamily;
import java.io.File;
import java.lang.reflect.Field;
import java.net.URL;
import java.util.Arrays;
import javax.imageio.ImageIO;
import water.H2O;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Vec;
import water.parser.BufferedString;
import water.util.ArrayUtils;
import water.util.Log;

public class DeepWaterParameters
extends Model.Parameters {
    public double _clip_gradient = 10.0;
    public boolean _gpu = true;
    public int[] _device_id = new int[]{0};
    public Network _network = Network.auto;
    public Backend _backend = Backend.mxnet;
    public String _network_definition_file;
    public String _network_parameters_file;
    public String _export_native_parameters_prefix;
    public ProblemType _problem_type = ProblemType.auto;
    public int[] _image_shape = new int[]{0, 0};
    public int _channels = 3;
    public String _mean_image_file;
    public boolean _overwrite_with_best_model = true;
    public boolean _autoencoder = false;
    public boolean _sparse = false;
    public boolean _use_all_factor_levels = true;
    public MissingValuesHandling _missing_values_handling = MissingValuesHandling.MeanImputation;
    public boolean _standardize = true;
    public double _epochs = 10.0;
    public Activation _activation = null;
    public int[] _hidden = null;
    public double _input_dropout_ratio = 0.0;
    public double[] _hidden_dropout_ratios = null;
    public long _train_samples_per_iteration = -2L;
    public double _target_ratio_comm_to_comp = 0.05;
    public double _learning_rate = 0.005;
    public double _learning_rate_annealing = 1.0E-6;
    public double _momentum_start = 0.9;
    public double _momentum_ramp = 10000.0;
    public double _momentum_stable = 0.99;
    public double _score_interval = 5.0;
    public long _score_training_samples = 10000L;
    public long _score_validation_samples = 0L;
    public double _score_duty_cycle = 0.1;
    public boolean _quiet_mode = false;
    public boolean _replicate_training_data = true;
    public boolean _single_node_mode = false;
    public boolean _shuffle_training_data = true;
    public int _mini_batch_size = 32;
    protected boolean _cache_data = true;

    public String algoName() {
        return "DeepWater";
    }

    public String fullName() {
        return "Deep Water";
    }

    public String javaName() {
        return DeepWaterModel.class.getName();
    }

    protected double defaultStoppingTolerance() {
        return 0.0;
    }

    public DeepWaterParameters() {
        this._stopping_rounds = 5;
    }

    public long progressUnits() {
        if (this.train() == null) {
            return 1L;
        }
        return (long)Math.ceil(this._epochs * (double)this.train().numRows());
    }

    public float learningRate(double n) {
        return (float)(this._learning_rate / (1.0 + this._learning_rate_annealing * n));
    }

    public final float momentum(double n) {
        double m = this._momentum_start;
        if (this._momentum_ramp > 0.0) {
            m = n >= this._momentum_ramp ? this._momentum_stable : (m += (this._momentum_stable - this._momentum_start) * n / this._momentum_ramp);
        }
        return (float)m;
    }

    void validate(DeepWater dl, boolean expensive) {
        Vec w;
        boolean classification;
        boolean bl = expensive || dl.nclasses() != 0 ? dl.isClassifier() : (classification = this._distribution == DistributionFamily.bernoulli || this._distribution == DistributionFamily.bernoulli);
        if (this._mini_batch_size < 1) {
            dl.error("_mini_batch_size", "Mini-batch size must be >= 1");
        }
        if (this._weights_column != null && expensive && (!(w = this.train().vec(this._weights_column)).isInt() || w.max() > 1.0 || w.min() < 0.0)) {
            dl.error("_weights_column", "only supporting weights of 0 or 1 right now");
        }
        if (this._clip_gradient <= 0.0) {
            dl.error("_clip_gradient", "Clip gradient must be >= 0");
        }
        if (this._hidden != null && this._network_definition_file != null) {
            dl.error("_hidden", "Cannot provide hidden layers and a network definition file at the same time.");
        }
        if (this._activation != null && this._network_definition_file != null) {
            dl.error("_activation", "Cannot provide activation functions and a network definition file at the same time.");
        }
        if (this._problem_type == ProblemType.image) {
            if (this._image_shape.length != 2) {
                dl.error("_image_shape", "image_shape must have 2 dimensions (width, height)");
            }
            if (this._image_shape[0] < 0) {
                dl.error("_image_shape", "image_shape[0] must be >=1 or automatic (0).");
            }
            if (this._image_shape[1] < 0) {
                dl.error("_image_shape", "image_shape[1] must be >=1 or automatic (0).");
            }
            if (this._channels != 1 && this._channels != 3) {
                dl.error("_channels", "channels must be either 1 or 3.");
            }
        } else if (this._problem_type != ProblemType.auto) {
            dl.warn("_image_shape", "image shape is ignored, only used for image_classification");
            dl.warn("_channels", "channels shape is ignored, only used for image_classification");
            dl.warn("_mean_image_file", "mean_image_file shape is ignored, only used for image_classification");
        }
        if (this._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.Enum) {
            dl.error("_categorical_encoding", "categorical encoding scheme cannot be Enum: the neural network must have numeric columns as input.");
        }
        if (this._autoencoder) {
            dl.error("_autoencoder", "Autoencoder is not supported right now.");
        }
        if (this._network == Network.user) {
            if (this._network_definition_file == null || this._network_definition_file.isEmpty()) {
                dl.error("_network_definition_file", "network_definition_file must be provided if the network is user-specified.");
            } else if (!new File(this._network_definition_file).exists()) {
                dl.error("_network_definition_file", "network_definition_file " + this._network_definition_file + " not found.");
            }
        } else if (this._network_definition_file != null && !this._network_definition_file.isEmpty() && this._network != Network.auto) {
            dl.error("_network_definition_file", "network_definition_file cannot be provided if a pre-defined network is chosen.");
        }
        if (this._network_parameters_file != null && !this._network_parameters_file.isEmpty() && !new File(this._network_parameters_file).exists()) {
            dl.error("_network_parameters_file", "network_parameters_file " + this._network_parameters_file + " not found.");
        }
        if (this._checkpoint != null) {
            DeepWaterModel other = (DeepWaterModel)this._checkpoint.get();
            if (other == null) {
                dl.error("_width", "Invalid checkpoint provided: width mismatch.");
            }
            if (!Arrays.equals(this._image_shape, other.get_params()._image_shape)) {
                dl.error("_width", "Invalid checkpoint provided: width mismatch.");
            }
        }
        if (!this._autoencoder) {
            if (classification) {
                dl.hide("_regression_stop", "regression_stop is used only with regression.");
            } else {
                dl.hide("_classification_stop", "classification_stop is used only with classification.");
            }
            if (!classification && this._valid != null || this._valid == null) {
                dl.hide("_score_validation_sampling", "score_validation_sampling requires classification and a validation frame.");
            }
        } else if (this._nfolds > 1) {
            dl.error("_nfolds", "N-fold cross-validation is not supported for Autoencoder.");
        }
        if (H2O.CLOUD.size() == 1 && this._replicate_training_data) {
            dl.hide("_replicate_training_data", "replicate_training_data is only valid with cloud size greater than 1.");
        }
        if (this._single_node_mode && (H2O.CLOUD.size() == 1 || !this._replicate_training_data)) {
            dl.hide("_single_node_mode", "single_node_mode is only used with multi-node operation with replicated training data.");
        }
        if (H2O.ARGS.client && this._single_node_mode) {
            dl.error("_single_node_mode", "Cannot run on a single node in client mode");
        }
        if (this._autoencoder) {
            dl.hide("_use_all_factor_levels", "use_all_factor_levels is mandatory in combination with autoencoder.");
        }
        if (this._nfolds != 0) {
            dl.hide("_overwrite_with_best_model", "overwrite_with_best_model is unsupported in combination with n-fold cross-validation.");
        }
        if (expensive) {
            dl.checkDistributions();
        }
        if (this._score_training_samples < 0L) {
            dl.error("_score_training_samples", "Number of training samples for scoring must be >= 0 (0 for all).");
        }
        if (this._score_validation_samples < 0L) {
            dl.error("_score_validation_samples", "Number of training samples for scoring must be >= 0 (0 for all).");
        }
        if (classification && dl.hasOffsetCol()) {
            dl.error("_offset_column", "Offset is only supported for regression.");
        }
        if (expensive) {
            if (!classification && this._balance_classes) {
                dl.error("_balance_classes", "balance_classes requires classification.");
            }
            if (this._class_sampling_factors != null && !this._balance_classes) {
                dl.error("_class_sampling_factors", "class_sampling_factors requires balance_classes to be enabled.");
            }
            if (this._replicate_training_data && null != this.train() && (double)this.train().byteSize() > 0.9 * (double)H2O.CLOUD.free_mem() / (double)H2O.CLOUD.size() && H2O.CLOUD.size() > 1) {
                dl.error("_replicate_training_data", "Compressed training dataset takes more than 90% of avg. free available memory per node (" + 0.9 * (double)H2O.CLOUD.free_mem() / (double)H2O.CLOUD.size() + "), cannot run with replicate_training_data.");
            }
        }
        if (this._autoencoder && this._stopping_metric != ScoreKeeper.StoppingMetric.AUTO && this._stopping_metric != ScoreKeeper.StoppingMetric.MSE) {
            dl.error("_stopping_metric", "Stopping metric must either be AUTO or MSE for autoencoder.");
        }
    }

    ProblemType guessProblemType() {
        if (this._problem_type == ProblemType.auto) {
            boolean image = false;
            boolean text = false;
            String first = null;
            Vec v = this.train().vec(0);
            if (v.isString() || v.isCategorical()) {
                BufferedString bs = new BufferedString();
                first = v.atStr(bs, 0L).toString();
                try {
                    ImageIO.read(new File(first));
                    image = true;
                }
                catch (Throwable t) {
                    // empty catch block
                }
                try {
                    ImageIO.read(new URL(first));
                    image = true;
                }
                catch (Throwable t) {
                    // empty catch block
                }
            }
            if (first != null) {
                if (!image && (first.endsWith(".jpg") || first.endsWith(".png") || first.endsWith(".tif"))) {
                    image = true;
                    Log.warn((Object[])new Object[]{"Cannot read first image at " + first + " - Check data."});
                } else if (v.isString() && this.train().numCols() <= 4) {
                    text = true;
                }
            }
            if (image) {
                return ProblemType.image;
            }
            if (text) {
                return ProblemType.text;
            }
            return ProblemType.dataset;
        }
        return this._problem_type;
    }

    static class Sanity {
        private static final transient String[] cp_modifiable = new String[]{"_seed", "_checkpoint", "_epochs", "_score_interval", "_train_samples_per_iteration", "_target_ratio_comm_to_comp", "_score_duty_cycle", "_score_training_samples", "_score_validation_samples", "_score_validation_sampling", "_classification_stop", "_regression_stop", "_stopping_rounds", "_stopping_metric", "_quiet_mode", "_max_confusion_matrix_size", "_max_hit_ratio_k", "_diagnostics", "_variable_importances", "_replicate_training_data", "_shuffle_training_data", "_single_node_mode", "_overwrite_with_best_model", "_mini_batch_size", "_network_parameters_file", "_clip_gradient", "_learning_rate", "_learning_rate_annealing", "_gpu", "_sparse", "_device_id", "_input_dropout_ratio", "_hidden_dropout_ratios", "_cache_data", "_export_native_parameters_prefix", "_image_shape"};
        private static final transient String[] cp_not_modifiable = new String[]{"_drop_na20_cols", "_missing_values_handling", "_response_column", "_activation", "_use_all_factor_levels", "_problem_type", "_channels", "_standardize", "_autoencoder", "_network", "_backend", "_momentum_start", "_momentum_ramp", "_momentum_stable", "_ignore_const_cols", "_max_categorical_features", "_nfolds", "_distribution", "_network_definition_file", "_mean_image_file"};

        Sanity() {
        }

        static void checkCompleteness() {
            for (Field f : DeepWaterParameters.class.getDeclaredFields()) {
                if (ArrayUtils.contains((String[])cp_not_modifiable, (String)f.getName()) || ArrayUtils.contains((String[])cp_modifiable, (String)f.getName()) || f.getName().equals("_hidden") || f.getName().equals("_ignored_columns") || f.getName().equals("$jacocoData")) continue;
                throw H2O.unimpl((String)("Please add " + f.getName() + " to either cp_modifiable or cp_not_modifiable"));
            }
        }

        static void checkIfParameterChangeAllowed(DeepWaterParameters oldP, DeepWaterParameters newP) {
            Sanity.checkCompleteness();
            if (newP._nfolds != 0) {
                throw new UnsupportedOperationException("nfolds must be 0: Cross-validation is not supported during checkpoint restarts.");
            }
            if (newP._valid == null != (oldP._valid == null)) {
                throw new H2OIllegalArgumentException("Presence of validation dataset must agree with the checkpointed model.");
            }
            if (!(newP._autoencoder || newP._response_column != null && newP._response_column.equals(oldP._response_column))) {
                throw new H2OIllegalArgumentException("Response column (" + newP._response_column + ") is not the same as for the checkpointed model: " + oldP._response_column);
            }
            if (!Arrays.equals(newP._ignored_columns, oldP._ignored_columns)) {
                throw new H2OIllegalArgumentException("Ignored columns must be the same as for the checkpointed model.");
            }
            for (Field fBefore : ((Object)((Object)oldP)).getClass().getFields()) {
                if (!ArrayUtils.contains((String[])cp_not_modifiable, (String)fBefore.getName())) continue;
                for (Field fAfter : ((Object)((Object)newP)).getClass().getFields()) {
                    if (!fBefore.equals(fAfter)) continue;
                    try {
                        if (fAfter.get((Object)newP) != null && fBefore.get((Object)oldP) != null && fBefore.get((Object)oldP).toString().equals(fAfter.get((Object)newP).toString()) || fBefore.get((Object)oldP) == null && fAfter.get((Object)newP) == null) continue;
                        throw new H2OIllegalArgumentException("Cannot change parameter: '" + fBefore.getName() + "': " + fBefore.get((Object)oldP) + " -> " + fAfter.get((Object)newP));
                    }
                    catch (IllegalAccessException e) {
                        e.printStackTrace();
                    }
                }
            }
        }

        static void updateParametersDuringCheckpointRestart(DeepWaterParameters srcParms, DeepWaterParameters tgtParms, boolean doIt, boolean quiet) {
            for (Field fTarget : ((Object)((Object)tgtParms)).getClass().getFields()) {
                if (!ArrayUtils.contains((String[])cp_modifiable, (String)fTarget.getName())) continue;
                for (Field fSource : ((Object)((Object)srcParms)).getClass().getFields()) {
                    if (!fTarget.equals(fSource)) continue;
                    try {
                        if (fSource.get((Object)srcParms) != null && fTarget.get((Object)tgtParms) != null && fTarget.get((Object)tgtParms).toString().equals(fSource.get((Object)srcParms).toString()) || fTarget.get((Object)tgtParms) == null && fSource.get((Object)srcParms) == null) continue;
                        if (!tgtParms._quiet_mode && !quiet) {
                            Log.info((Object[])new Object[]{"Applying user-requested modification of '" + fTarget.getName() + "': " + fTarget.get((Object)tgtParms) + " -> " + fSource.get((Object)srcParms)});
                        }
                        if (!doIt) continue;
                        fTarget.set((Object)tgtParms, fSource.get((Object)srcParms));
                    }
                    catch (IllegalAccessException e) {
                        e.printStackTrace();
                    }
                }
            }
        }

        static void modifyParms(DeepWaterParameters fromParms, DeepWaterParameters toParms, int nClasses) {
            if (H2O.CLOUD.size() == 1 && fromParms._replicate_training_data) {
                if (!fromParms._quiet_mode) {
                    Log.info((Object[])new Object[]{"_replicate_training_data: Disabling replicate_training_data on 1 node."});
                }
                toParms._replicate_training_data = false;
            }
            if (fromParms._distribution == DistributionFamily.AUTO) {
                toParms._distribution = nClasses > 1 ? (nClasses == 2 ? DistributionFamily.bernoulli : DistributionFamily.multinomial) : DistributionFamily.gaussian;
            }
            if (fromParms._single_node_mode && (H2O.CLOUD.size() == 1 || !fromParms._replicate_training_data)) {
                if (!fromParms._quiet_mode) {
                    Log.info((Object[])new Object[]{"_single_node_mode: Disabling single_node_mode (only for multi-node operation with replicated training data)."});
                }
                toParms._single_node_mode = false;
            }
            if (fromParms._overwrite_with_best_model && fromParms._nfolds != 0) {
                if (!fromParms._quiet_mode) {
                    Log.info((Object[])new Object[]{"_overwrite_with_best_model: Disabling overwrite_with_best_model in combination with n-fold cross-validation."});
                }
                toParms._overwrite_with_best_model = false;
            }
            if (fromParms._problem_type == ProblemType.auto) {
                toParms._problem_type = fromParms.guessProblemType();
                if (!fromParms._quiet_mode) {
                    Log.info((Object[])new Object[]{"_problem_type: Automatically selecting problem_type: " + toParms._problem_type.toString()});
                }
            }
            if (fromParms._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.AUTO) {
                if (!fromParms._quiet_mode) {
                    Log.info((Object[])new Object[]{"_categorical_encoding: Automatically enabling OneHotInternal categorical encoding."});
                }
                toParms._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.OneHotInternal;
            }
            if (fromParms._nfolds != 0 && fromParms._overwrite_with_best_model) {
                if (!fromParms._quiet_mode) {
                    Log.info((Object[])new Object[]{"_overwrite_with_best_model: Automatically disabling overwrite_with_best_model, since the final model is the only scored model with n-fold cross-validation."});
                }
                toParms._overwrite_with_best_model = false;
            }
            if (fromParms._network == Network.auto || fromParms._network == null) {
                if (fromParms._network_definition_file != null && !fromParms._network_definition_file.equals("")) {
                    if (!fromParms._quiet_mode) {
                        Log.info((Object[])new Object[]{"_network_definition_file: Automatically setting network type to 'user', since a network definition file was provided."});
                    }
                    toParms._network = Network.user;
                } else {
                    if (toParms._problem_type == ProblemType.image) {
                        toParms._network = Network.inception_bn;
                    }
                    if (toParms._problem_type == ProblemType.text || toParms._problem_type == ProblemType.dataset) {
                        toParms._network = null;
                        if (fromParms._hidden == null) {
                            toParms._hidden = new int[]{200, 200};
                            toParms._activation = Activation.Rectifier;
                            toParms._hidden_dropout_ratios = new double[toParms._hidden.length];
                        }
                    }
                    if (!fromParms._quiet_mode && toParms._network != null && toParms._network != Network.user) {
                        Log.info((Object[])new Object[]{"_network: Using " + (Object)((Object)toParms._network) + " model by default."});
                    }
                }
            }
            if (fromParms._autoencoder && fromParms._stopping_metric == ScoreKeeper.StoppingMetric.AUTO) {
                if (!fromParms._quiet_mode) {
                    Log.info((Object[])new Object[]{"_stopping_metric: Automatically setting stopping_metric to MSE for autoencoder."});
                }
                toParms._stopping_metric = ScoreKeeper.StoppingMetric.MSE;
            }
            if (toParms._hidden != null) {
                if (toParms._hidden_dropout_ratios == null) {
                    if (!fromParms._quiet_mode) {
                        Log.info((Object[])new Object[]{"_hidden_dropout_ratios: Automatically setting hidden_dropout_ratios to 0 for all layers."});
                    }
                    toParms._hidden_dropout_ratios = new double[toParms._hidden.length];
                }
                if (toParms._activation == null) {
                    toParms._activation = Activation.Rectifier;
                    if (!fromParms._quiet_mode) {
                        Log.info((Object[])new Object[]{"_activation: Automatically setting activation to " + (Object)((Object)toParms._activation) + " for all layers."});
                    }
                }
                if (!fromParms._quiet_mode) {
                    Log.info((Object[])new Object[]{"Hidden layers: " + Arrays.toString(toParms._hidden)});
                    Log.info((Object[])new Object[]{"Activation function: " + (Object)((Object)toParms._activation)});
                    Log.info((Object[])new Object[]{"Input dropout ratio: " + toParms._input_dropout_ratio});
                    Log.info((Object[])new Object[]{"Hidden layer dropout ratio: " + Arrays.toString(toParms._hidden_dropout_ratios)});
                }
            }
        }
    }

    public static enum Activation {
        Rectifier,
        Tanh;

    }

    public static enum MissingValuesHandling {
        Skip,
        MeanImputation;

    }

    public static enum ProblemType {
        auto,
        image,
        text,
        dataset;

    }

    public static enum Backend {
        unknown,
        mxnet,
        caffe,
        tensorflow;

    }

    public static enum Network {
        auto,
        user,
        lenet,
        alexnet,
        vgg,
        googlenet,
        inception_bn,
        resnet;

    }
}

