/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.Distribution;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ensemble.StackedEnsemble;
import hex.ensemble.StackedEnsembleMojoWriter;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.tree.drf.DRFModel;
import java.lang.reflect.Field;
import java.util.Arrays;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.Job;
import water.Key;
import water.Keyed;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.nbhm.NonBlockingHashSet;
import water.udf.CFuncRef;
import water.util.Log;
import water.util.ReflectionUtils;

public class StackedEnsembleModel
extends Model<StackedEnsembleModel, StackedEnsembleParameters, StackedEnsembleOutput> {
    public ModelCategory modelCategory;
    public long trainingFrameRows = -1L;
    public String responseColumn = null;
    private NonBlockingHashSet<String> names = null;
    public int basemodel_nfolds = -1;
    public Model.Parameters.FoldAssignmentScheme basemodel_fold_assignment;
    public String basemodel_fold_column;
    public long seed = -1L;

    public StackedEnsembleModel(Key selfKey, StackedEnsembleParameters parms, StackedEnsembleOutput output) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
    }

    protected Frame predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
        String[] names = this.makeScoringNames();
        String[][] domains = new String[names.length][];
        domains[0] = names.length == 1 ? null : (computeMetrics ? ((StackedEnsembleOutput)this._output)._domains[((StackedEnsembleOutput)this._output)._domains.length - 1] : adaptFrm.lastVec().domain());
        Frame levelOneFrame = new Frame(Key.make((String)("preds_levelone_" + this._key.toString() + fr._key)));
        int baseIdx = 0;
        Frame[] base_prediction_frames = new Frame[((StackedEnsembleParameters)this._parms)._base_models.length];
        for (Key<Model> baseKey : ((StackedEnsembleParameters)this._parms)._base_models) {
            Model base = (Model)baseKey.get();
            Frame adaptedFrame = new Frame(fr);
            base.adaptTestForTrain(adaptedFrame, true, computeMetrics);
            Model.BigScore baseBs = (Model.BigScore)base.makeBigScoreTask(domains, names, adaptedFrame, computeMetrics, true, j, customMetricFunc).doAll(names.length, (byte)3, adaptedFrame);
            Frame basePreds = baseBs.outputFrame(Key.make((String)("preds_base_" + this._key.toString() + fr._key)), names, domains);
            if (base._output.isMultinomialClassifier()) {
                basePreds.remove("predict");
            }
            base_prediction_frames[baseIdx] = basePreds;
            StackedEnsemble.addModelPredictionsToLevelOneFrame(base, basePreds, levelOneFrame);
            DKV.remove((Key)basePreds._key);
            Frame.deleteTempFrameAndItsNonSharedVecs((Frame)adaptedFrame, (Frame)fr);
            ++baseIdx;
        }
        levelOneFrame.add(this.responseColumn, adaptFrm.vec(this.responseColumn));
        Log.info((Object[])new Object[]{"Finished creating \"level one\" frame for scoring: " + levelOneFrame.toString()});
        Model metalearner = ((StackedEnsembleOutput)this._output)._metalearner;
        Frame levelOneAdapted = new Frame(levelOneFrame);
        metalearner.adaptTestForTrain(levelOneAdapted, true, computeMetrics);
        String[] metaNames = metalearner.makeScoringNames();
        String[][] metaDomains = new String[metaNames.length][];
        metaDomains[0] = metaNames.length == 1 ? null : (!computeMetrics ? metalearner._output._domains[metalearner._output._domains.length - 1] : levelOneAdapted.lastVec().domain());
        Model.BigScore metaBs = (Model.BigScore)metalearner.makeBigScoreTask(metaDomains, metaNames, levelOneAdapted, computeMetrics, true, j, CFuncRef.from((String)((StackedEnsembleParameters)this._parms)._custom_metric_func)).doAll(metaNames.length, (byte)3, levelOneAdapted);
        if (computeMetrics) {
            ModelMetrics mmMetalearner = metaBs._mb.makeModelMetrics(metalearner, levelOneFrame, levelOneAdapted, metaBs.outputFrame());
            ModelMetrics mmStackedEnsemble = mmMetalearner.deepCloneWithDifferentModelAndFrame((Model)this, fr);
            this.addModelMetrics(mmStackedEnsemble);
        }
        Frame.deleteTempFrameAndItsNonSharedVecs((Frame)levelOneAdapted, (Frame)levelOneFrame);
        return metaBs.outputFrame(Key.make((String)destination_key), metaNames, metaDomains);
    }

    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("StackedEnsembleModel.score0() should never be called: the code paths that normally go here should call predictScoreImpl().");
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        throw new UnsupportedOperationException("StackedEnsembleModel.makeMetricBuilder should never be called!");
    }

    public ModelMetrics doScoreMetricsOneFrame(Frame frame, Job job) {
        Frame pred = this.predictScoreImpl(frame, new Frame(frame), null, job, true, CFuncRef.from((String)((StackedEnsembleParameters)this._parms)._custom_metric_func));
        pred.remove();
        return ModelMetrics.getFromDKV((Model)this, (Frame)frame);
    }

    public void doScoreOrCopyMetrics(Job job) {
        ((StackedEnsembleOutput)this._output)._training_metrics = this.doScoreMetricsOneFrame(((StackedEnsembleParameters)this._parms).train(), job);
        ((StackedEnsembleOutput)this._output)._validation_metrics = ((StackedEnsembleOutput)this._output)._metalearner._output._validation_metrics;
        if (null != ((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_metrics) {
            ((StackedEnsembleOutput)this._output)._cross_validation_metrics = ((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_metrics.deepCloneWithDifferentModelAndFrame((Model)this, ((StackedEnsembleOutput)this._output)._metalearner._parms.train());
        }
    }

    private DistributionFamily distributionFamily(Model aModel) {
        if (aModel instanceof DRFModel) {
            if (aModel._output.isBinomialClassifier()) {
                return DistributionFamily.bernoulli;
            }
            if (aModel._output.isClassifier()) {
                return DistributionFamily.multinomial;
            }
            return DistributionFamily.gaussian;
        }
        try {
            Field distributionField;
            Field familyField = ReflectionUtils.findNamedField((Object)aModel._parms, (String)"_family");
            Field field = distributionField = familyField != null ? null : ReflectionUtils.findNamedField((Object)aModel, (String)"_dist");
            if (null != familyField) {
                GLMModel.GLMParameters.Family thisFamily = (GLMModel.GLMParameters.Family)((Object)familyField.get(aModel._parms));
                if (thisFamily == GLMModel.GLMParameters.Family.binomial) {
                    return DistributionFamily.bernoulli;
                }
                try {
                    return Enum.valueOf(DistributionFamily.class, thisFamily.toString());
                }
                catch (IllegalArgumentException e) {
                    throw new H2OIllegalArgumentException("Don't know how to find the right DistributionFamily for Family: " + (Object)((Object)thisFamily));
                }
            }
            if (null != distributionField) {
                Distribution distribution = (Distribution)distributionField.get(aModel);
                DistributionFamily distributionFamily = null != distribution ? distribution.distribution : aModel._parms._distribution;
                if (distributionFamily == DistributionFamily.AUTO) {
                    distributionFamily = aModel._output.isBinomialClassifier() ? DistributionFamily.bernoulli : (aModel._output.isClassifier() ? DistributionFamily.multinomial : DistributionFamily.gaussian);
                }
                return distributionFamily;
            }
            throw new H2OIllegalArgumentException("Don't know how to stack models that have neither a distribution hyperparameter nor a family hyperparameter.");
        }
        catch (Exception e) {
            throw new H2OIllegalArgumentException(e.toString(), e.toString());
        }
    }

    public void checkAndInheritModelProperties() {
        if (null == ((StackedEnsembleParameters)this._parms)._base_models || 0 == ((StackedEnsembleParameters)this._parms)._base_models.length) {
            throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; found 0.");
        }
        if (null != ((StackedEnsembleParameters)this._parms)._metalearner_fold_column && 0 != ((StackedEnsembleParameters)this._parms)._metalearner_nfolds) {
            throw new H2OIllegalArgumentException("Cannot specify fold_column and nfolds at the same time.");
        }
        Model aModel = null;
        boolean beenHere = false;
        this.trainingFrameRows = ((StackedEnsembleParameters)this._parms).train().numRows();
        for (Key<Model> k : ((StackedEnsembleParameters)this._parms)._base_models) {
            aModel = (Model)DKV.getGet(k);
            if (null == aModel) {
                Log.warn((Object[])new Object[]{"Failed to find base model; skipping: " + k});
                continue;
            }
            if (!aModel.isSupervised()) {
                throw new H2OIllegalArgumentException("Base model is not supervised: " + aModel._key.toString());
            }
            if (beenHere) {
                if (this.modelCategory != aModel._output.getModelCategory()) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different categories of models: " + Arrays.toString(((StackedEnsembleParameters)this._parms)._base_models));
                }
                Frame aTrainingFrame = aModel._parms.train();
                if (this.trainingFrameRows != aTrainingFrame.numRows() && !((StackedEnsembleParameters)this._parms)._is_cv_model) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different size(number of rows) training frames.  Found number of rows: " + this.trainingFrameRows + " and: " + aTrainingFrame.numRows() + ".");
                }
                if (!this.responseColumn.equals(aModel._parms._response_column)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different response columns.  Found: " + this.responseColumn + " and: " + aModel._parms._response_column + ".");
                }
                if (!(aModel._parms._fold_assignment == this.basemodel_fold_assignment || aModel._parms._fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO && this.basemodel_fold_assignment == Model.Parameters.FoldAssignmentScheme.Random || aModel._parms._fold_assignment == Model.Parameters.FoldAssignmentScheme.Random && this.basemodel_fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different fold_assignments.");
                }
                if (aModel._parms._fold_column == null && this.basemodel_nfolds != aModel._parms._nfolds) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different values for nfolds.");
                }
                if (aModel._parms._fold_column == null && aModel._parms._nfolds < 2) {
                    throw new H2OIllegalArgumentException("Base model does not use cross-validation: " + aModel._parms._nfolds);
                }
                if (aModel._parms._fold_column != null && !aModel._parms._fold_column.equals(this.basemodel_fold_column)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different fold_columns.");
                }
                if (aModel._parms._fold_column == null && this.basemodel_fold_assignment == Model.Parameters.FoldAssignmentScheme.Random && aModel._parms._seed != this.seed) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use random-seeded k-fold cross-validation but have different seeds.");
                }
                if (!aModel._parms._keep_cross_validation_predictions) {
                    throw new H2OIllegalArgumentException("Base model does not keep cross-validation predictions: " + aModel._parms._nfolds);
                }
                if (aModel instanceof DRFModel || this.distributionFamily(aModel) == this.distributionFamily(this)) continue;
                Log.warn((Object[])new Object[]{"Base models are inconsistent; they use different distributions: " + this.distributionFamily(this) + " and: " + this.distributionFamily(aModel) + ". Is this intentional?"});
                continue;
            }
            this.modelCategory = aModel._output.getModelCategory();
            this._dist = new Distribution(this.distributionFamily(aModel));
            ((StackedEnsembleOutput)this._output)._domains = (String[][])Arrays.copyOf(aModel._output._domains, aModel._output._domains.length);
            ((StackedEnsembleOutput)this._output).setNames(aModel._output._names);
            this.names = new NonBlockingHashSet();
            this.names.addAll(Arrays.asList(aModel._output._names));
            this.responseColumn = aModel._parms._response_column;
            if (!this.responseColumn.equals(((StackedEnsembleParameters)this._parms)._response_column)) {
                throw new H2OIllegalArgumentException("StackedModel response_column must match the response_column of each base model.  Found: " + this.responseColumn + " and: " + ((StackedEnsembleParameters)this._parms)._response_column);
            }
            this.basemodel_nfolds = aModel._parms._nfolds;
            this.basemodel_fold_assignment = aModel._parms._fold_assignment;
            if (this.basemodel_fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO) {
                this.basemodel_fold_assignment = Model.Parameters.FoldAssignmentScheme.Random;
            }
            this.basemodel_fold_column = aModel._parms._fold_column;
            this.seed = aModel._parms._seed;
            ((StackedEnsembleParameters)this._parms)._distribution = aModel._parms._distribution;
            beenHere = true;
        }
        if (null == aModel) {
            throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; " + ((StackedEnsembleParameters)this._parms)._base_models.length + " were specified but none of those were found: " + Arrays.toString(((StackedEnsembleParameters)this._parms)._base_models));
        }
    }

    protected Futures remove_impl(Futures fs) {
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.remove(fs);
        }
        if (((StackedEnsembleOutput)this._output)._levelone_frame_id != null) {
            ((StackedEnsembleOutput)this._output)._levelone_frame_id.remove(fs);
        }
        return super.remove_impl(fs);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(((StackedEnsembleOutput)this._output)._metalearner._key);
        for (Key<Model> ks : ((StackedEnsembleParameters)this._parms)._base_models) {
            ab.putKey(ks);
        }
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(((StackedEnsembleOutput)this._output)._metalearner._key, fs);
        for (Key<Model> ks : ((StackedEnsembleParameters)this._parms)._base_models) {
            ab.getKey(ks, fs);
        }
        return super.readAll_impl(ab, fs);
    }

    public StackedEnsembleMojoWriter getMojo() {
        return new StackedEnsembleMojoWriter(this);
    }

    public void deleteCrossValidationModels() {
        if (((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_models != null) {
            for (Key k : ((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_models) {
                Model m = (Model)DKV.getGet((Key)k);
                if (m == null) continue;
                m.delete();
            }
        }
    }

    public void deleteCrossValidationPreds() {
        if (((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_predictions != null) {
            for (Key k : ((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_predictions) {
                Frame f = (Frame)DKV.getGet((Key)k);
                if (f == null) continue;
                f.delete();
            }
        }
        if (((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_holdout_predictions_frame_id != null) {
            ((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_holdout_predictions_frame_id.remove();
        }
    }

    public static class StackedEnsembleOutput
    extends Model.Output {
        public Model _metalearner;
        public Frame _levelone_frame_id;

        public StackedEnsembleOutput() {
        }

        public StackedEnsembleOutput(StackedEnsemble b) {
            super((ModelBuilder)b);
        }

        public StackedEnsembleOutput(Job job) {
            this._job = job;
        }
    }

    public static class StackedEnsembleParameters
    extends Model.Parameters {
        public Key<Model>[] _base_models = new Key[0];
        public boolean _keep_levelone_frame = false;
        public int _metalearner_nfolds;
        public Model.Parameters.FoldAssignmentScheme _metalearner_fold_assignment;
        public String _metalearner_fold_column;
        public MetalearnerAlgorithm _metalearner_algorithm = MetalearnerAlgorithm.AUTO;
        public String _metalearner_params = new String();
        public Model.Parameters _metalearner_parameters;
        public long _seed;

        public String algoName() {
            return "StackedEnsemble";
        }

        public String fullName() {
            return "Stacked Ensemble";
        }

        public String javaName() {
            return StackedEnsembleModel.class.getName();
        }

        public long progressUnits() {
            return 1L;
        }

        public static enum MetalearnerAlgorithm {
            AUTO,
            glm,
            gbm,
            drf,
            deeplearning;

        }
    }
}

