/*
 * Decompiled with CFR 0.152.
 */
package hex;

import hex.Distribution;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsRegression;
import hex.ensemble.StackedEnsemble;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLMModel;
import hex.tree.drf.DRFModel;
import java.lang.reflect.Field;
import java.util.Arrays;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.nbhm.NonBlockingHashSet;
import water.util.Log;
import water.util.ReflectionUtils;

public class StackedEnsembleModel
extends Model<StackedEnsembleModel, StackedEnsembleParameters, StackedEnsembleOutput> {
    public ModelCategory modelCategory;
    public long trainingFrameChecksum = -1L;
    public String responseColumn = null;
    private NonBlockingHashSet<String> names = null;
    private NonBlockingHashSet<String> ignoredColumns = null;
    public int nfolds = -1;
    public Model.Parameters.FoldAssignmentScheme fold_assignment;
    public String fold_column;
    public long seed = -1L;

    public StackedEnsembleModel(Key selfKey, StackedEnsembleParameters parms, StackedEnsembleOutput output) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
    }

    protected Frame predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j, boolean computeMetrics) {
        String[] names = this.makeScoringNames();
        String[][] domains = new String[names.length][];
        domains[0] = names.length == 1 ? null : (!computeMetrics ? ((StackedEnsembleOutput)this._output)._domains[((StackedEnsembleOutput)this._output)._domains.length - 1] : adaptFrm.lastVec().domain());
        Frame levelOneFrame = new Frame(Key.make((String)("preds_levelone_" + this._key.toString() + fr._key)));
        int baseIdx = 0;
        Frame[] base_prediction_frames = new Frame[((StackedEnsembleParameters)this._parms)._base_models.length];
        for (Key<Model> baseKey : ((StackedEnsembleParameters)this._parms)._base_models) {
            Frame basePreds;
            Model base = (Model)baseKey.get();
            Frame adaptedFrame = new Frame(fr);
            base.adaptTestForTrain(adaptedFrame, true, computeMetrics);
            Model.BigScore baseBs = (Model.BigScore)base.makeBigScoreTask(domains, names, adaptedFrame, computeMetrics, true, j).doAll(names.length, (byte)3, adaptedFrame);
            base_prediction_frames[baseIdx] = basePreds = baseBs.outputFrame(Key.make((String)("preds_base_" + this._key.toString() + fr._key)), names, domains);
            StackedEnsemble.addModelPredictionsToLevelOneFrame(base, basePreds, levelOneFrame);
            DKV.remove((Key)basePreds._key);
            Model.cleanup_adapt((Frame)adaptedFrame, (Frame)fr);
            ++baseIdx;
        }
        levelOneFrame.add(this.responseColumn, adaptFrm.vec(this.responseColumn));
        Log.info((Object[])new Object[]{"Finished creating \"level one\" frame for scoring: " + levelOneFrame.toString()});
        Model metalearner = ((StackedEnsembleOutput)this._output)._metalearner;
        Frame levelOneAdapted = new Frame(levelOneFrame);
        metalearner.adaptTestForTrain(levelOneAdapted, true, computeMetrics);
        String[] metaNames = metalearner.makeScoringNames();
        String[][] metaDomains = new String[metaNames.length][];
        metaDomains[0] = metaNames.length == 1 ? null : (!computeMetrics ? metalearner._output._domains[metalearner._output._domains.length - 1] : levelOneAdapted.lastVec().domain());
        Model.BigScore metaBs = (Model.BigScore)metalearner.makeBigScoreTask(metaDomains, metaNames, levelOneAdapted, computeMetrics, true, j).doAll(metaNames.length, (byte)3, levelOneAdapted);
        if (computeMetrics) {
            ModelMetrics mmMetalearner = metaBs._mb.makeModelMetrics(metalearner, levelOneFrame, levelOneAdapted, metaBs.outputFrame());
            ModelMetrics mmStackedEnsemble = mmMetalearner.deepCloneWithDifferentModelAndFrame((Model)this, fr);
            this.addModelMetrics(mmStackedEnsemble);
        }
        Model.cleanup_adapt((Frame)levelOneAdapted, (Frame)levelOneFrame);
        return metaBs.outputFrame(Key.make((String)destination_key), metaNames, metaDomains);
    }

    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("StackedEnsembleModel.score0() should never be called: the code paths that normally go here should call predictScoreImpl().");
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((StackedEnsembleOutput)this._output).getModelCategory()) {
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
        }
        throw H2O.unimpl();
    }

    public ModelMetrics doScoreMetricsOneFrame(Frame frame, Job job) {
        Frame pred = this.predictScoreImpl(frame, new Frame(frame), null, job, true);
        pred.remove();
        return ModelMetrics.getFromDKV((Model)this, (Frame)frame);
    }

    public void doScoreMetrics(Job job) {
        ((StackedEnsembleOutput)this._output)._training_metrics = this.doScoreMetricsOneFrame(((StackedEnsembleParameters)this._parms).train(), job);
        if (null != ((StackedEnsembleParameters)this._parms).valid()) {
            ((StackedEnsembleOutput)this._output)._validation_metrics = this.doScoreMetricsOneFrame(((StackedEnsembleParameters)this._parms).valid(), job);
        }
    }

    private DistributionFamily distributionFamily(Model aModel) {
        if (aModel instanceof DRFModel) {
            if (aModel._output.isBinomialClassifier()) {
                return DistributionFamily.bernoulli;
            }
            if (aModel._output.isClassifier()) {
                throw new H2OIllegalArgumentException("Don't know how to set the distribution for a multinomial Random Forest classifier.");
            }
            return DistributionFamily.gaussian;
        }
        try {
            Field distributionField;
            Field familyField = ReflectionUtils.findNamedField((Object)aModel._parms, (String)"_family");
            Field field = distributionField = familyField != null ? null : ReflectionUtils.findNamedField((Object)aModel, (String)"_dist");
            if (null != familyField) {
                GLMModel.GLMParameters.Family thisFamily = (GLMModel.GLMParameters.Family)((Object)familyField.get(aModel._parms));
                if (thisFamily == GLMModel.GLMParameters.Family.binomial) {
                    return DistributionFamily.bernoulli;
                }
                try {
                    return Enum.valueOf(DistributionFamily.class, thisFamily.toString());
                }
                catch (IllegalArgumentException e) {
                    throw new H2OIllegalArgumentException("Don't know how to find the right DistributionFamily for Family: " + (Object)((Object)thisFamily));
                }
            }
            if (null != distributionField) {
                Distribution distribution = (Distribution)distributionField.get(aModel);
                DistributionFamily distributionFamily = null != distribution ? distribution.distribution : aModel._parms._distribution;
                if (distributionFamily == DistributionFamily.AUTO) {
                    if (aModel._output.isBinomialClassifier()) {
                        distributionFamily = DistributionFamily.bernoulli;
                    } else {
                        if (aModel._output.isClassifier()) {
                            throw new H2OIllegalArgumentException("Don't know how to determine the distribution for a multinomial classifier.");
                        }
                        distributionFamily = DistributionFamily.gaussian;
                    }
                }
                return distributionFamily;
            }
            throw new H2OIllegalArgumentException("Don't know how to stack models that have neither a distribution hyperparameter nor a family hyperparameter.");
        }
        catch (Exception e) {
            throw new H2OIllegalArgumentException(e.toString(), e.toString());
        }
    }

    public void checkAndInheritModelProperties() {
        if (null == ((StackedEnsembleParameters)this._parms)._base_models || 0 == ((StackedEnsembleParameters)this._parms)._base_models.length) {
            throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; found 0.");
        }
        Model aModel = null;
        boolean beenHere = false;
        this.trainingFrameChecksum = ((StackedEnsembleParameters)this._parms).train().checksum();
        for (Key<Model> k : ((StackedEnsembleParameters)this._parms)._base_models) {
            aModel = (Model)DKV.getGet(k);
            if (null == aModel) {
                Log.warn((Object[])new Object[]{"Failed to find base model; skipping: " + k});
                continue;
            }
            if (beenHere) {
                if (((StackedEnsembleOutput)this._output)._isSupervised ^ aModel.isSupervised()) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of supervised and unsupervised models: " + Arrays.toString(((StackedEnsembleParameters)this._parms)._base_models));
                }
                if (this.modelCategory != aModel._output.getModelCategory()) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different categories of models: " + Arrays.toString(((StackedEnsembleParameters)this._parms)._base_models));
                }
                Frame aTrainingFrame = aModel._parms.train();
                if (this.trainingFrameChecksum != aTrainingFrame.checksum()) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different training frames.  Found checksums: " + this.trainingFrameChecksum + " and: " + aTrainingFrame.checksum() + ".");
                }
                NonBlockingHashSet aNames = new NonBlockingHashSet();
                aNames.addAll(Arrays.asList(aModel._output._names));
                if (!aNames.equals(this.names)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different column lists.  Found: " + this.names + " and: " + aNames + ".");
                }
                NonBlockingHashSet anIgnoredColumns = new NonBlockingHashSet();
                if (null != aModel._parms._ignored_columns) {
                    anIgnoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
                }
                if (!anIgnoredColumns.equals(this.ignoredColumns)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different ignored_column lists.  Found: " + this.ignoredColumns + " and: " + aModel._parms._ignored_columns + ".");
                }
                if (!this.responseColumn.equals(aModel._parms._response_column)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different response columns.  Found: " + this.responseColumn + " and: " + aModel._parms._response_column + ".");
                }
                if (((StackedEnsembleOutput)this._output)._domains.length != aModel._output._domains.length) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: there is a mix of different numbers of domains (categorical levels): " + Arrays.toString(((StackedEnsembleParameters)this._parms)._base_models));
                }
                if (!(aModel._parms._fold_assignment == this.fold_assignment || aModel._parms._fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO && this.fold_assignment == Model.Parameters.FoldAssignmentScheme.Random || aModel._parms._fold_assignment == Model.Parameters.FoldAssignmentScheme.Random && this.fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different fold_assignments.");
                }
                if (aModel._parms._fold_column == null && this.nfolds != aModel._parms._nfolds) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different values for nfolds.");
                }
                if (aModel._parms._fold_column == null && aModel._parms._nfolds < 2) {
                    throw new H2OIllegalArgumentException("Base model does not use cross-validation: " + aModel._parms._nfolds);
                }
                if (aModel._parms._fold_column != null && !aModel._parms._fold_column.equals(this.fold_column)) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use different fold_columns.");
                }
                if (aModel._parms._fold_column == null && this.fold_assignment == Model.Parameters.FoldAssignmentScheme.Random && aModel._parms._seed != this.seed) {
                    throw new H2OIllegalArgumentException("Base models are inconsistent: they use random-seeded crossfold validation but have different seeds.");
                }
                if (!aModel._parms._keep_cross_validation_predictions) {
                    throw new H2OIllegalArgumentException("Base model does not keep cross-validation predictions: " + aModel._parms._nfolds);
                }
                if (aModel instanceof DRFModel || this.distributionFamily(aModel) == this.distributionFamily(this)) continue;
                Log.warn((Object[])new Object[]{"Base models are inconsistent; they use different distributions: " + this.distributionFamily(this) + " and: " + this.distributionFamily(aModel) + ". Is this intentional?"});
                continue;
            }
            ((StackedEnsembleOutput)this._output)._isSupervised = aModel.isSupervised();
            this.modelCategory = aModel._output.getModelCategory();
            this._dist = new Distribution(this.distributionFamily(aModel));
            ((StackedEnsembleOutput)this._output)._domains = (String[][])Arrays.copyOf(aModel._output._domains, aModel._output._domains.length);
            ((StackedEnsembleOutput)this._output).setNames(aModel._output._names);
            this.names = new NonBlockingHashSet();
            this.names.addAll(Arrays.asList(aModel._output._names));
            this.ignoredColumns = new NonBlockingHashSet();
            if (null != aModel._parms._ignored_columns) {
                this.ignoredColumns.addAll(Arrays.asList(aModel._parms._ignored_columns));
            }
            if (null != ((StackedEnsembleParameters)this._parms)._ignored_columns) {
                NonBlockingHashSet ensembleIgnoredColumns = new NonBlockingHashSet();
                ensembleIgnoredColumns.addAll(Arrays.asList(((StackedEnsembleParameters)this._parms)._ignored_columns));
                if (!ensembleIgnoredColumns.equals(this.ignoredColumns)) {
                    throw new H2OIllegalArgumentException("A StackedEnsemble takes its ignored_columns list from the base models.  An inconsistent list of ignored_columns was specified for the ensemble model.");
                }
            }
            this.responseColumn = aModel._parms._response_column;
            if (!this.responseColumn.equals(((StackedEnsembleParameters)this._parms)._response_column)) {
                throw new H2OIllegalArgumentException("StackedModel response_column must match the response_column of each base model.  Found: " + this.responseColumn + " and: " + ((StackedEnsembleParameters)this._parms)._response_column);
            }
            this.nfolds = aModel._parms._nfolds;
            this.fold_assignment = aModel._parms._fold_assignment;
            if (this.fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO) {
                this.fold_assignment = Model.Parameters.FoldAssignmentScheme.Random;
            }
            this.fold_column = aModel._parms._fold_column;
            this.seed = aModel._parms._seed;
            ((StackedEnsembleParameters)this._parms)._distribution = aModel._parms._distribution;
            beenHere = true;
        }
        if (null == aModel) {
            throw new H2OIllegalArgumentException("When creating a StackedEnsemble you must specify one or more models; " + ((StackedEnsembleParameters)this._parms)._base_models.length + " were specified but none of those were found: " + Arrays.toString(((StackedEnsembleParameters)this._parms)._base_models));
        }
    }

    protected Futures remove_impl(Futures fs) {
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            DKV.remove((Key)((StackedEnsembleOutput)this._output)._metalearner._key, (Futures)fs);
        }
        return super.remove_impl(fs);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(((StackedEnsembleOutput)this._output)._metalearner._key);
        for (Key<Model> ks : ((StackedEnsembleParameters)this._parms)._base_models) {
            ab.putKey(ks);
        }
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(((StackedEnsembleOutput)this._output)._metalearner._key, fs);
        for (Key<Model> ks : ((StackedEnsembleParameters)this._parms)._base_models) {
            ab.getKey(ks, fs);
        }
        return super.readAll_impl(ab, fs);
    }

    public static class StackedEnsembleOutput
    extends Model.Output {
        public Model _metalearner;

        public StackedEnsembleOutput() {
        }

        public StackedEnsembleOutput(StackedEnsemble b) {
            super((ModelBuilder)b);
        }

        public StackedEnsembleOutput(Job job) {
            this._job = job;
        }
    }

    public static class StackedEnsembleParameters
    extends Model.Parameters {
        public Key<Model>[] _base_models = new Key[0];

        public String algoName() {
            return "StackedEnsemble";
        }

        public String fullName() {
            return "Stacked Ensemble";
        }

        public String javaName() {
            return StackedEnsembleModel.class.getName();
        }

        public long progressUnits() {
            return 1L;
        }
    }
}

