/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.automl.modeling;

import ai.h2o.automl.Algo;
import ai.h2o.automl.AutoML;
import ai.h2o.automl.AutoMLBuildSpec;
import ai.h2o.automl.ModelParametersProvider;
import ai.h2o.automl.ModelingStep;
import ai.h2o.automl.ModelingSteps;
import ai.h2o.automl.ModelingStepsProvider;
import ai.h2o.automl.StepDefinition;
import ai.h2o.automl.WorkAllocations;
import ai.h2o.automl.events.EventLogEntry;
import ai.h2o.automl.preprocessing.PreprocessingConfig;
import ai.h2o.automl.preprocessing.TargetEncoding;
import hex.KeyValue;
import hex.Model;
import hex.ensemble.Metalearner;
import hex.ensemble.StackedEnsembleModel;
import hex.glm.GLMModel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import water.DKV;
import water.Job;
import water.Key;
import water.util.PojoUtils;

public class StackedEnsembleStepsProvider
implements ModelingStepsProvider<StackedEnsembleSteps>,
ModelParametersProvider<StackedEnsembleModel.StackedEnsembleParameters> {
    @Override
    public String getName() {
        return StackedEnsembleSteps.NAME;
    }

    @Override
    public StackedEnsembleSteps newInstance(AutoML aml) {
        return new StackedEnsembleSteps(aml);
    }

    @Override
    public StackedEnsembleModel.StackedEnsembleParameters newDefaultParameters() {
        return new StackedEnsembleModel.StackedEnsembleParameters();
    }

    public static class StackedEnsembleSteps
    extends ModelingSteps {
        static final String NAME = Algo.StackedEnsemble.name();
        private final ModelingStep[] defaults;
        private final ModelingStep[] optionals;

        @Override
        protected void cleanup() {
            super.cleanup();
            Arrays.stream(this.aml().leaderboard().getModels()).filter(model -> model instanceof StackedEnsembleModel).forEach(model -> ((StackedEnsembleModel)model).deleteBaseModelPredictions());
        }

        public StackedEnsembleSteps(AutoML autoML) {
            super(autoML);
            StepDefinition[] modelingPlan = this.aml().getBuildSpec().build_models.modeling_plan;
            if (Stream.of(modelingPlan).noneMatch(sd -> sd.getName().equals(NAME))) {
                this.defaults = new ModelingStep[0];
                this.optionals = new ModelingStep[0];
            } else {
                int[] baseAlgoGroups;
                ArrayList<StackedEnsembleModelStep> defaultSeSteps = new ArrayList<StackedEnsembleModelStep>();
                Set defaultAlgoProviders = Stream.of(Algo.values()).filter(a -> a != Algo.StackedEnsemble).map(Enum::name).collect(Collectors.toSet());
                for (int group : baseAlgoGroups = Stream.of(modelingPlan).filter(sd -> defaultAlgoProviders.contains(sd.getName())).flatMapToInt(sd -> sd.getAlias() == StepDefinition.Alias.defaults ? IntStream.of(1) : (sd.getAlias() == StepDefinition.Alias.grids ? IntStream.of(2) : (sd.getAlias() == StepDefinition.Alias.all ? IntStream.of(1, 2) : sd.getSteps().stream().flatMapToInt(s -> s.getGroup() == -1 ? IntStream.of(1, 2) : IntStream.of(s.getGroup()))))).distinct().sorted().toArray()) {
                    defaultSeSteps.add(new BestOfFamilySEModelStep("best_of_family_" + group, group, this.aml()));
                    defaultSeSteps.add(new AllSEModelStep("all_" + group, group, this.aml()));
                }
                this.defaults = defaultSeSteps.toArray(new ModelingStep[0]);
                int maxBaseGroup = IntStream.of(baseAlgoGroups).max().orElse(0);
                ArrayList<StackedEnsembleModelStep> optionalSeSteps = new ArrayList<StackedEnsembleModelStep>();
                if (maxBaseGroup > 0) {
                    int optionalGroup = maxBaseGroup + 1;
                    optionalSeSteps.add(new MonotonicSEModelStep("monotonic", optionalGroup, this.aml()));
                    optionalSeSteps.add(new BestOfFamilySEModelStep("best_of_family", optionalGroup, this.aml()));
                    optionalSeSteps.add(new AllSEModelStep("all", optionalGroup, this.aml()));
                    if (Algo.XGBoost.enabled()) {
                        optionalSeSteps.add(new BestOfFamilySEModelStep("best_of_family_xgboost", Metalearner.Algorithm.xgboost, optionalGroup, this.aml()));
                        optionalSeSteps.add(new AllSEModelStep("all_xgboost", Metalearner.Algorithm.xgboost, optionalGroup, this.aml()));
                    }
                    optionalSeSteps.add(new BestOfFamilySEModelStep("best_of_family_gbm", Metalearner.Algorithm.gbm, optionalGroup, this.aml()));
                    optionalSeSteps.add(new AllSEModelStep("all_gbm", Metalearner.Algorithm.gbm, optionalGroup, this.aml()));
                    optionalSeSteps.add(new BestOfFamilySEModelStep("best_of_family_xglm", optionalGroup, this.aml()){

                        @Override
                        protected boolean hasDoppelganger(Key<Model>[] baseModelsKeys) {
                            return false;
                        }

                        @Override
                        protected void setMetalearnerParameters(StackedEnsembleModel.StackedEnsembleParameters params) {
                            super.setMetalearnerParameters(params);
                            GLMModel.GLMParameters metalearnerParams = (GLMModel.GLMParameters)params._metalearner_parameters;
                            metalearnerParams._lambda_search = true;
                        }
                    });
                    optionalSeSteps.add(new AllSEModelStep("all_xglm", optionalGroup, this.aml()){

                        @Override
                        protected boolean hasDoppelganger(Key<Model>[] baseModelsKeys) {
                            HashSet<String> modelTypes = new HashSet<String>();
                            for (Key<Model> key : baseModelsKeys) {
                                String modelType = this.getModelType(key);
                                if (modelTypes.contains(modelType)) {
                                    return false;
                                }
                                modelTypes.add(modelType);
                            }
                            return true;
                        }

                        @Override
                        protected void setMetalearnerParameters(StackedEnsembleModel.StackedEnsembleParameters params) {
                            super.setMetalearnerParameters(params);
                            GLMModel.GLMParameters metalearnerParams = (GLMModel.GLMParameters)params._metalearner_parameters;
                            metalearnerParams._lambda_search = true;
                        }
                    });
                    int card = this.aml().getResponseColumn().cardinality();
                    int maxModels = card <= 2 ? 1000 : Math.max(100, 1000 / (card - 1));
                    optionalSeSteps.add(new BestNModelsSEModelStep("best_N", maxModels, optionalGroup, this.aml()));
                }
                this.optionals = optionalSeSteps.toArray(new ModelingStep[0]);
            }
        }

        @Override
        public String getProvider() {
            return NAME;
        }

        @Override
        protected ModelingStep[] getDefaultModels() {
            return this.defaults;
        }

        @Override
        protected ModelingStep[] getOptionals() {
            return this.optionals;
        }

        static class MonotonicSEModelStep
        extends StackedEnsembleModelStep {
            public MonotonicSEModelStep(String id, int priorityGroup, AutoML autoML) {
                this(id, Metalearner.Algorithm.AUTO, priorityGroup, 10, autoML);
            }

            public MonotonicSEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, int weight, AutoML autoML) {
                super(id == null ? "monotonic" : id, algo, priorityGroup, weight, autoML);
                this._description = this._description + " (built with " + algo.name() + " metalearner, using monotonically constrained AutoML models)";
            }

            boolean hasMonotoneConstrains(Key<Model> modelKey) {
                Model model = (Model)DKV.getGet(modelKey);
                try {
                    KeyValue[] mc = (KeyValue[])PojoUtils.getFieldValue((Object)model._parms, (String)"_monotone_constraints", (PojoUtils.FieldNaming)PojoUtils.FieldNaming.CONSISTENT);
                    return mc != null && mc.length > 0;
                }
                catch (IllegalArgumentException e) {
                    return false;
                }
            }

            @Override
            public boolean canRun() {
                boolean canRun = super.canRun();
                if (!canRun) {
                    return false;
                }
                int monotoneModels = 0;
                for (Key<Model> modelKey : this.getTrainedModelsKeys()) {
                    if (this.hasMonotoneConstrains(modelKey)) {
                        ++monotoneModels;
                    }
                    if (monotoneModels < 2) continue;
                    return true;
                }
                if (monotoneModels == 1) {
                    this.aml().job().update((long)this.getAllocatedWork().consume(), "Only one monotonic base model; skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Skipping StackedEnsemble '%s' since there is only one monotonic model to stack", this._id));
                } else {
                    this.aml().job().update((long)this.getAllocatedWork().consume(), "No monotonic base model; skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Skipping StackedEnsemble '%s' since there is no monotonic model to stack", this._id));
                }
                return false;
            }

            @Override
            protected Key<Model>[] getBaseModels() {
                return (Key[])Stream.of(this.getTrainedModelsKeys()).filter(k -> !this.isStackedEnsemble((Key<Model>)k) && this.hasMonotoneConstrains((Key<Model>)k)).toArray(Key[]::new);
            }

            @Override
            protected Job<StackedEnsembleModel> startJob() {
                return this.stack(this._provider + "_Monotonic", this.getBaseModels(), false);
            }
        }

        static class AllSEModelStep
        extends StackedEnsembleModelStep {
            public AllSEModelStep(String id, int priorityGroup, AutoML autoML) {
                this(id, Metalearner.Algorithm.AUTO, priorityGroup, autoML);
            }

            public AllSEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, AutoML autoML) {
                this(id, algo, priorityGroup, 10, autoML);
            }

            public AllSEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, int weight, AutoML autoML) {
                super(id == null ? "all_" + algo.name() : id, algo, priorityGroup, weight, autoML);
                this._description = this._description + " (built with " + algo.name() + " metalearner, using all AutoML models)";
            }

            @Override
            protected Key<Model>[] getBaseModels() {
                return (Key[])Stream.of(this.getTrainedModelsKeys()).filter(k -> !this.isStackedEnsemble((Key<Model>)k)).toArray(Key[]::new);
            }

            @Override
            protected Job<StackedEnsembleModel> startJob() {
                return this.stack(this._provider + "_AllModels", this.getBaseModels(), false);
            }
        }

        static class BestNModelsSEModelStep
        extends StackedEnsembleModelStep {
            private final int _N;

            public BestNModelsSEModelStep(String id, int N, int priorityGroup, AutoML autoML) {
                this(id, Metalearner.Algorithm.AUTO, N, priorityGroup, 10, autoML);
            }

            public BestNModelsSEModelStep(String id, Metalearner.Algorithm algo, int N, int priorityGroup, int weight, AutoML autoML) {
                super(id == null ? "best_" + N + "_" + algo.name() : id, algo, priorityGroup, weight, autoML);
                this._N = N;
                this._description = this._description + " (built with " + algo.name() + " metalearner, using best " + N + " non-SE models)";
            }

            @Override
            protected Key<Model>[] getBaseModels() {
                return (Key[])Stream.of(this.getTrainedModelsKeys()).filter(k -> !this.isStackedEnsemble((Key<Model>)k)).limit(this._N).toArray(Key[]::new);
            }

            @Override
            protected Job<StackedEnsembleModel> startJob() {
                return this.stack(this._provider + "_Best" + this._N, this.getBaseModels(), false);
            }
        }

        static class BestOfFamilySEModelStep
        extends StackedEnsembleModelStep {
            public BestOfFamilySEModelStep(String id, int priorityGroup, AutoML autoML) {
                this(id, Metalearner.Algorithm.AUTO, priorityGroup, autoML);
            }

            public BestOfFamilySEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, AutoML autoML) {
                this(id, algo, priorityGroup, 10, autoML);
            }

            public BestOfFamilySEModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, int weight, AutoML autoML) {
                super(id == null ? "best_of_family_" + algo.name() : id, algo, priorityGroup, weight, autoML);
                this._description = this._description + " (built with " + algo.name() + " metalearner, using top model from each algorithm type)";
            }

            @Override
            protected Key<Model>[] getBaseModels() {
                ArrayList<Key<Model>> bestModelsOfEachType = new ArrayList<Key<Model>>();
                HashSet<String> typesOfGatheredModels = new HashSet<String>();
                for (Key<Model> key : this.getTrainedModelsKeys()) {
                    String type = this.getModelType(key);
                    if (this.isStackedEnsemble(key) || typesOfGatheredModels.contains(type)) continue;
                    typesOfGatheredModels.add(type);
                    bestModelsOfEachType.add(key);
                }
                return bestModelsOfEachType.toArray(new Key[0]);
            }

            @Override
            protected Job<StackedEnsembleModel> startJob() {
                return this.stack(this._provider + "_BestOfFamily", this.getBaseModels(), false);
            }
        }

        static abstract class StackedEnsembleModelStep
        extends ModelingStep.ModelStep<StackedEnsembleModel> {
            protected final Metalearner.Algorithm _metalearnerAlgo;

            StackedEnsembleModelStep(String id, Metalearner.Algorithm algo, int priorityGroup, int weight, AutoML autoML) {
                super(NAME, Algo.StackedEnsemble, id, priorityGroup, weight, autoML);
                this._metalearnerAlgo = algo;
                this._ignoredConstraints = new AutoML.Constraint[]{AutoML.Constraint.MODEL_COUNT, AutoML.Constraint.FAILURE_COUNT};
            }

            @Override
            protected void setCrossValidationParams(Model.Parameters params) {
            }

            @Override
            protected void setWeightingParams(Model.Parameters params) {
            }

            @Override
            protected void setClassBalancingParams(Model.Parameters params) {
            }

            @Override
            protected PreprocessingConfig getPreprocessingConfig() {
                PreprocessingConfig config = super.getPreprocessingConfig();
                config.put(TargetEncoding.CONFIG_ENABLED, false);
                return config;
            }

            @Override
            public boolean canRun() {
                Key<Model>[] keys = this.getBaseModels();
                WorkAllocations.Work seWork = this.getAllocatedWork();
                if (!super.canRun()) {
                    this.aml().job().update(0L, "Skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Skipping StackedEnsemble '%s' due to the exclude_algos option or it is already trained.", this._id));
                    return false;
                }
                if (keys.length == 0) {
                    this.aml().job().update((long)seWork.consume(), "No base models; skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("No base models, due to timeouts or the exclude_algos option. Skipping StackedEnsemble '%s'.", this._id));
                    return false;
                }
                if (keys.length == 1) {
                    this.aml().job().update((long)seWork.consume(), "Only one base model; skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Skipping StackedEnsemble '%s' since there is only one model to stack", this._id));
                    return false;
                }
                if (!this.isCVEnabled() && this.aml().getBlendingFrame() == null) {
                    this.aml().job().update((long)seWork.consume(), "Cross-validation disabled by the user and no blending frame provided; Skipping this StackedEnsemble");
                    this.aml().eventLog().info(EventLogEntry.Stage.ModelTraining, String.format("Cross-validation is disabled by the user and no blending frame was provided; skipping StackedEnsemble '%s'.", this._id));
                    return false;
                }
                return !this.hasDoppelganger(keys);
            }

            protected boolean hasDoppelganger(Key<Model>[] baseModelsKeys) {
                Key[] seModels = (Key[])Arrays.stream(this.getTrainedModelsKeys()).filter(k -> this.isStackedEnsemble((Key<Model>)k)).toArray(Key[]::new);
                HashSet<Key<Model>> keySet = new HashSet<Key<Model>>(Arrays.asList(baseModelsKeys));
                for (Key seKey : seModels) {
                    StackedEnsembleModelStep seStep = (StackedEnsembleModelStep)this.aml().session().getModelingStep(seKey);
                    if (seStep._metalearnerAlgo != this._metalearnerAlgo) continue;
                    StackedEnsembleModel.StackedEnsembleParameters seParams = (StackedEnsembleModel.StackedEnsembleParameters)((StackedEnsembleModel)seKey.get())._parms;
                    Key[] seBaseModels = seParams._base_models;
                    if (seBaseModels.length != baseModelsKeys.length || !keySet.equals(new HashSet<Key>(Arrays.asList(seBaseModels)))) continue;
                    return true;
                }
                return false;
            }

            protected abstract Key<Model>[] getBaseModels();

            protected String getModelType(Key<Model> key) {
                String keyStr = key.toString();
                return keyStr.substring(0, keyStr.indexOf(95));
            }

            protected boolean isStackedEnsemble(Key<Model> key) {
                ModelingStep step = this.aml().session().getModelingStep(key);
                return step != null && step.getAlgo() == Algo.StackedEnsemble;
            }

            public StackedEnsembleModel.StackedEnsembleParameters prepareModelParameters() {
                StackedEnsembleModel.StackedEnsembleParameters params = new StackedEnsembleModel.StackedEnsembleParameters();
                params._valid = this.aml().getValidationFrame() == null ? null : this.aml().getValidationFrame()._key;
                params._blending = this.aml().getBlendingFrame() == null ? null : this.aml().getBlendingFrame()._key;
                params._keep_levelone_frame = true;
                return params;
            }

            protected void setMetalearnerParameters(StackedEnsembleModel.StackedEnsembleParameters params) {
                AutoMLBuildSpec buildSpec = this.aml().getBuildSpec();
                params._metalearner_fold_column = buildSpec.input_spec.fold_column;
                params._metalearner_nfolds = buildSpec.build_control.nfolds;
                params.initMetalearnerParams(this._metalearnerAlgo);
                params._metalearner_parameters._keep_cross_validation_models = buildSpec.build_control.keep_cross_validation_models;
                params._metalearner_parameters._keep_cross_validation_predictions = buildSpec.build_control.keep_cross_validation_predictions;
            }

            Job<StackedEnsembleModel> stack(String modelName, Key<Model>[] baseModels, boolean isLast) {
                StackedEnsembleModel.StackedEnsembleParameters params = this.prepareModelParameters();
                params._base_models = baseModels;
                params._keep_base_model_predictions = !isLast;
                this.setMetalearnerParameters(params);
                if (this._metalearnerAlgo == Metalearner.Algorithm.AUTO) {
                    this.setAutoMetalearnerSEParameters(params);
                }
                return this.stack(modelName, params);
            }

            Job<StackedEnsembleModel> stack(String modelName, StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters) {
                Key modelKey = this.makeKey(modelName, true);
                return this.trainModel(modelKey, (Model.Parameters)stackedEnsembleParameters);
            }

            protected void setAutoMetalearnerSEParameters(StackedEnsembleModel.StackedEnsembleParameters stackedEnsembleParameters) {
                GLMModel.GLMParameters metalearnerParams = (GLMModel.GLMParameters)stackedEnsembleParameters._metalearner_parameters;
                metalearnerParams._alpha = new double[]{0.5, 1.0};
                if (this.aml().getResponseColumn().isCategorical()) {
                    stackedEnsembleParameters._metalearner_transform = StackedEnsembleModel.StackedEnsembleParameters.MetalearnerTransform.Logit;
                }
            }
        }
    }
}

