package hex.ensemble;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ensemble.Metalearner.Algorithm;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.glm.GLMModel.GLMParameters;
import water.exceptions.H2OIllegalArgumentException;
import water.nbhm.NonBlockingHashMap;

import java.util.ServiceLoader;
import java.util.function.Supplier;

/**
 * Entry point class to load and access the supported metalearners.
 * Most of them are defined in this class, but some others can be loaded dynamically from the classpath,
 * this is for example the case with the XGBoostMetalearner.
 */
public class Metalearners {

    static final NonBlockingHashMap<String, MetalearnerProvider> providersByName = new NonBlockingHashMap<>();

    static {
        LocalProvider[] localProviders = new LocalProvider[] {
                new LocalProvider<>(Algorithm.AUTO, AUTOMetalearner::new),
                new LocalProvider<>(Algorithm.deeplearning, DLMetalearner::new),
                new LocalProvider<>(Algorithm.drf, DRFMetalearner::new),
                new LocalProvider<>(Algorithm.gbm, GBMMetalearner::new),
                new LocalProvider<>(Algorithm.glm, GLMMetalearner::new),
                new LocalProvider<>(Algorithm.naivebayes, NaiveBayesMetalearner::new),
        };
        for (MetalearnerProvider provider : localProviders) {
            providersByName.put(provider.getName(), provider);
        }

        ServiceLoader<MetalearnerProvider> extensionProviders = ServiceLoader.load(MetalearnerProvider.class);
        for (MetalearnerProvider provider : extensionProviders) {
            providersByName.put(provider.getName(), provider);
        }
    }

    static Algorithm getActualMetalearnerAlgo(Algorithm algo) {
        assertAvailable(algo.name());
        return algo == Algorithm.AUTO ? Algorithm.glm : algo;
    }

    static Model.Parameters createParameters(String name) {
        assertAvailable(name);
        return createInstance(name).createBuilder()._parms;
    }

    static Metalearner createInstance(String name) {
        assertAvailable(name);
        return providersByName.get(name).newInstance();
    }

    private static void assertAvailable(String algo) {
        if (!providersByName.containsKey(algo))
            throw new H2OIllegalArgumentException("'"+algo+"' metalearner is not supported or available.");
    }

    /**
     * A local implementation of {@link MetalearnerProvider} to expose the {@link Metalearner}s defined in this class.
     */
    static class LocalProvider<M extends Metalearner> implements MetalearnerProvider<M> {

        private Algorithm _algorithm;
        private Supplier<M> _instanceFactory;

        public LocalProvider(Algorithm algorithm,
                             Supplier<M> instanceFactory) {
            _algorithm = algorithm;
            _instanceFactory = instanceFactory;
        }

        @Override
        public String getName() {
            return _algorithm.name();
        }

        @Override
        public M newInstance() {
            return _instanceFactory.get();
        }
    }

    /**
     * A simple implementation of {@link Metalearner} suitable for any algo; it is just using the algo with its default parameters.
     */
    public static class SimpleMetalearner extends Metalearner {
        private String _algo;

        protected SimpleMetalearner(String algo) {
            _algo = algo;
        }

        @Override
        ModelBuilder createBuilder() {
            return ModelBuilder.make(_algo, _metalearnerJob, _metalearnerKey);
        }
    }

    static class DLMetalearner extends SimpleMetalearner {
        public DLMetalearner() {
            super(Algorithm.deeplearning.name());
        }
    }

    static class DRFMetalearner extends SimpleMetalearner {
        public DRFMetalearner() {
            super(Algorithm.drf.name());
        }
    }

    static class GBMMetalearner extends SimpleMetalearner {
        public GBMMetalearner() {
            super(Algorithm.gbm.name());
        }
    }

    static class GLMMetalearner extends Metalearner<GLM, GLMModel, GLMParameters> {
        @Override
        GLM createBuilder() {
            return ModelBuilder.make("GLM", _metalearnerJob, _metalearnerKey);
        }

        @Override
        protected void setCustomParams(GLMParameters parms) {
            if (_model.modelCategory == ModelCategory.Regression) {
                parms._family = GLMParameters.Family.gaussian;
            } else if (_model.modelCategory == ModelCategory.Binomial) {
                parms._family = GLMParameters.Family.binomial;
            } else if (_model.modelCategory == ModelCategory.Multinomial) {
                parms._family = GLMParameters.Family.multinomial;
            } else {
                throw new H2OIllegalArgumentException("Family " + _model.modelCategory + "  is not supported.");
            }
        }
    }

    static class NaiveBayesMetalearner extends SimpleMetalearner {
        public NaiveBayesMetalearner() {
            super(Algorithm.naivebayes.name());
        }
    }

    static class AUTOMetalearner extends GLMMetalearner {

        @Override
        protected void setCustomParams(GLMParameters parms) {
            //add GLM custom params
            super.setCustomParams(parms);

            //specific to AUTO mode
            parms._non_negative = true;
            //parms._alpha = new double[] {0.0, 0.25, 0.5, 0.75, 1.0};

            // feature columns are already homogeneous (probabilities); when standardization is enabled,
            // there can be information loss if some columns have very low probabilities compared with others for example (bad model)
            // giving more weight than it should to those columns.
            parms._standardize = false;

            // Enable lambda search if a validation frame is passed in to get a better GLM fit.
            // Since we are also using non_negative to true, we should also set early_stopping = false.
            if (parms._valid != null) {
                parms._lambda_search = true;
                parms._early_stopping = false;
            }
        }
    }

}
