/*
 * Decompiled with CFR 0.152.
 */
package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ensemble.Metalearner;
import hex.ensemble.MetalearnerProvider;
import hex.glm.GLM;
import hex.glm.GLMModel;
import java.util.ServiceLoader;
import java.util.function.Supplier;
import water.Job;
import water.Key;
import water.exceptions.H2OIllegalArgumentException;
import water.nbhm.NonBlockingHashMap;

public class Metalearners {
    static final NonBlockingHashMap<String, MetalearnerProvider> providersByName;

    static Metalearner.Algorithm getActualMetalearnerAlgo(Metalearner.Algorithm algo) {
        Metalearners.assertAvailable(algo.name());
        return algo == Metalearner.Algorithm.AUTO ? Metalearner.Algorithm.glm : algo;
    }

    static Model.Parameters createParameters(String name) {
        Metalearners.assertAvailable(name);
        return ((ModelBuilder)Metalearners.createInstance((String)name).createBuilder())._parms;
    }

    static Metalearner createInstance(String name) {
        Metalearners.assertAvailable(name);
        return ((MetalearnerProvider)providersByName.get((Object)name)).newInstance();
    }

    private static void assertAvailable(String algo) {
        if (!providersByName.containsKey((Object)algo)) {
            throw new H2OIllegalArgumentException("'" + algo + "' metalearner is not supported or available.");
        }
    }

    static {
        LocalProvider[] localProviders;
        providersByName = new NonBlockingHashMap();
        for (LocalProvider provider : localProviders = new LocalProvider[]{new LocalProvider<AUTOMetalearner>(Metalearner.Algorithm.AUTO, AUTOMetalearner::new), new LocalProvider<DLMetalearner>(Metalearner.Algorithm.deeplearning, DLMetalearner::new), new LocalProvider<DRFMetalearner>(Metalearner.Algorithm.drf, DRFMetalearner::new), new LocalProvider<GBMMetalearner>(Metalearner.Algorithm.gbm, GBMMetalearner::new), new LocalProvider<GLMMetalearner>(Metalearner.Algorithm.glm, GLMMetalearner::new), new LocalProvider<NaiveBayesMetalearner>(Metalearner.Algorithm.naivebayes, NaiveBayesMetalearner::new)}) {
            providersByName.put((Object)provider.getName(), (Object)provider);
        }
        ServiceLoader<MetalearnerProvider> extensionProviders = ServiceLoader.load(MetalearnerProvider.class);
        for (MetalearnerProvider provider : extensionProviders) {
            providersByName.put((Object)provider.getName(), (Object)provider);
        }
    }

    static class AUTOMetalearner
    extends GLMMetalearner {
        AUTOMetalearner() {
        }

        @Override
        protected void setCustomParams(GLMModel.GLMParameters parms) {
            super.setCustomParams(parms);
            parms._non_negative = true;
            parms._standardize = false;
            if (parms._valid != null) {
                parms._lambda_search = true;
                parms._early_stopping = false;
            }
        }
    }

    static class NaiveBayesMetalearner
    extends SimpleMetalearner {
        public NaiveBayesMetalearner() {
            super(Metalearner.Algorithm.naivebayes.name());
        }
    }

    static class GLMMetalearner
    extends Metalearner<GLM, GLMModel, GLMModel.GLMParameters> {
        GLMMetalearner() {
        }

        @Override
        GLM createBuilder() {
            return (GLM)ModelBuilder.make((String)"GLM", (Job)this._metalearnerJob, (Key)this._metalearnerKey);
        }

        @Override
        protected void setCustomParams(GLMModel.GLMParameters parms) {
            if (this._model.modelCategory == ModelCategory.Regression) {
                parms._family = GLMModel.GLMParameters.Family.gaussian;
            } else if (this._model.modelCategory == ModelCategory.Binomial) {
                parms._family = GLMModel.GLMParameters.Family.binomial;
            } else if (this._model.modelCategory == ModelCategory.Multinomial) {
                parms._family = GLMModel.GLMParameters.Family.multinomial;
            } else {
                throw new H2OIllegalArgumentException("Family " + this._model.modelCategory + "  is not supported.");
            }
        }
    }

    static class GBMMetalearner
    extends SimpleMetalearner {
        public GBMMetalearner() {
            super(Metalearner.Algorithm.gbm.name());
        }
    }

    static class DRFMetalearner
    extends SimpleMetalearner {
        public DRFMetalearner() {
            super(Metalearner.Algorithm.drf.name());
        }
    }

    static class DLMetalearner
    extends SimpleMetalearner {
        public DLMetalearner() {
            super(Metalearner.Algorithm.deeplearning.name());
        }
    }

    public static class SimpleMetalearner
    extends Metalearner {
        private String _algo;

        protected SimpleMetalearner(String algo) {
            this._algo = algo;
        }

        ModelBuilder createBuilder() {
            return ModelBuilder.make((String)this._algo, (Job)this._metalearnerJob, (Key)this._metalearnerKey);
        }
    }

    static class LocalProvider<M extends Metalearner>
    implements MetalearnerProvider<M> {
        private Metalearner.Algorithm _algorithm;
        private Supplier<M> _instanceFactory;

        public LocalProvider(Metalearner.Algorithm algorithm, Supplier<M> instanceFactory) {
            this._algorithm = algorithm;
            this._instanceFactory = instanceFactory;
        }

        @Override
        public String getName() {
            return this._algorithm.name();
        }

        @Override
        public M newInstance() {
            return (M)((Metalearner)this._instanceFactory.get());
        }
    }
}

