/*
 * Decompiled with CFR 0.152.
 */
package hex.ensemble;

import hex.LinkFunction;
import hex.LinkFunctionFactory;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.ensemble.Metalearner;
import hex.ensemble.Metalearners;
import hex.ensemble.StackedEnsemble;
import hex.ensemble.StackedEnsembleMojoWriter;
import hex.genmodel.utils.DistributionFamily;
import hex.genmodel.utils.LinkFunctionType;
import java.util.Arrays;
import java.util.HashSet;
import java.util.stream.Stream;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.LocalMR;
import water.MRTask;
import water.MrFun;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.udf.CFuncRef;
import water.util.Log;
import water.util.MRUtils;
import water.util.TwoDimTable;

public class StackedEnsembleModel
extends Model<StackedEnsembleModel, StackedEnsembleParameters, StackedEnsembleOutput> {
    public ModelCategory modelCategory;
    public long trainingFrameRows = -1L;
    public String responseColumn = null;

    public StackedEnsembleModel(Key selfKey, StackedEnsembleParameters parms, StackedEnsembleOutput output) {
        super(selfKey, (Model.Parameters)parms, (Model.Output)output);
    }

    public void initActualParamValues() {
        super.initActualParamValues();
        if (((StackedEnsembleParameters)this._parms)._metalearner_fold_assignment == Model.Parameters.FoldAssignmentScheme.AUTO) {
            ((StackedEnsembleParameters)this._parms)._metalearner_fold_assignment = Model.Parameters.FoldAssignmentScheme.Random;
        }
    }

    public boolean haveMojo() {
        return super.haveMojo() && Stream.of(((StackedEnsembleParameters)this._parms)._base_models).filter(this::isUsefulBaseModel).map(DKV::getGet).allMatch(Model::haveMojo);
    }

    protected Model.PredictScoreResult predictScoreImpl(final Frame fr, Frame adaptFrm, String destination_key, final Job j, boolean computeMetrics, CFuncRef customMetricFunc) {
        StackedEnsembleParameters.MetalearnerTransform transform;
        if (((StackedEnsembleParameters)this._parms)._metalearner_transform != null && ((StackedEnsembleParameters)this._parms)._metalearner_transform != StackedEnsembleParameters.MetalearnerTransform.NONE) {
            if (!((StackedEnsembleOutput)this._output).isBinomialClassifier() && !((StackedEnsembleOutput)this._output).isMultinomialClassifier()) {
                throw new H2OIllegalArgumentException("Metalearner transform is supported only for classification!");
            }
            transform = ((StackedEnsembleParameters)this._parms)._metalearner_transform;
        } else {
            transform = null;
        }
        final String seKey = this._key.toString();
        Key levelOneFrameKey = Key.make((String)("preds_levelone_" + seKey + fr._key));
        Frame levelOneFrame = transform == null ? new Frame(levelOneFrameKey) : new Frame(new Vec[0]);
        final Model[] usefulBaseModels = (Model[])Stream.of(((StackedEnsembleParameters)this._parms)._base_models).filter(this::isUsefulBaseModel).map(Key::get).toArray(Model[]::new);
        if (usefulBaseModels.length > 0) {
            final Frame[] baseModelPredictions = new Frame[usefulBaseModels.length];
            ((LocalMR)H2O.submitTask((H2O.H2OCountedCompleter)new LocalMR(new MrFun(){

                protected void map(int id) {
                    baseModelPredictions[id] = usefulBaseModels[id].score(fr, "preds_base_" + seKey + usefulBaseModels[id]._key + fr._key, j, false);
                }
            }, usefulBaseModels.length))).join();
            for (int i = 0; i < usefulBaseModels.length; ++i) {
                StackedEnsemble.addModelPredictionsToLevelOneFrame(usefulBaseModels[i], baseModelPredictions[i], levelOneFrame);
                DKV.remove((Key)baseModelPredictions[i]._key);
                Frame.deleteTempFrameAndItsNonSharedVecs((Frame)baseModelPredictions[i], (Frame)levelOneFrame);
            }
        }
        if (transform != null) {
            Frame oldLOF = levelOneFrame;
            levelOneFrame = transform.transform(this, levelOneFrame, (Key<Frame>)levelOneFrameKey);
            oldLOF.remove();
        }
        StackedEnsemble.addNonPredictorsToLevelOneFrame((StackedEnsembleParameters)this._parms, adaptFrm, levelOneFrame, false);
        Log.info((Object[])new Object[]{"Finished creating \"level one\" frame for scoring: " + levelOneFrame.toString()});
        Model metalearner = ((StackedEnsembleOutput)this._output)._metalearner;
        Frame predictFr = metalearner.score(levelOneFrame, destination_key, j, computeMetrics, CFuncRef.from((String)((StackedEnsembleParameters)this._parms)._custom_metric_func));
        ModelMetrics mmStackedEnsemble = null;
        if (computeMetrics) {
            Key[] mms = metalearner._output.getModelMetrics();
            ModelMetrics lastComputedMetric = (ModelMetrics)mms[mms.length - 1].get();
            mmStackedEnsemble = lastComputedMetric.deepCloneWithDifferentModelAndFrame((Model)this, fr);
            this.addModelMetrics(mmStackedEnsemble);
            for (Key mm : metalearner._output.clearModelMetrics(true)) {
                DKV.remove((Key)mm);
            }
        }
        Frame.deleteTempFrameAndItsNonSharedVecs((Frame)levelOneFrame, (Frame)adaptFrm);
        return new StackedEnsemblePredictScoreResult(predictFr, mmStackedEnsemble);
    }

    boolean isUsefulBaseModel(Key<Model> baseModelKey) {
        Model metalearner = ((StackedEnsembleOutput)this._output)._metalearner;
        assert (metalearner != null) : "can't use isUsefulBaseModel during training";
        if (this.modelCategory == ModelCategory.Multinomial) {
            for (String feature : metalearner._output._names) {
                if (!feature.startsWith(baseModelKey.toString().concat("/")) || !metalearner.isFeatureUsedInPredict(feature)) continue;
                return true;
            }
            return false;
        }
        return metalearner.isFeatureUsedInPredict(baseModelKey.toString());
    }

    protected double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("StackedEnsembleModel.score0() should never be called: the code paths that normally go here should call predictScoreImpl().");
    }

    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        throw new UnsupportedOperationException("StackedEnsembleModel.makeMetricBuilder should never be called!");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private ModelMetrics doScoreTrainingMetrics(Frame frame, Job job) {
        Frame scoredFrame = ((StackedEnsembleParameters)this._parms)._score_training_samples > 0L && ((StackedEnsembleParameters)this._parms)._score_training_samples < frame.numRows() ? MRUtils.sampleFrame((Frame)frame, (long)((StackedEnsembleParameters)this._parms)._score_training_samples, (long)((StackedEnsembleParameters)this._parms)._seed) : frame;
        try {
            Frame adaptedFrame = new Frame(scoredFrame);
            Model.PredictScoreResult result = this.predictScoreImpl(scoredFrame, adaptedFrame, null, job, true, CFuncRef.from((String)((StackedEnsembleParameters)this._parms)._custom_metric_func));
            result.getPredictions().delete();
            ModelMetrics modelMetrics = result.makeModelMetrics(scoredFrame, adaptedFrame);
            return modelMetrics;
        }
        finally {
            if (scoredFrame != frame) {
                scoredFrame.delete();
            }
        }
    }

    void doScoreOrCopyMetrics(Job job) {
        ((StackedEnsembleOutput)this._output)._training_metrics = this.doScoreTrainingMetrics(((StackedEnsembleParameters)this._parms).train(), null);
        ((StackedEnsembleOutput)this._output)._validation_metrics = ((StackedEnsembleOutput)this._output)._metalearner._output._validation_metrics;
        if (null != ((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_metrics) {
            ((StackedEnsembleOutput)this._output)._cross_validation_metrics = ((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_metrics.deepCloneWithDifferentModelAndFrame((Model)this, ((StackedEnsembleOutput)this._output)._metalearner._parms.train());
            ((StackedEnsembleOutput)this._output)._cross_validation_metrics_summary = (TwoDimTable)((StackedEnsembleOutput)this._output)._metalearner._output._cross_validation_metrics_summary.clone();
        }
    }

    public void deleteBaseModelPredictions() {
        if (((StackedEnsembleOutput)this._output)._base_model_predictions_keys != null) {
            for (Key<Frame> key : ((StackedEnsembleOutput)this._output)._base_model_predictions_keys) {
                if (((StackedEnsembleOutput)this._output)._levelone_frame_id != null && key.get() != null) {
                    Frame.deleteTempFrameAndItsNonSharedVecs((Frame)((Frame)key.get()), (Frame)((StackedEnsembleOutput)this._output)._levelone_frame_id);
                    continue;
                }
                Keyed.remove(key);
            }
            ((StackedEnsembleOutput)this._output)._base_model_predictions_keys = null;
        }
    }

    protected Futures remove_impl(Futures fs, boolean cascade) {
        this.deleteBaseModelPredictions();
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.remove(fs);
        }
        if (((StackedEnsembleOutput)this._output)._levelone_frame_id != null) {
            ((StackedEnsembleOutput)this._output)._levelone_frame_id.remove(fs);
        }
        return super.remove_impl(fs, cascade);
    }

    protected AutoBuffer writeAll_impl(AutoBuffer ab) {
        ab.putKey(((StackedEnsembleOutput)this._output)._metalearner._key);
        for (Key<Model> ks : ((StackedEnsembleParameters)this._parms)._base_models) {
            ab.putKey(ks);
        }
        return super.writeAll_impl(ab);
    }

    protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
        ab.getKey(((StackedEnsembleOutput)this._output)._metalearner._key, fs);
        for (Key<Model> ks : ((StackedEnsembleParameters)this._parms)._base_models) {
            ab.getKey(ks, fs);
        }
        return super.readAll_impl(ab, fs);
    }

    public StackedEnsembleMojoWriter getMojo() {
        return new StackedEnsembleMojoWriter(this);
    }

    public void deleteCrossValidationModels() {
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.deleteCrossValidationModels();
        }
    }

    public void deleteCrossValidationPreds() {
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.deleteCrossValidationPreds();
        }
    }

    public void deleteCrossValidationFoldAssignment() {
        if (((StackedEnsembleOutput)this._output)._metalearner != null) {
            ((StackedEnsembleOutput)this._output)._metalearner.deleteCrossValidationFoldAssignment();
        }
    }

    private class StackedEnsemblePredictScoreResult
    extends Model.PredictScoreResult {
        private final ModelMetrics _modelMetrics;

        public StackedEnsemblePredictScoreResult(Frame preds, ModelMetrics modelMetrics) {
            super((Model)StackedEnsembleModel.this, null, preds, preds);
            this._modelMetrics = modelMetrics;
        }

        public ModelMetrics makeModelMetrics(Frame fr, Frame adaptFrm) {
            return this._modelMetrics;
        }

        public ModelMetrics.MetricBuilder<?> getMetricBuilder() {
            throw new UnsupportedOperationException("Stacked Ensemble model doesn't implement MetricBuilder infrastructure code, retrieve your metrics by calling getOrMakeMetrics method.");
        }
    }

    public static class StackedEnsembleOutput
    extends Model.Output {
        public Model _metalearner;
        public Frame _levelone_frame_id;
        public StackingStrategy _stacking_strategy;
        public Key<Frame>[] _base_model_predictions_keys;

        public StackedEnsembleOutput() {
        }

        public StackedEnsembleOutput(StackedEnsemble b) {
            super((ModelBuilder)b);
        }

        public StackedEnsembleOutput(Job job) {
            this._job = job;
        }

        public int nfeatures() {
            return super.nfeatures() - (this._metalearner._parms._fold_column == null ? 0 : 1);
        }
    }

    public static class StackedEnsembleParameters
    extends Model.Parameters {
        public Key<Model>[] _base_models = new Key[0];
        public boolean _keep_levelone_frame = false;
        public boolean _keep_base_model_predictions = false;
        public int _metalearner_nfolds;
        public Model.Parameters.FoldAssignmentScheme _metalearner_fold_assignment;
        public String _metalearner_fold_column;
        public Key<Frame> _blending;
        public MetalearnerTransform _metalearner_transform = MetalearnerTransform.NONE;
        public Metalearner.Algorithm _metalearner_algorithm = Metalearner.Algorithm.AUTO;
        public String _metalearner_params = new String();
        public Model.Parameters _metalearner_parameters;
        public long _score_training_samples = 10000L;

        public String algoName() {
            return "StackedEnsemble";
        }

        public String fullName() {
            return "Stacked Ensemble";
        }

        public String javaName() {
            return StackedEnsembleModel.class.getName();
        }

        public long progressUnits() {
            return 1L;
        }

        public void initMetalearnerParams() {
            this.initMetalearnerParams(this._metalearner_algorithm);
        }

        public void initMetalearnerParams(Metalearner.Algorithm algo) {
            this._metalearner_algorithm = algo;
            this._metalearner_parameters = Metalearners.createParameters(algo.name());
        }

        public final Frame blending() {
            return this._blending == null ? null : (Frame)this._blending.get();
        }

        public String[] getNonPredictors() {
            HashSet<String> nonPredictors = new HashSet<String>();
            nonPredictors.addAll(Arrays.asList(super.getNonPredictors()));
            if (null != this._metalearner_fold_column) {
                nonPredictors.add(this._metalearner_fold_column);
            }
            return nonPredictors.toArray(new String[0]);
        }

        public DistributionFamily getDistributionFamily() {
            if (this._metalearner_parameters != null) {
                return this._metalearner_parameters.getDistributionFamily();
            }
            return super.getDistributionFamily();
        }

        public void setDistributionFamily(DistributionFamily distributionFamily) {
            assert (this._metalearner_parameters != null);
            this._metalearner_parameters.setDistributionFamily(distributionFamily);
        }

        public static enum MetalearnerTransform {
            NONE,
            Logit;


            public Frame transform(StackedEnsembleModel model, Frame frame, Key<Frame> destKey) {
                if (this == Logit) {
                    return new MRTask(){

                        public void map(Chunk[] cs, NewChunk[] ncs) {
                            LinkFunction logitLink = LinkFunctionFactory.getLinkFunction((LinkFunctionType)LinkFunctionType.logit);
                            for (int c = 0; c < cs.length; ++c) {
                                for (int i = 0; i < cs[c]._len; ++i) {
                                    double p = Math.min(0.999999999, Math.max(cs[c].atd(i), 1.0E-9));
                                    ncs[c].addNum(logitLink.link(p));
                                }
                            }
                        }
                    }.doAll(frame.numCols(), (byte)3, frame).outputFrame(destKey, frame._names, (String[][])null);
                }
                throw new RuntimeException();
            }
        }
    }

    public static enum StackingStrategy {
        cross_validation,
        blending;

    }
}

