/*
 * Decompiled with CFR 0.152.
 */
package hex.rulefit;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsRegression;
import hex.ToEigenVec;
import hex.glm.GLMModel;
import hex.rulefit.RuleEnsemble;
import hex.rulefit.RuleFit;
import hex.rulefit.RuleFitUtils;
import hex.util.LinearAlgebraUtils;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.fvec.Frame;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.TwoDimTable;

public class RuleFitModel
extends Model<RuleFitModel, RuleFitParameters, RuleFitOutput> {
    GLMModel glmModel;
    RuleEnsemble ruleEnsemble;

    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    public RuleFitModel(Key<RuleFitModel> selfKey, RuleFitParameters parms, RuleFitOutput output, GLMModel glmModel, RuleEnsemble ruleEnsemble) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
        this.glmModel = glmModel;
        this.ruleEnsemble = ruleEnsemble;
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        assert (domain == null);
        switch (((RuleFitOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((RuleFitOutput)this._output).nclasses(), domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl((String)("Invalid ModelCategory " + ((RuleFitOutput)this._output).getModelCategory()));
    }

    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("RuleFitModel doesn't support scoring on raw data. Use score() instead.");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Frame score(Frame fr, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) throws IllegalArgumentException {
        Frame adaptFrm = new Frame(fr);
        this.adaptTestForTrain(adaptFrm, true, false);
        Frame linearTest = new Frame(new Vec[0]);
        try {
            if (ModelType.RULES_AND_LINEAR.equals((Object)((RuleFitParameters)this._parms)._model_type) || ModelType.RULES.equals((Object)((RuleFitParameters)this._parms)._model_type)) {
                linearTest.add(this.ruleEnsemble.createGLMTrainFrame(adaptFrm, ((RuleFitParameters)this._parms)._max_rule_length - ((RuleFitParameters)this._parms)._min_rule_length + 1, ((RuleFitParameters)this._parms)._rule_generation_ntrees));
            }
            if (ModelType.RULES_AND_LINEAR.equals((Object)((RuleFitParameters)this._parms)._model_type) || ModelType.LINEAR.equals((Object)((RuleFitParameters)this._parms)._model_type)) {
                linearTest.add(RuleFitUtils.getLinearNames(adaptFrm.numCols(), adaptFrm.names()), adaptFrm.vecs());
            }
            Frame scored = this.glmModel.score(linearTest, destination_key, null, true);
            this.updateModelMetrics(this.glmModel, fr);
            Frame frame = scored;
            return frame;
        }
        finally {
            Frame.deleteTempFrameAndItsNonSharedVecs((Frame)linearTest, (Frame)adaptFrm);
        }
    }

    protected Futures remove_impl(Futures fs, boolean cascade) {
        super.remove_impl(fs, cascade);
        if (cascade) {
            this.glmModel.remove(fs);
        }
        return fs;
    }

    void updateModelMetrics(GLMModel glmModel, Frame fr) {
        for (Key modelMetricsKey : ((GLMModel.GLMOutput)glmModel._output).getModelMetrics()) {
            this.addModelMetrics(((ModelMetrics)modelMetricsKey.get()).deepCloneWithDifferentModelAndFrame((Model)this, fr));
        }
    }

    public static class RuleFitOutput
    extends Model.Output {
        public double[] _intercept;
        public TwoDimTable _rule_importance = null;
        Key glmModelKey = null;

        public RuleFitOutput(RuleFit b) {
            super((ModelBuilder)b);
        }
    }

    public static class RuleFitParameters
    extends Model.Parameters {
        public Algorithm _algorithm = Algorithm.AUTO;
        public int _min_rule_length = 3;
        public int _max_rule_length = 3;
        public int _max_num_rules = -1;
        public ModelType _model_type = ModelType.RULES_AND_LINEAR;
        public int _rule_generation_ntrees = 50;

        public String algoName() {
            return "RuleFit";
        }

        public String fullName() {
            return "RuleFit";
        }

        public String javaName() {
            return RuleFitModel.class.getName();
        }

        public long progressUnits() {
            return 1000000L;
        }
    }

    public static enum ModelType {
        RULES,
        RULES_AND_LINEAR,
        LINEAR;

    }

    public static enum Algorithm {
        DRF,
        GBM,
        AUTO;

    }
}

