/*
 * Decompiled with CFR 0.152.
 */
package hex.schemas;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import hex.Infogram.Infogram;
import hex.Infogram.InfogramModel;
import hex.Model;
import hex.deeplearning.DeepLearningModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.schemas.DRFV3;
import hex.schemas.DeepLearningV3;
import hex.schemas.GBMV3;
import hex.schemas.GLMV3;
import hex.schemas.ModelBuilderSchema;
import hex.schemas.XGBoostV3;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBMModel;
import hex.tree.xgboost.XGBoostModel;
import hex.util.DistributionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import water.api.API;
import water.api.EnumValuesProvider;
import water.api.schemas3.KeyV3;
import water.api.schemas3.ModelParametersSchemaV3;

public class InfogramV3
extends ModelBuilderSchema<Infogram, InfogramV3, InfogramParametersV3> {

    public static final class InfogramAlrogithmProvider
    extends EnumValuesProvider<InfogramModel.InfogramParameters.Algorithm> {
        public InfogramAlrogithmProvider() {
            super(InfogramModel.InfogramParameters.Algorithm.class);
        }
    }

    public static final class InfogramParametersV3
    extends ModelParametersSchemaV3<InfogramModel.InfogramParameters, InfogramParametersV3> {
        public static final String[] fields = new String[]{"model_id", "training_frame", "validation_frame", "seed", "keep_cross_validation_models", "keep_cross_validation_predictions", "keep_cross_validation_fold_assignment", "nfolds", "fold_assignment", "fold_column", "response_column", "ignored_columns", "ignore_const_cols", "score_each_iteration", "offset_column", "weights_column", "standardize", "distribution", "plug_values", "max_iterations", "stopping_rounds", "stopping_metric", "stopping_tolerance", "balance_classes", "class_sampling_factors", "max_after_balance_size", "max_runtime_secs", "custom_metric_func", "auc_type", "algorithm", "algorithm_params", "protected_columns", "total_information_threshold", "net_information_threshold", "relevance_index_threshold", "safety_index_threshold", "data_fraction", "top_n_features"};
        @API(help="Seed for pseudo random number generator (if applicable).", gridable=true)
        public long seed;
        @API(help="Standardize numeric columns to have zero mean and unit variance.", level=API.Level.critical)
        public boolean standardize;
        @API(help="Plug Values (a single row frame containing values that will be used to impute missing values of the training/validation frame, use with conjunction missing_values_handling = PlugValues).", direction=API.Direction.INPUT)
        public KeyV3.FrameKeyV3 plug_values;
        @API(help="Maximum number of iterations.", level=API.Level.secondary)
        public int max_iterations;
        @API(help="Prior probability for y==1. To be used only for logistic regression iff the data has been sampled and the mean of response does not reflect reality.", level=API.Level.expert)
        public double prior;
        @API(help="Balance training data class counts via over/under-sampling (for imbalanced data).", level=API.Level.secondary, direction=API.Direction.INOUT)
        public boolean balance_classes;
        @API(help="Desired over/under-sampling ratios per class (in lexicographic order). If not specified, sampling factors will be automatically computed to obtain class balance during training. Requires balance_classes.", level=API.Level.expert, direction=API.Direction.INOUT)
        public float[] class_sampling_factors;
        @API(help="Maximum relative size of the training data after balancing class counts (can be less than 1.0). Requires balance_classes.", level=API.Level.expert, direction=API.Direction.INOUT)
        public float max_after_balance_size;
        @API(level=API.Level.critical, direction=API.Direction.INOUT, valuesProvider=InfogramAlrogithmProvider.class, help="Type of machine learning algorithm used to build the infogram. Options include 'AUTO' (gbm), 'deeplearning' (Deep Learning with default parameters), 'drf' (Random Forest with default parameters), 'gbm' (GBM with default parameters), 'glm' (GLM with default parameters), or 'xgboost' (if available, XGBoost with default parameters).")
        public InfogramModel.InfogramParameters.Algorithm algorithm;
        @API(help="Customized parameters for the machine learning algorithm specified in the algorithm parameter.", level=API.Level.expert, gridable=true)
        public String algorithm_params;
        @API(help="Columns that contain features that are sensitive and need to be protected (legally, or otherwise), if applicable. These features (e.g. race, gender, etc) should not drive the prediction of the response.", level=API.Level.secondary, gridable=true)
        public String[] protected_columns;
        @API(help="A number between 0 and 1 representing a threshold for total information, defaulting to 0.1. For a specific feature, if the total information is higher than this threshold, and the corresponding net information is also higher than the threshold ``net_information_threshold``, that feature will be considered admissible. The total information is the x-axis of the Core Infogram. Default is -1 which gets set to 0.1.", level=API.Level.secondary, gridable=true)
        public double total_information_threshold;
        @API(help="A number between 0 and 1 representing a threshold for net information, defaulting to 0.1.  For a specific feature, if the net information is higher than this threshold, and the corresponding total information is also higher than the total_information_threshold, that feature will be considered admissible. The net information is the y-axis of the Core Infogram. Default is -1 which gets set to 0.1.", level=API.Level.secondary, gridable=true)
        public double net_information_threshold;
        @API(help="A number between 0 and 1 representing a threshold for the relevance index, defaulting to 0.1.  This is only used when ``protected_columns`` is set by the user.  For a specific feature, if the relevance index value is higher than this threshold, and the corresponding safety index is also higher than the safety_index_threshold``, that feature will be considered admissible.  The relevance index is the x-axis of the Fair Infogram. Default is -1 which gets set to 0.1.", level=API.Level.secondary, gridable=true)
        public double relevance_index_threshold;
        @API(help="A number between 0 and 1 representing a threshold for the safety index, defaulting to 0.1.  This is only used when protected_columns is set by the user.  For a specific feature, if the safety index value is higher than this threshold, and the corresponding relevance index is also higher than the relevance_index_threshold, that feature will be considered admissible.  The safety index is the y-axis of the Fair Infogram. Default is -1 which gets set to 0.1.", level=API.Level.secondary, gridable=true)
        public double safety_index_threshold;
        @API(help="The fraction of training frame to use to build the infogram model. Defaults to 1.0, and any value greater than 0 and less than or equal to 1.0 is acceptable.", level=API.Level.secondary, gridable=true)
        public double data_fraction;
        @API(help="An integer specifying the number of columns to evaluate in the infogram.  The columns are ranked by variable importance, and the top N are evaluated.  Defaults to 50.", level=API.Level.secondary, gridable=true)
        public int top_n_features;

        public InfogramModel.InfogramParameters fillImpl(InfogramModel.InfogramParameters impl) {
            super.fillImpl((Model.Parameters)impl);
            if (this.algorithm_params != null && !this.algorithm_params.isEmpty()) {
                Properties p = this.generateProperties(this.algorithm_params);
                ParamNParamSchema schemaParams = this.generateParamsSchema(this.algorithm);
                schemaParams._paramsSchema.init_meta();
                impl._infogram_algorithm_parameters = (Model.Parameters)schemaParams._paramsSchema.fillFromImpl(schemaParams._params).fillFromParms(p, true).createAndFillImpl();
                super.fillImpl((Model.Parameters)impl);
            }
            return impl;
        }

        public static void generateModelParams(InfogramModel.InfogramParameters parms, Properties p, ArrayList<String> excludeList) {
            GLMModel.GLMParameters params;
            GLMV3.GLMParametersV3 paramsSchema;
            switch (parms._algorithm) {
                case glm: {
                    paramsSchema = new GLMV3.GLMParametersV3();
                    params = new GLMModel.GLMParameters();
                    excludeList.add("_distribution");
                    params._family = DistributionUtils.distributionToFamily((DistributionFamily)parms._distribution);
                    break;
                }
                case AUTO: 
                case gbm: {
                    paramsSchema = new GBMV3.GBMParametersV3();
                    params = new GBMModel.GBMParameters();
                    if (excludeList.contains("_stopping_tolerance")) break;
                    params._stopping_tolerance = 0.01;
                    excludeList.add("_stopping_tolerance");
                    break;
                }
                case drf: {
                    paramsSchema = new DRFV3.DRFParametersV3();
                    params = new DRFModel.DRFParameters();
                    if (excludeList.contains("_stopping_tolerance")) break;
                    params._stopping_tolerance = 0.01;
                    excludeList.add("_stopping_tolerance");
                    break;
                }
                case deeplearning: {
                    paramsSchema = new DeepLearningV3.DeepLearningParametersV3();
                    params = new DeepLearningModel.DeepLearningParameters();
                    break;
                }
                case xgboost: {
                    paramsSchema = new XGBoostV3.XGBoostParametersV3();
                    params = new XGBoostModel.XGBoostParameters();
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unknown algo: " + (Object)((Object)parms._algorithm));
                }
            }
            paramsSchema.init_meta();
            parms._infogram_algorithm_parameters = (Model.Parameters)paramsSchema.fillFromImpl((Model.Parameters)params).fillFromParms(p, true).createAndFillImpl();
        }

        Properties generateProperties(String algoParms) {
            Properties p = new Properties();
            HashMap map = (HashMap)new Gson().fromJson(algoParms, new TypeToken<HashMap<String, String[]>>(){}.getType());
            for (Map.Entry param : map.entrySet()) {
                Object[] paramVal = (String[])param.getValue();
                if (paramVal.length == 1) {
                    p.setProperty((String)param.getKey(), (String)paramVal[0]);
                    continue;
                }
                p.setProperty((String)param.getKey(), Arrays.toString(paramVal));
            }
            return p;
        }

        ParamNParamSchema generateParamsSchema(InfogramModel.InfogramParameters.Algorithm chosenAlgo) {
            GLMModel.GLMParameters params;
            GLMV3.GLMParametersV3 paramsSchema;
            switch (chosenAlgo) {
                case glm: 
                case AUTO: {
                    paramsSchema = new GLMV3.GLMParametersV3();
                    params = new GLMModel.GLMParameters();
                    params._family = GLMModel.GLMParameters.Family.AUTO;
                    break;
                }
                case gbm: {
                    paramsSchema = new GBMV3.GBMParametersV3();
                    params = new GBMModel.GBMParameters();
                    break;
                }
                case drf: {
                    paramsSchema = new DRFV3.DRFParametersV3();
                    params = new DRFModel.DRFParameters();
                    break;
                }
                case deeplearning: {
                    paramsSchema = new DeepLearningV3.DeepLearningParametersV3();
                    params = new DeepLearningModel.DeepLearningParameters();
                    break;
                }
                case xgboost: {
                    paramsSchema = new XGBoostV3.XGBoostParametersV3();
                    params = new XGBoostModel.XGBoostParameters();
                    break;
                }
                default: {
                    throw new UnsupportedOperationException("Unknown given algo: " + (Object)((Object)chosenAlgo));
                }
            }
            return new ParamNParamSchema((ModelParametersSchemaV3)paramsSchema, (Model.Parameters)params);
        }

        private class ParamNParamSchema {
            private ModelParametersSchemaV3 _paramsSchema;
            private Model.Parameters _params;

            public ParamNParamSchema(ModelParametersSchemaV3 schema, Model.Parameters params) {
                this._paramsSchema = schema;
                this._params = params;
            }
        }
    }
}

