/*
 * Decompiled with CFR 0.152.
 */
package hex.modelselection;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsRegression;
import hex.deeplearning.DeepLearningModel;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.modelselection.ModelSelection;
import hex.modelselection.ModelSelectionUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Stream;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.fvec.Frame;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.TwoDimTable;

public class ModelSelectionModel
extends Model<ModelSelectionModel, ModelSelectionParameters, ModelSelectionModelOutput> {
    public ModelSelectionModel(Key<ModelSelectionModel> selfKey, ModelSelectionParameters parms, ModelSelectionModelOutput output) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        assert (domain == null);
        switch (((ModelSelectionModelOutput)this._output).getModelCategory()) {
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl((String)("Invalid ModelCategory " + ((ModelSelectionModelOutput)this._output).getModelCategory()));
    }

    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("ModelSelection does not support scoring on data.  It only provide information on predictor relevance");
    }

    public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
        throw new UnsupportedOperationException("AnovaGLM does not support scoring on data.  It only provide information on predictor relevance");
    }

    public Frame result() {
        return ((ModelSelectionModelOutput)this._output).generateResultFrame();
    }

    protected Futures remove_impl(Futures fs, boolean cascade) {
        super.remove_impl(fs, cascade);
        if (cascade && ((ModelSelectionModelOutput)this._output)._best_model_ids != null && ((ModelSelectionModelOutput)this._output)._best_model_ids.length > 0) {
            for (Key oneModelID : ((ModelSelectionModelOutput)this._output)._best_model_ids) {
                if (null == oneModelID) continue;
                Keyed.remove((Key)oneModelID, (Futures)fs, (boolean)cascade);
            }
        }
        return fs;
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        if (((ModelSelectionModelOutput)this._output)._best_model_ids != null && ((ModelSelectionModelOutput)this._output)._best_model_ids.length > 0) {
            for (Key oneModelID : ((ModelSelectionModelOutput)this._output)._best_model_ids) {
                if (null == oneModelID) continue;
                ab.putKey(oneModelID);
            }
        }
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        if (((ModelSelectionModelOutput)this._output)._best_model_ids != null && ((ModelSelectionModelOutput)this._output)._best_model_ids.length > 0) {
            for (Key oneModelID : ((ModelSelectionModelOutput)this._output)._best_model_ids) {
                if (null == oneModelID) continue;
                ab.getKey(oneModelID, fs);
            }
        }
        return super.readAll_impl(ab, fs);
    }

    public HashMap<String, Double>[] coefficients() {
        return this.coefficients(false);
    }

    public HashMap<String, Double>[] coefficients(boolean standardize) {
        int numModel = ((ModelSelectionModelOutput)this._output)._best_model_ids.length;
        HashMap[] coeffs = new HashMap[numModel];
        for (int index = 0; index < numModel; ++index) {
            coeffs[index] = this.coefficients(index + 1, standardize);
        }
        return coeffs;
    }

    public HashMap<String, Double> coefficients(int predictorSize) {
        return this.coefficients(predictorSize, false);
    }

    public HashMap<String, Double> coefficients(int predictorSize, boolean standardize) {
        int numModel = ((ModelSelectionModelOutput)this._output)._best_model_ids.length;
        if (predictorSize <= 0 || predictorSize > numModel) {
            throw new IllegalArgumentException("predictorSize must be between 1 and maximum size of predictor subset size.");
        }
        GLMModel oneModel = (GLMModel)DKV.getGet((Key)((ModelSelectionModelOutput)this._output)._best_model_ids[predictorSize - 1]);
        return oneModel.coefficients(standardize);
    }

    public static class ModelSelectionModelOutput
    extends Model.Output {
        GLMModel.GLMParameters.Family _family;
        DataInfo _dinfo;
        String[][] _best_model_predictors;
        double[] _best_r2_values;
        public Key[] _best_model_ids;
        String[][] _coefficient_names;
        double[][] _coef_p_values;
        double[][] _z_values;

        public ModelSelectionModelOutput(ModelSelection b, DataInfo dinfo) {
            super((ModelBuilder)b, dinfo._adaptedFrame);
            this._dinfo = dinfo;
        }

        public String[][] coefficientNames() {
            return this._coefficient_names;
        }

        public double[][] beta() {
            int numModel = this._best_model_ids.length;
            double[][] coeffs = new double[numModel][];
            for (int index = 0; index < numModel; ++index) {
                GLMModel oneModel = (GLMModel)DKV.getGet((Key)this._best_model_ids[index]);
                coeffs[index] = (double[])((GLMModel.GLMOutput)oneModel._output).beta().clone();
            }
            return coeffs;
        }

        public double[][] getNormBeta() {
            int numModel = this._best_model_ids.length;
            double[][] coeffs = new double[numModel][];
            for (int index = 0; index < numModel; ++index) {
                GLMModel oneModel = (GLMModel)DKV.getGet((Key)this._best_model_ids[index]);
                coeffs[index] = (double[])((GLMModel.GLMOutput)oneModel._output).getNormBeta().clone();
            }
            return coeffs;
        }

        public ModelCategory getModelCategory() {
            return ModelCategory.Regression;
        }

        private Frame generateResultFrame() {
            int numRows = this._best_model_predictors.length;
            String[] modelNames = new String[numRows];
            String[] predNames = new String[numRows];
            String[] modelIds = (String[])Stream.of(this._best_model_ids).map(Key::toString).toArray(String[]::new);
            String[] zvalues = new String[numRows];
            String[] pvalues = new String[numRows];
            boolean backwardMode = this._z_values != null;
            for (int index = 0; index < numRows; ++index) {
                int numPred = this._best_model_predictors[index].length;
                modelNames[index] = "best " + numPred + " predictor(s) model";
                String string = predNames[index] = backwardMode ? String.join((CharSequence)", ", this._coefficient_names[index]) : String.join((CharSequence)", ", this._best_model_predictors[index]);
                if (!backwardMode) continue;
                zvalues[index] = ModelSelectionUtils.joinDouble(this._z_values[index]);
                pvalues[index] = ModelSelectionUtils.joinDouble(this._coef_p_values[index]);
            }
            Vec.VectorGroup vg = Vec.VectorGroup.VG_LEN1;
            Vec modNames = Vec.makeVec((String[])modelNames, (Key)vg.addVec());
            Vec modelIDV = Vec.makeVec((String[])modelIds, (Key)vg.addVec());
            Vec r2 = null;
            Vec zval = null;
            Vec pval = null;
            if (backwardMode) {
                zval = Vec.makeVec((String[])zvalues, (Key)vg.addVec());
                pval = Vec.makeVec((String[])pvalues, (Key)vg.addVec());
            } else {
                r2 = Vec.makeVec((double[])this._best_r2_values, (Key)vg.addVec());
            }
            Vec predN = Vec.makeVec((String[])predNames, (Key)vg.addVec());
            if (backwardMode) {
                String[] colNames = new String[]{"model_name", "model_id", "z_values", "p_values", "coefficient_names"};
                return new Frame(Key.make(), colNames, new Vec[]{modNames, modelIDV, zval, pval, predN});
            }
            String[] colNames = new String[]{"model_name", "model_id", "best_r2_value", "predictor_names"};
            return new Frame(Key.make(), colNames, new Vec[]{modNames, modelIDV, r2, predN});
        }

        public void shrinkArrays(int numModelsBuilt) {
            if (this._best_model_predictors.length > numModelsBuilt) {
                this._best_model_predictors = ModelSelectionUtils.shrinkStringArray(this._best_model_predictors, numModelsBuilt);
                this._coefficient_names = ModelSelectionUtils.shrinkStringArray(this._coefficient_names, numModelsBuilt);
                this._z_values = ModelSelectionUtils.shrinkDoubleArray(this._z_values, numModelsBuilt);
                this._coef_p_values = ModelSelectionUtils.shrinkDoubleArray(this._coef_p_values, numModelsBuilt);
                this._best_model_ids = ModelSelectionUtils.shrinkKeyArray(this._best_model_ids, numModelsBuilt);
            }
        }

        public void generateSummary() {
            int numModels = this._best_r2_values.length;
            String[] names = new String[]{"best r2 value", "predictor names"};
            String[] types = new String[]{"double", "String"};
            String[] formats = new String[]{"%d", "%s"};
            String[] rowHeaders = new String[numModels];
            for (int index = 1; index <= numModels; ++index) {
                rowHeaders[index - 1] = "with " + index + " predictors";
            }
            this._model_summary = new TwoDimTable("ModelSelection Model Summary", "summary", rowHeaders, names, types, formats, "");
            for (int rIndex = 0; rIndex < numModels; ++rIndex) {
                int colInd = 0;
                this._model_summary.set(rIndex, colInd++, (Object)this._best_r2_values[rIndex]);
                this._model_summary.set(rIndex, colInd++, (Object)String.join((CharSequence)", ", this._best_model_predictors[rIndex]));
            }
        }

        public void generateSummary(int numModels) {
            String[] names = new String[]{"coefficient names", "z values", "p values"};
            String[] types = new String[]{"string", "string", "string"};
            String[] formats = new String[]{"%s", "%s", "%s"};
            String[] rowHeaders = new String[numModels];
            for (int index = 0; index < numModels; ++index) {
                rowHeaders[index] = "with " + this._best_model_predictors[index].length + " predictors";
            }
            this._model_summary = new TwoDimTable("ModelSlection Model Summary", "summary", rowHeaders, names, types, formats, "");
            for (int rIndex = 0; rIndex < numModels; ++rIndex) {
                int colInd = 0;
                String pValue = ModelSelectionUtils.joinDouble(this._coef_p_values[rIndex]);
                String zValue = ModelSelectionUtils.joinDouble(this._z_values[rIndex]);
                String coeffNames = String.join((CharSequence)", ", this._coefficient_names[rIndex]);
                this._model_summary.set(rIndex, colInd++, (Object)coeffNames);
                this._model_summary.set(rIndex, colInd++, (Object)zValue);
                this._model_summary.set(rIndex, colInd++, (Object)pValue);
            }
        }

        void updateBestModels(GLMModel bestModel, int index) {
            this._best_model_ids[index] = bestModel.getKey();
            if (((GLMModel.GLMParameters)bestModel._parms)._nfolds > 0) {
                int r2Index = Arrays.asList(((GLMModel.GLMOutput)bestModel._output)._cross_validation_metrics_summary.getRowHeaders()).indexOf("r2");
                Float tempR2 = (Float)((GLMModel.GLMOutput)bestModel._output)._cross_validation_metrics_summary.get(r2Index, 0);
                this._best_r2_values[index] = tempR2.doubleValue();
            } else {
                this._best_r2_values[index] = bestModel.r2();
            }
            this.extractCoeffs(bestModel, index);
        }

        void extractCoeffs(GLMModel model, int index) {
            this._coefficient_names[index] = (String[])((GLMModel.GLMOutput)model._output).coefficientNames().clone();
            ArrayList<String> coeffNames = new ArrayList<String>(Arrays.asList(((GLMModel.GLMOutput)model._output).coefficientNames()));
            this._best_model_predictors[index] = coeffNames.toArray(new String[0]);
        }

        void extractPredictors4NextModel(GLMModel model, int index, List<String> predNames, List<Integer> predIndices, List<String> numPredNames, List<String> catPredNames) {
            this.extractCoeffs(model, index);
            this._best_model_ids[index] = model.getKey();
            int predIndex2Remove = ModelSelectionUtils.findMinZValue(model, numPredNames, catPredNames, predNames);
            predIndices.remove(predIndices.indexOf(predIndex2Remove));
            this._z_values[index] = (double[])((GLMModel.GLMOutput)model._output).zValues().clone();
            this._coef_p_values[index] = (double[])((GLMModel.GLMOutput)model._output).pValues().clone();
        }
    }

    public static class ModelSelectionParameters
    extends Model.Parameters {
        public double[] _alpha;
        public double[] _lambda;
        public boolean _standardize = true;
        GLMModel.GLMParameters.Family _family = GLMModel.GLMParameters.Family.AUTO;
        public boolean _lambda_search;
        public GLMModel.GLMParameters.Link _link = GLMModel.GLMParameters.Link.family_default;
        public GLMModel.GLMParameters.Solver _solver = GLMModel.GLMParameters.Solver.IRLSM;
        public String[] _interactions = null;
        public Serializable _missing_values_handling = GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
        public boolean _compute_p_values = false;
        public boolean _remove_collinear_columns = false;
        public int _nfolds = 0;
        public Key<Frame> _plug_values = null;
        public int _max_predictor_number = 1;
        public int _min_predictor_number = 1;
        public int _nparallelism = 0;
        public double _p_values_threshold = 0.0;
        public double _tweedie_variance_power;
        public double _tweedie_link_power;
        public GLMModel.GLMParameters.GLMType _glmType = GLMModel.GLMParameters.GLMType.glm;
        public Mode _mode = Mode.maxr;
        public double _beta_epsilon = 1.0E-4;
        public double _objective_epsilon = -1.0;

        public String algoName() {
            return "ModelSelection";
        }

        public String fullName() {
            return "Model Selection";
        }

        public String javaName() {
            return ModelSelectionModel.class.getName();
        }

        public long progressUnits() {
            return 1L;
        }

        public GLMModel.GLMParameters.MissingValuesHandling missingValuesHandling() {
            if (this._missing_values_handling instanceof GLMModel.GLMParameters.MissingValuesHandling) {
                return (GLMModel.GLMParameters.MissingValuesHandling)((Object)this._missing_values_handling);
            }
            assert (this._missing_values_handling instanceof DeepLearningModel.DeepLearningParameters.MissingValuesHandling);
            switch ((DeepLearningModel.DeepLearningParameters.MissingValuesHandling)((Object)this._missing_values_handling)) {
                case MeanImputation: {
                    return GLMModel.GLMParameters.MissingValuesHandling.MeanImputation;
                }
                case Skip: {
                    return GLMModel.GLMParameters.MissingValuesHandling.Skip;
                }
            }
            throw new IllegalStateException("Unsupported missing values handling value: " + this._missing_values_handling);
        }

        public boolean imputeMissing() {
            return this.missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.MeanImputation || this.missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues;
        }

        public DataInfo.Imputer makeImputer() {
            if (this.missingValuesHandling() == GLMModel.GLMParameters.MissingValuesHandling.PlugValues) {
                if (this._plug_values == null || this._plug_values.get() == null) {
                    throw new IllegalStateException("Plug values frame needs to be specified when Missing Value Handling = PlugValues.");
                }
                return new GLM.PlugValuesImputer((Frame)this._plug_values.get());
            }
            return new DataInfo.MeanImputer();
        }

        public static enum Mode {
            allsubsets,
            maxr,
            backward;

        }
    }
}

