/*
 * Decompiled with CFR 0.152.
 */
package hex.adaboost;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.adaboost.AdaBoostModel;
import hex.adaboost.CountWeTask;
import hex.adaboost.UpdateWeightsTask;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import java.util.ArrayList;
import org.apache.log4j.Logger;
import water.DKV;
import water.Key;
import water.Keyed;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.Timer;
import water.util.TwoDimTable;

public class AdaBoost
extends ModelBuilder<AdaBoostModel, AdaBoostModel.AdaBoostParameters, AdaBoostModel.AdaBoostOutput> {
    private static final Logger LOG = Logger.getLogger(AdaBoost.class);
    private static final int MAX_LEARNERS = 100000;
    private AdaBoostModel _model;
    private String _weightsName = "weights";

    public AdaBoost(AdaBoostModel.AdaBoostParameters parms) {
        super((Model.Parameters)parms);
        this.init(false);
    }

    public AdaBoost(boolean startup_once) {
        super((Model.Parameters)new AdaBoostModel.AdaBoostParameters(), startup_once);
    }

    public boolean havePojo() {
        return false;
    }

    public boolean haveMojo() {
        return false;
    }

    public void init(boolean expensive) {
        super.init(expensive);
        if (((AdaBoostModel.AdaBoostParameters)this._parms)._nlearners < 1 || ((AdaBoostModel.AdaBoostParameters)this._parms)._nlearners > 100000) {
            this.error("n_estimators", "Parameter n_estimators must be in interval [1, 100000] but it is " + ((AdaBoostModel.AdaBoostParameters)this._parms)._nlearners);
        }
        if (((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner == AdaBoostModel.Algorithm.AUTO) {
            ((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner = AdaBoostModel.Algorithm.DRF;
        }
        if (((AdaBoostModel.AdaBoostParameters)this._parms)._weights_column != null) {
            this._weightsName = ((AdaBoostModel.AdaBoostParameters)this._parms)._weights_column;
        }
        if (!(0.0 < ((AdaBoostModel.AdaBoostParameters)this._parms)._learn_rate) || !(((AdaBoostModel.AdaBoostParameters)this._parms)._learn_rate <= 1.0)) {
            this.error("learn_rate", "learn_rate must be between 0 and 1");
        }
    }

    protected ModelBuilder.Driver trainModelImpl() {
        return new AdaBoostDriver();
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial};
    }

    public boolean isSupervised() {
        return true;
    }

    private ModelBuilder chooseWeakLearner(Frame frame) {
        switch (((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner) {
            case GLM: {
                return this.getGLMWeakLearner(frame);
            }
            case GBM: {
                return this.getGBMWeakLearner(frame);
            }
        }
        return this.getDRFWeakLearner(frame);
    }

    private DRF getDRFWeakLearner(Frame frame) {
        DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
        parms._train = frame._key;
        parms._response_column = ((AdaBoostModel.AdaBoostParameters)this._parms)._response_column;
        parms._weights_column = this._weightsName;
        parms._mtries = 1;
        parms._min_rows = 1.0;
        parms._ntrees = 1;
        parms._sample_rate = 1.0;
        parms._max_depth = 1;
        parms._seed = ((AdaBoostModel.AdaBoostParameters)this._parms)._seed;
        return new DRF(parms);
    }

    private GLM getGLMWeakLearner(Frame frame) {
        GLMModel.GLMParameters parms = new GLMModel.GLMParameters();
        parms._train = frame._key;
        parms._response_column = ((AdaBoostModel.AdaBoostParameters)this._parms)._response_column;
        parms._weights_column = this._weightsName;
        parms._seed = ((AdaBoostModel.AdaBoostParameters)this._parms)._seed;
        return new GLM(parms);
    }

    private GBM getGBMWeakLearner(Frame frame) {
        GBMModel.GBMParameters parms = new GBMModel.GBMParameters();
        parms._train = frame._key;
        parms._response_column = ((AdaBoostModel.AdaBoostParameters)this._parms)._response_column;
        parms._weights_column = this._weightsName;
        parms._min_rows = 1.0;
        parms._ntrees = 1;
        parms._sample_rate = 1.0;
        parms._max_depth = 1;
        parms._seed = ((AdaBoostModel.AdaBoostParameters)this._parms)._seed;
        return new GBM(parms);
    }

    public TwoDimTable createModelSummaryTable() {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormat = new ArrayList<String>();
        colHeaders.add("Number of weak learners");
        colTypes.add("int");
        colFormat.add("%d");
        colHeaders.add("Learn rate");
        colTypes.add("int");
        colFormat.add("%d");
        colHeaders.add("Weak learner");
        colTypes.add("int");
        colFormat.add("%d");
        colHeaders.add("Seed");
        colTypes.add("long");
        colFormat.add("%d");
        boolean rows = true;
        TwoDimTable table = new TwoDimTable("Model Summary", null, new String[1], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");
        int row = 0;
        int col = 0;
        table.set(row, col++, (Object)((AdaBoostModel.AdaBoostParameters)this._parms)._nlearners);
        table.set(row, col++, (Object)((AdaBoostModel.AdaBoostParameters)this._parms)._learn_rate);
        table.set(row, col++, (Object)((AdaBoostModel.AdaBoostParameters)this._parms)._weak_learner.toString());
        table.set(row, col, (Object)((AdaBoostModel.AdaBoostParameters)this._parms)._seed);
        return table;
    }

    private class AdaBoostDriver
    extends ModelBuilder.Driver {
        private AdaBoostDriver() {
            super((ModelBuilder)AdaBoost.this);
        }

        public void computeImpl() {
            AdaBoost.this._model = null;
            try {
                AdaBoost.this.init(true);
                if (AdaBoost.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder((ModelBuilder)AdaBoost.this);
                }
                AdaBoost.this._model = new AdaBoostModel((Key<AdaBoostModel>)AdaBoost.this.dest(), (AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms, new AdaBoostModel.AdaBoostOutput(AdaBoost.this));
                AdaBoost.this._model.delete_and_lock(AdaBoost.this._job);
                this.buildAdaboost();
                LOG.info((Object)AdaBoost.this._model.toString());
            }
            finally {
                if (AdaBoost.this._model != null) {
                    AdaBoost.this._model.unlock(AdaBoost.this._job);
                }
            }
        }

        private void buildAdaboost() {
            Frame _trainWithWeights;
            ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output).alphas = new double[((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._nlearners];
            ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output).models = new Key[((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._nlearners];
            if (((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._weights_column == null) {
                _trainWithWeights = new Frame(AdaBoost.this.train());
                Vec weights = _trainWithWeights.anyVec().makeCons(1, 1L, (String[][])null, null)[0];
                AdaBoost.this._weightsName = _trainWithWeights.uniquify(AdaBoost.this._weightsName);
                _trainWithWeights.add(AdaBoost.this._weightsName, weights);
                DKV.put((Keyed)_trainWithWeights);
                Scope.track((Vec)weights);
            } else {
                _trainWithWeights = ((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms).train();
            }
            for (int n = 0; n < ((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._nlearners; ++n) {
                double alphaM;
                Timer timer = new Timer();
                ModelBuilder job = AdaBoost.this.chooseWeakLearner(_trainWithWeights);
                job._parms._seed += (long)n;
                Model model = (Model)job.trainModel().get();
                DKV.put((Keyed)model);
                Scope.untrack((Key[])new Key[]{model._key});
                ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output).models[n] = model._key;
                Frame predictions = model.score(_trainWithWeights);
                Scope.track((Frame[])new Frame[]{predictions});
                CountWeTask countWe = (CountWeTask)new CountWeTask().doAll(new Vec[]{_trainWithWeights.vec(AdaBoost.this._weightsName), _trainWithWeights.vec(((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._response_column), predictions.vec("predict")});
                double eM = countWe.We / countWe.W;
                ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output).alphas[n] = alphaM = ((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._learn_rate * Math.log((1.0 - eM) / eM);
                UpdateWeightsTask updateWeightsTask = new UpdateWeightsTask(alphaM);
                updateWeightsTask.doAll(new Vec[]{_trainWithWeights.vec(AdaBoost.this._weightsName), _trainWithWeights.vec(((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms)._response_column), predictions.vec("predict")});
                AdaBoost.this._job.update(1L);
                AdaBoost.this._model.update(AdaBoost.this._job);
                LOG.info((Object)(n + 1 + ". estimator was built in " + timer.toString()));
                LOG.info((Object)"*********************************************************************");
            }
            if (_trainWithWeights != ((AdaBoostModel.AdaBoostParameters)AdaBoost.this._parms).train()) {
                DKV.remove((Key)_trainWithWeights._key);
            }
            ((AdaBoostModel.AdaBoostOutput)((AdaBoost)AdaBoost.this)._model._output)._model_summary = AdaBoost.this.createModelSummaryTable();
        }
    }
}

