/*
 * Decompiled with CFR 0.152.
 */
package hex.rulefit;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelBuilderHelper;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.genmodel.utils.ArrayUtils;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.rulefit.Rule;
import hex.rulefit.RuleEnsemble;
import hex.rulefit.RuleFitModel;
import hex.rulefit.RuleFitUtils;
import hex.tree.SharedTree;
import hex.tree.SharedTreeModel;
import hex.tree.TreeStats;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.log4j.Logger;
import water.DKV;
import water.Iced;
import water.Key;
import water.Keyed;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.TwoDimTable;

public class RuleFit
extends ModelBuilder<RuleFitModel, RuleFitModel.RuleFitParameters, RuleFitModel.RuleFitOutput> {
    private static final Logger LOG = Logger.getLogger(RuleFit.class);
    protected static final long WORK_TOTAL = 1000000L;
    private SharedTreeModel.SharedTreeParameters treeParameters = null;
    private GLMModel.GLMParameters glmParameters = null;

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Regression, ModelCategory.Binomial, ModelCategory.Multinomial};
    }

    public boolean isSupervised() {
        return true;
    }

    protected RuleFitDriver trainModelImpl() {
        return new RuleFitDriver();
    }

    public RuleFit(RuleFitModel.RuleFitParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    public RuleFit(boolean startup_once) {
        super((Model.Parameters)new RuleFitModel.RuleFitParameters(), startup_once);
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (expensive) {
            if (((RuleFitModel.RuleFitParameters)this._parms)._fold_column != null) {
                this._train.remove(((RuleFitModel.RuleFitParameters)this._parms)._fold_column);
            }
            if (((RuleFitModel.RuleFitParameters)this._parms)._algorithm == RuleFitModel.Algorithm.AUTO) {
                ((RuleFitModel.RuleFitParameters)this._parms)._algorithm = RuleFitModel.Algorithm.DRF;
            }
            this.initTreeParameters();
            this.initGLMParameters();
            this.ignoreBadColumns(this.separateFeatureVecs(), true);
        }
    }

    private void initTreeParameters() {
        if (((RuleFitModel.RuleFitParameters)this._parms)._algorithm == RuleFitModel.Algorithm.GBM) {
            this.treeParameters = new GBMModel.GBMParameters();
        } else if (((RuleFitModel.RuleFitParameters)this._parms)._algorithm == RuleFitModel.Algorithm.DRF) {
            this.treeParameters = new DRFModel.DRFParameters();
        } else {
            throw new RuntimeException("Unsupported algorithm for tree building: " + (Object)((Object)((RuleFitModel.RuleFitParameters)this._parms)._algorithm));
        }
        this.treeParameters._response_column = ((RuleFitModel.RuleFitParameters)this._parms)._response_column;
        this.treeParameters._train = ((RuleFitModel.RuleFitParameters)this._parms)._train;
        this.treeParameters._ignored_columns = ((RuleFitModel.RuleFitParameters)this._parms)._ignored_columns;
        this.treeParameters._seed = ((RuleFitModel.RuleFitParameters)this._parms)._seed;
        this.treeParameters._weights_column = ((RuleFitModel.RuleFitParameters)this._parms)._weights_column;
        this.treeParameters._distribution = ((RuleFitModel.RuleFitParameters)this._parms)._distribution;
        this.treeParameters._ntrees = ((RuleFitModel.RuleFitParameters)this._parms)._rule_generation_ntrees;
    }

    private void initGLMParameters() {
        if (((RuleFitModel.RuleFitParameters)this._parms)._distribution == DistributionFamily.AUTO) {
            this.glmParameters = this._nclass < 2 ? new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian) : (this._nclass == 2 ? new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial) : new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial));
        } else {
            switch (((RuleFitModel.RuleFitParameters)this._parms)._distribution) {
                case bernoulli: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.binomial);
                    break;
                }
                case quasibinomial: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.quasibinomial);
                    break;
                }
                case multinomial: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.multinomial);
                    break;
                }
                case ordinal: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.ordinal);
                    break;
                }
                case gaussian: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gaussian);
                    break;
                }
                case poisson: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.poisson);
                    break;
                }
                case gamma: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.gamma);
                    break;
                }
                case tweedie: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.tweedie);
                    break;
                }
                case fractionalbinomial: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.fractionalbinomial);
                    break;
                }
                case negativebinomial: {
                    this.glmParameters = new GLMModel.GLMParameters(GLMModel.GLMParameters.Family.negativebinomial);
                    break;
                }
                default: {
                    this.error("_distribution", "Distribution not supported.");
                }
            }
        }
        if (RuleFitModel.ModelType.RULES_AND_LINEAR.equals((Object)((RuleFitModel.RuleFitParameters)this._parms)._model_type) && ((RuleFitModel.RuleFitParameters)this._parms)._ignored_columns != null) {
            this.glmParameters._ignored_columns = ((RuleFitModel.RuleFitParameters)this._parms)._ignored_columns;
        }
        this.glmParameters._response_column = "linear." + ((RuleFitModel.RuleFitParameters)this._parms)._response_column;
        this.glmParameters._seed = ((RuleFitModel.RuleFitParameters)this._parms)._seed;
        this.glmParameters._alpha = new double[]{1.0};
        if (((RuleFitModel.RuleFitParameters)this._parms)._weights_column != null) {
            this.glmParameters._weights_column = "linear." + ((RuleFitModel.RuleFitParameters)this._parms)._weights_column;
        }
    }

    private TwoDimTable convertRulesToTable(Rule[] rules) {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormat = new ArrayList<String>();
        colHeaders.add("variable");
        colTypes.add("string");
        colFormat.add("%s");
        colHeaders.add("coefficient");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("rule");
        colTypes.add("string");
        colFormat.add("%s");
        int rows = rules.length;
        TwoDimTable table = new TwoDimTable("Rule Importance", null, new String[rows], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");
        for (int row = 0; row < rows; ++row) {
            int col = 0;
            table.set(row, col++, (Object)rules[row].varName);
            table.set(row, col++, (Object)rules[row].coefficient);
            table.set(row, col, (Object)rules[row].languageRule);
        }
        return table;
    }

    protected int nTreeEnsemblesInParallel(int numDepths) {
        if (((RuleFitModel.RuleFitParameters)this._parms)._algorithm == RuleFitModel.Algorithm.GBM) {
            return this.nModelsInParallel(numDepths, 2);
        }
        return this.nModelsInParallel(numDepths, 1);
    }

    TwoDimTable generateSummary(GLMModel glmModel, int ruleEnsembleSize, TreeStats overallTreeStats, int ntrees) {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormats = new ArrayList<String>();
        TwoDimTable glmModelSummary = ((GLMModel.GLMOutput)glmModel._output)._model_summary;
        String[] glmColHeaders = glmModelSummary.getColHeaders();
        String[] glmColTypes = glmModelSummary.getColTypes();
        String[] glmColFormats = glmModelSummary.getColFormats();
        for (int i = 0; i < glmModelSummary.getColDim(); ++i) {
            if ("Training Frame".equals(glmColHeaders[i])) continue;
            colHeaders.add(glmColHeaders[i]);
            colTypes.add(glmColTypes[i]);
            colFormats.add(glmColFormats[i]);
        }
        colHeaders.add("Rule Ensemble Size");
        colTypes.add("long");
        colFormats.add("%d");
        colHeaders.add("Number of Trees");
        colTypes.add("long");
        colFormats.add("%d");
        colHeaders.add("Number of Internal Trees");
        colTypes.add("long");
        colFormats.add("%d");
        colHeaders.add("Min. Depth");
        colTypes.add("long");
        colFormats.add("%d");
        colHeaders.add("Max. Depth");
        colTypes.add("long");
        colFormats.add("%d");
        colHeaders.add("Mean Depth");
        colTypes.add("double");
        colFormats.add("%.5f");
        colHeaders.add("Min. Leaves");
        colTypes.add("long");
        colFormats.add("%d");
        colHeaders.add("Max. Leaves");
        colTypes.add("long");
        colFormats.add("%d");
        colHeaders.add("Mean Leaves");
        colTypes.add("double");
        colFormats.add("%.5f");
        boolean rows = true;
        TwoDimTable summary = new TwoDimTable("Rulefit Model Summary", null, new String[1], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormats.toArray(new String[0]), "");
        int col = 0;
        int row = 0;
        for (int i = 0; i < glmModelSummary.getColDim(); ++i) {
            if ("Training Frame".equals(glmColHeaders[i])) continue;
            summary.set(row, col++, glmModelSummary.get(row, i));
        }
        summary.set(row, col++, (Object)ruleEnsembleSize);
        summary.set(row, col++, (Object)ntrees);
        summary.set(row, col++, (Object)overallTreeStats._num_trees);
        summary.set(row, col++, (Object)overallTreeStats._min_depth);
        summary.set(row, col++, (Object)overallTreeStats._max_depth);
        summary.set(row, col++, (Object)Float.valueOf(overallTreeStats._mean_depth));
        summary.set(row, col++, (Object)overallTreeStats._min_leaves);
        summary.set(row, col++, (Object)overallTreeStats._max_leaves);
        summary.set(row, col++, (Object)Float.valueOf(overallTreeStats._mean_leaves));
        return summary;
    }

    private final class RuleFitDriver
    extends ModelBuilder.Driver {
        private RuleFitDriver() {
            super((ModelBuilder)RuleFit.this);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void computeImpl() {
            RuleFitModel model = null;
            RuleEnsemble ruleEnsemble = null;
            int ntrees = 0;
            TreeStats overallTreeStats = new TreeStats();
            RuleFit.this.init(true);
            if (RuleFit.this.error_count() > 0) {
                throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)RuleFit.this);
            }
            try {
                Frame linearTrain = new Frame(Key.make((String)("paths_frame" + RuleFit.this._result)));
                Frame linearValid = RuleFit.this._valid != null ? new Frame(Key.make((String)("valid_paths_frame" + RuleFit.this._result))) : null;
                Frame trainAdapted = new Frame(RuleFit.this._train);
                int[] depths = this.range(((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._min_rule_length, ((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._max_rule_length);
                if (RuleFitModel.ModelType.RULES_AND_LINEAR.equals((Object)((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._model_type) || RuleFitModel.ModelType.RULES.equals((Object)((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._model_type)) {
                    DKV.put((Key)trainAdapted._key, (Iced)trainAdapted);
                    ((RuleFit)RuleFit.this).treeParameters._train = trainAdapted._key;
                    long startAllTreesTime = System.nanoTime();
                    SharedTree[] builders = (SharedTree[])ModelBuilderHelper.trainModelsParallel((ModelBuilder[])this.makeTreeModelBuilders(((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._algorithm, depths), (int)RuleFit.this.nTreeEnsemblesInParallel(depths.length));
                    ArrayList<Rule> rulesList = new ArrayList<Rule>();
                    for (int modelId = 0; modelId < builders.length; ++modelId) {
                        long startModelTime = System.nanoTime();
                        SharedTreeModel treeModel = (SharedTreeModel)builders[modelId].get();
                        long endModelTime = System.nanoTime() - startModelTime;
                        LOG.info((Object)("Tree model n." + modelId + " trained in " + (double)endModelTime / 1.0E9 + "s."));
                        rulesList.addAll(Rule.extractRulesListFromModel(treeModel, modelId, RuleFit.this.nclasses()));
                        overallTreeStats.mergeWith(((SharedTreeModel.SharedTreeOutput)treeModel._output)._treeStats);
                        ntrees += ((SharedTreeModel.SharedTreeOutput)treeModel._output)._ntrees;
                        treeModel.delete();
                    }
                    long endAllTreesTime = System.nanoTime() - startAllTreesTime;
                    LOG.info((Object)("All tree models trained in " + (double)endAllTreesTime / 1.0E9 + "s."));
                    LOG.info((Object)"Extracting rules from trees...");
                    ruleEnsemble = new RuleEnsemble(rulesList.toArray(new Rule[0]));
                    linearTrain.add(ruleEnsemble.createGLMTrainFrame(RuleFit.this._train, depths.length, ((RuleFit)RuleFit.this).treeParameters._ntrees));
                    if (RuleFit.this._valid != null) {
                        linearValid.add(ruleEnsemble.createGLMTrainFrame(RuleFit.this._valid, depths.length, ((RuleFit)RuleFit.this).treeParameters._ntrees));
                    }
                }
                if (RuleFitModel.ModelType.RULES_AND_LINEAR.equals((Object)((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._model_type) || RuleFitModel.ModelType.LINEAR.equals((Object)((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._model_type)) {
                    String[] names = ((RuleFit)RuleFit.this)._train._names;
                    linearTrain.add(RuleFitUtils.getLinearNames(names.length, names), RuleFit.this._train.vecs(names));
                    if (RuleFit.this._valid != null) {
                        linearValid.add(RuleFitUtils.getLinearNames(names.length, names), RuleFit.this._valid.vecs(names));
                    }
                } else {
                    linearTrain.add(((RuleFit)RuleFit.this).glmParameters._response_column, RuleFit.this._train.vec(((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._response_column));
                    if (RuleFit.this._valid != null) {
                        linearValid.add(((RuleFit)RuleFit.this).glmParameters._response_column, RuleFit.this._valid.vec(((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._response_column));
                    }
                    if (((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._weights_column != null) {
                        linearTrain.add(((RuleFit)RuleFit.this).glmParameters._weights_column, RuleFit.this._train.vec(((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._weights_column));
                        if (RuleFit.this._valid != null) {
                            linearValid.add(((RuleFit)RuleFit.this).glmParameters._weights_column, RuleFit.this._valid.vec(((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._weights_column));
                        }
                    }
                }
                DKV.put((Keyed)linearTrain);
                if (RuleFit.this._valid != null) {
                    DKV.put((Keyed)linearValid);
                    ((RuleFit)RuleFit.this).glmParameters._valid = linearValid._key;
                }
                ((RuleFit)RuleFit.this).glmParameters._train = linearTrain._key;
                if (((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._max_num_rules > 0) {
                    ((RuleFit)RuleFit.this).glmParameters._max_active_predictors = ((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._max_num_rules + 1;
                    if (((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._distribution != DistributionFamily.multinomial) {
                        ((RuleFit)RuleFit.this).glmParameters._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT;
                    }
                } else {
                    ((RuleFit)RuleFit.this).glmParameters._lambda = this.getOptimalLambda();
                }
                long startGLMTime = System.nanoTime();
                GLM job = new GLM(RuleFit.this.glmParameters);
                GLMModel glmModel = (GLMModel)job.trainModel().get();
                long endGLMTime = System.nanoTime() - startGLMTime;
                LOG.info((Object)("GLM trained in " + (double)endGLMTime / 1.0E9 + "s."));
                DKV.put((Keyed)glmModel);
                DKV.remove((Key)linearTrain._key);
                if (linearValid != null) {
                    DKV.remove((Key)linearValid._key);
                }
                DKV.remove((Key)trainAdapted._key);
                model = new RuleFitModel((Key<RuleFitModel>)RuleFit.this.dest(), (RuleFitModel.RuleFitParameters)RuleFit.this._parms, new RuleFitModel.RuleFitOutput(RuleFit.this), glmModel, ruleEnsemble);
                ((RuleFitModel.RuleFitOutput)model._output).glmModelKey = glmModel._key;
                ((RuleFitModel.RuleFitOutput)model._output)._intercept = this.getIntercept(glmModel);
                ((RuleFitModel.RuleFitOutput)model._output)._rule_importance = RuleFit.this.convertRulesToTable(this.getRules(glmModel.coefficients(), ruleEnsemble));
                ((RuleFitModel.RuleFitOutput)model._output)._model_summary = RuleFit.this.generateSummary(glmModel, ruleEnsemble.size(), overallTreeStats, ntrees);
                this.fillModelMetrics(model, glmModel);
                model.delete_and_lock(RuleFit.this._job);
                model.update(RuleFit.this._job);
            }
            finally {
                if (model != null) {
                    model.unlock(RuleFit.this._job);
                }
            }
        }

        void fillModelMetrics(RuleFitModel model, GLMModel glmModel) {
            ((RuleFitModel.RuleFitOutput)model._output)._validation_metrics = ((GLMModel.GLMOutput)glmModel._output)._validation_metrics;
            ((RuleFitModel.RuleFitOutput)model._output)._training_metrics = ((GLMModel.GLMOutput)glmModel._output)._training_metrics;
            ((RuleFitModel.RuleFitOutput)model._output)._cross_validation_metrics = ((GLMModel.GLMOutput)glmModel._output)._cross_validation_metrics;
            ((RuleFitModel.RuleFitOutput)model._output)._cross_validation_metrics_summary = ((GLMModel.GLMOutput)glmModel._output)._cross_validation_metrics_summary;
            Frame inputTrain = (Frame)((RuleFitModel.RuleFitParameters)model._parms)._train.get();
            for (Key modelMetricsKey : ((GLMModel.GLMOutput)glmModel._output).getModelMetrics()) {
                model.addModelMetrics(((ModelMetrics)modelMetricsKey.get()).deepCloneWithDifferentModelAndFrame((Model)model, inputTrain));
            }
        }

        int[] range(int min, int max) {
            int[] array = new int[max - min + 1];
            int i = min;
            int j = 0;
            while (i <= max) {
                array[j] = i++;
                ++j;
            }
            return array;
        }

        SharedTree<?, ?, ?> makeTreeModelBuilder(RuleFitModel.Algorithm algorithm, int maxDepth) {
            SharedTree builder;
            SharedTreeModel.SharedTreeParameters p = (SharedTreeModel.SharedTreeParameters)RuleFit.this.treeParameters.clone();
            p._max_depth = maxDepth;
            if (algorithm.equals((Object)RuleFitModel.Algorithm.DRF)) {
                builder = new DRF((DRFModel.DRFParameters)p);
            } else if (algorithm.equals((Object)RuleFitModel.Algorithm.GBM)) {
                builder = new GBM((GBMModel.GBMParameters)p);
            } else {
                throw new RuntimeException("Unsupported algorithm for tree building: " + (Object)((Object)((RuleFitModel.RuleFitParameters)RuleFit.this._parms)._algorithm));
            }
            return builder;
        }

        SharedTree<?, ?, ?>[] makeTreeModelBuilders(RuleFitModel.Algorithm algorithm, int[] depths) {
            SharedTree[] builders = new SharedTree[depths.length];
            for (int i = 0; i < depths.length; ++i) {
                builders[i] = this.makeTreeModelBuilder(algorithm, depths[i]);
            }
            return builders;
        }

        double[] getOptimalLambda() {
            int bestLambdaIndex;
            ((RuleFit)RuleFit.this).glmParameters._lambda_search = true;
            GLM job = new GLM(RuleFit.this.glmParameters);
            GLMModel lambdaModel = (GLMModel)job.trainModel().get();
            ((RuleFit)RuleFit.this).glmParameters._lambda_search = false;
            GLMModel.RegularizationPath regularizationPath = lambdaModel.getRegularizationPath();
            double[] deviance = regularizationPath._explained_deviance_train;
            double[] lambdas = regularizationPath._lambdas;
            if (deviance.length < 5) {
                bestLambdaIndex = deviance.length - 1;
            } else {
                bestLambdaIndex = this.getBestLambdaIndex(deviance);
                if (bestLambdaIndex >= lambdas.length) {
                    bestLambdaIndex = this.getBestLambdaIndexCornerCase(deviance, lambdas);
                }
            }
            lambdaModel.remove();
            return new double[]{lambdas[bestLambdaIndex]};
        }

        int getBestLambdaIndex(double[] deviance) {
            int bestLambdaIndex = deviance.length - 1;
            if (deviance.length >= 5) {
                double[] array = ArrayUtils.difference((double[])ArrayUtils.signum((double[])ArrayUtils.difference((double[])ArrayUtils.difference((double[])deviance))));
                for (int i = 0; i < array.length; ++i) {
                    if (array[i] == 0.0 || i <= 0) continue;
                    bestLambdaIndex = 3 * i;
                    break;
                }
            }
            return bestLambdaIndex;
        }

        int getBestLambdaIndexCornerCase(double[] deviance, double[] lambdas) {
            double[] leftUpPoint = new double[]{deviance[0], lambdas[0]};
            double[] rightLowPoint = new double[]{deviance[deviance.length - 1], lambdas[lambdas.length - 1]};
            double[] leftActualPoint = new double[2];
            double[] rightActualPoint = new double[2];
            int leftActualId = 0;
            int rightActualId = deviance.length - 1;
            while (leftActualId < deviance.length && rightActualId < deviance.length && leftActualId < rightActualId) {
                leftActualPoint[0] = deviance[leftActualId];
                leftActualPoint[1] = lambdas[leftActualId];
                double leftVolume = (leftUpPoint[1] - leftActualPoint[1]) * (leftActualPoint[0] - leftUpPoint[0]);
                rightActualPoint[0] = deviance[rightActualId];
                rightActualPoint[1] = lambdas[rightActualId];
                double rightVolume = (rightActualPoint[1] - rightLowPoint[1]) * (rightLowPoint[0] - rightActualPoint[0]);
                if (Math.abs(leftVolume) > Math.abs(rightVolume)) {
                    --rightActualId;
                    continue;
                }
                ++leftActualId;
            }
            return rightActualId;
        }

        double[] getIntercept(GLMModel glmModel) {
            HashMap<String, Double> glmCoefficients = glmModel.coefficients();
            double[] intercept = RuleFit.this.nclasses() > 2 ? new double[RuleFit.this.nclasses()] : new double[1];
            int i = 0;
            for (Map.Entry<String, Double> coefficient : glmCoefficients.entrySet()) {
                if (!"Intercept".equals(coefficient.getKey()) && !coefficient.getKey().contains("Intercept_")) continue;
                intercept[i] = coefficient.getValue();
                ++i;
            }
            return intercept;
        }

        Rule[] getRules(HashMap<String, Double> glmCoefficients, RuleEnsemble ruleEnsemble) {
            Map<String, Double> filteredRules = glmCoefficients.entrySet().stream().filter(e -> !"Intercept".equals(e.getKey()) && !((String)e.getKey()).contains("Intercept_") && 0.0 != (Double)e.getValue()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            ArrayList<Rule> rules = new ArrayList<Rule>();
            for (Map.Entry<String, Double> entry : filteredRules.entrySet()) {
                Rule rule = !entry.getKey().startsWith("linear.") ? ruleEnsemble.getRuleByVarName(entry.getKey().substring(entry.getKey().lastIndexOf(".") + 1)) : new Rule(null, entry.getValue(), entry.getKey());
                rule.setCoefficient(entry.getValue());
                rules.add(rule);
            }
            Comparator<Rule> ruleAbsCoefficientComparator = Comparator.comparingDouble(Rule::getAbsCoefficient).reversed();
            rules.sort(ruleAbsCoefficientComparator);
            return rules.toArray(new Rule[0]);
        }
    }
}

