/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.DistributionFactory;
import hex.FeatureInteractions;
import hex.FeatureInteractionsCollector;
import hex.FriedmanPopescusHCollector;
import hex.KeyValue;
import hex.Model;
import hex.genmodel.GenModel;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.BranchInteractionConstraints;
import hex.tree.CompressedForest;
import hex.tree.CompressedTree;
import hex.tree.Constraints;
import hex.tree.FriedmanPopescusH;
import hex.tree.GlobalInteractionConstraints;
import hex.tree.Score;
import hex.tree.SharedTreeModel;
import hex.tree.SharedTreeModelWithContributions;
import hex.tree.SharedTreePojoWriter;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GbmMojoWriter;
import hex.tree.gbm.GbmPojoWriter;
import hex.util.EffectiveParametersUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import water.DKV;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.util.TwoDimTable;

public class GBMModel
extends SharedTreeModelWithContributions<GBMModel, GBMParameters, GBMOutput>
implements Model.StagedPredictions,
FeatureInteractionsCollector,
FriedmanPopescusHCollector {
    public GBMModel(Key<GBMModel> selfKey, GBMParameters parms, GBMOutput output) {
        super(selfKey, parms, output);
    }

    public void initActualParamValues() {
        super.initActualParamValues();
        EffectiveParametersUtils.initFoldAssignment(this._parms);
        EffectiveParametersUtils.initHistogramType((SharedTreeModel.SharedTreeParameters)this._parms);
        EffectiveParametersUtils.initCategoricalEncoding(this._parms, Model.Parameters.CategoricalEncodingScheme.Enum);
    }

    public void initActualParamValuesAfterOutputSetup(int nclasses, boolean isClassifier) {
        EffectiveParametersUtils.initStoppingMetric(this._parms, isClassifier);
        EffectiveParametersUtils.initDistribution(this._parms, nclasses);
    }

    @Override
    protected SharedTreeModelWithContributions.ScoreContributionsTask getScoreContributionsTask(SharedTreeModel model) {
        return new SharedTreeModelWithContributions.ScoreContributionsTask(this);
    }

    @Override
    protected SharedTreeModelWithContributions.ScoreContributionsTask getScoreContributionsSoringTask(SharedTreeModel model, Model.Contributions.ContributionsOptions options) {
        return new SharedTreeModelWithContributions.ScoreContributionsSortingTask(model, options);
    }

    public Frame scoreStagedPredictions(Frame frame, Key<Frame> destination_key) {
        Frame adaptFrm = new Frame(frame);
        this.adaptTestForTrain(adaptFrm, true, false);
        String[] names = this.makeAllTreeColumnNames();
        int outputcols = names.length;
        return ((StagedPredictionsTask)new StagedPredictionsTask(this).doAll(outputcols, (byte)3, adaptFrm)).outputFrame(destination_key, names, null);
    }

    @Override
    protected final double[] score0Incremental(Score.ScoreIncInfo sii, Chunk[] chks, double offset, int row_in_chunk, double[] tmp, double[] preds) {
        int i;
        assert (((GBMOutput)this._output).nfeatures() == tmp.length);
        for (i = 0; i < tmp.length; ++i) {
            tmp[i] = chks[i].atd(row_in_chunk);
        }
        if (sii._startTree == 0) {
            Arrays.fill(preds, 0.0);
        } else {
            for (i = 0; i < sii._workspaceColCnt; ++i) {
                preds[sii._predsAryOffset + i] = chks[sii._workspaceColIdx + i].atd(row_in_chunk);
            }
        }
        this.score0(tmp, preds, offset, sii._startTree, ((GBMOutput)this._output)._treeKeys.length);
        for (i = 0; i < sii._workspaceColCnt; ++i) {
            chks[sii._workspaceColIdx + i].set(row_in_chunk, preds[sii._predsAryOffset + i]);
        }
        this.score0Probabilities(preds, offset);
        this.score0PostProcessSupervised(preds, tmp);
        return preds;
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
        super.score0(data, preds, offset, ntrees);
        return this.score0Probabilities(preds, offset);
    }

    private double[] score0Probabilities(double[] preds, double offset) {
        if (((GBMParameters)this._parms)._distribution == DistributionFamily.bernoulli || ((GBMParameters)this._parms)._distribution == DistributionFamily.quasibinomial || ((GBMParameters)this._parms)._distribution == DistributionFamily.modified_huber || ((GBMParameters)this._parms)._distribution == DistributionFamily.custom && ((GBMOutput)this._output).nclasses() == 2) {
            double f = preds[1] + ((GBMOutput)this._output)._init_f + offset;
            preds[2] = DistributionFactory.getDistribution((Model.Parameters)this._parms).linkInv(f);
            preds[1] = 1.0 - preds[2];
        } else if (((GBMParameters)this._parms)._distribution == DistributionFamily.multinomial || ((GBMParameters)this._parms)._distribution == DistributionFamily.custom && ((GBMOutput)this._output).nclasses() > 2) {
            if (((GBMOutput)this._output).nclasses() == 2) {
                preds[1] = preds[1] + (((GBMOutput)this._output)._init_f + offset);
                preds[2] = -preds[1];
            }
            GenModel.GBM_rescale((double[])preds);
        } else {
            double f = preds[0] + ((GBMOutput)this._output)._init_f + offset;
            preds[0] = DistributionFactory.getDistribution((Model.Parameters)this._parms).linkInv(f);
        }
        return preds;
    }

    @Override
    protected SharedTreePojoWriter makeTreePojoWriter() {
        CompressedForest compressedForest = new CompressedForest(((GBMOutput)this._output)._treeKeys, ((GBMOutput)this._output)._domains);
        CompressedForest.LocalCompressedForest localCompressedForest = compressedForest.fetch();
        return new GbmPojoWriter(this, localCompressedForest._trees);
    }

    public GbmMojoWriter getMojo() {
        return new GbmMojoWriter(this);
    }

    public FeatureInteractions getFeatureInteractions(int maxInteractionDepth, int maxTreeDepth, int maxDeepening) {
        FeatureInteractions featureInteractions = new FeatureInteractions();
        int nclasses = ((GBMOutput)this._output)._nclasses > 2 ? ((GBMOutput)this._output)._nclasses : 1;
        for (int i = 0; i < ((GBMParameters)this._parms)._ntrees; ++i) {
            for (int j = 0; j < nclasses; ++j) {
                FeatureInteractions currentTreeFeatureInteractions = new FeatureInteractions();
                SharedTreeSubgraph tree = this.getSharedTreeSubgraph(i, j);
                ArrayList interactionPath = new ArrayList();
                HashSet memo = new HashSet();
                FeatureInteractions.collectFeatureInteractions((SharedTreeNode)tree.rootNode, interactionPath, (double)0.0, (double)0.0, (double)1.0, (int)0, (int)0, (FeatureInteractions)currentTreeFeatureInteractions, memo, (int)maxInteractionDepth, (int)maxTreeDepth, (int)maxDeepening, (int)i, (boolean)true);
                featureInteractions.mergeWith(currentTreeFeatureInteractions);
            }
        }
        return featureInteractions;
    }

    public TwoDimTable[][] getFeatureInteractionsTable(int maxInteractionDepth, int maxTreeDepth, int maxDeepening) {
        return FeatureInteractions.getFeatureInteractionsTable((FeatureInteractions)this.getFeatureInteractions(maxInteractionDepth, maxTreeDepth, maxDeepening));
    }

    public double getFriedmanPopescusH(Frame frame, String[] vars) {
        int nclasses = ((GBMOutput)this._output)._nclasses > 2 ? ((GBMOutput)this._output)._nclasses : 1;
        SharedTreeSubgraph[][] sharedTreeSubgraphs = new SharedTreeSubgraph[((GBMParameters)this._parms)._ntrees][nclasses];
        for (int i = 0; i < ((GBMParameters)this._parms)._ntrees; ++i) {
            for (int j = 0; j < nclasses; ++j) {
                sharedTreeSubgraphs[i][j] = this.getSharedTreeSubgraph(i, j);
            }
        }
        return FriedmanPopescusH.h(frame, vars, ((GBMParameters)this._parms)._learn_rate, sharedTreeSubgraphs);
    }

    private static class StagedPredictionsTask
    extends MRTask<StagedPredictionsTask> {
        private final Key<GBMModel> _modelKey;
        private transient GBMModel _model;

        private StagedPredictionsTask(GBMModel model) {
            this._modelKey = model._key;
        }

        protected void setupLocal() {
            this._model = (GBMModel)this._modelKey.get();
            assert (this._model != null);
        }

        public void map(Chunk[] chks, NewChunk[] nc) {
            double[] input = new double[chks.length];
            int contribOffset = ((GBMOutput)this._model._output).nclasses() == 1 ? 0 : 1;
            for (int row = 0; row < chks[0]._len; ++row) {
                for (int i = 0; i < chks.length; ++i) {
                    input[i] = chks[i].atd(row);
                }
                double[] contribs = new double[contribOffset + ((GBMOutput)this._model._output).nclasses()];
                double[] preds = new double[contribs.length];
                int col = 0;
                for (int tidx = 0; tidx < ((GBMOutput)this._model._output)._treeKeys.length; ++tidx) {
                    int i;
                    Key[] keys = ((GBMOutput)this._model._output)._treeKeys[tidx];
                    for (i = 0; i < keys.length; ++i) {
                        if (keys[i] != null) {
                            int n = contribOffset + i;
                            contribs[n] = contribs[n] + ((CompressedTree)DKV.get((Key)keys[i]).get()).score(input, ((GBMOutput)this._model._output)._domains);
                        }
                        preds[contribOffset + i] = contribs[contribOffset + i];
                    }
                    this._model.score0Probabilities(preds, 0.0);
                    this._model.score0PostProcessSupervised(preds, input);
                    for (i = 0; i < keys.length; ++i) {
                        if (keys[i] == null) continue;
                        nc[col++].addNum(preds[contribOffset + i]);
                    }
                }
                assert (col == nc.length);
            }
        }
    }

    public static class GBMOutput
    extends SharedTreeModel.SharedTreeOutput {
        public String[] _quasibinomialDomains;
        boolean _quasibinomial;
        int _nclasses;

        public int nclasses() {
            return this._nclasses;
        }

        public GBMOutput(GBM b) {
            super(b);
            this._quasibinomial = ((GBMParameters)b._parms)._distribution == DistributionFamily.quasibinomial;
            this._nclasses = b.nclasses();
        }

        public String[] classNames() {
            String[] res = super.classNames();
            if (this._quasibinomial) {
                return this._quasibinomialDomains;
            }
            return res;
        }
    }

    public static class GBMParameters
    extends SharedTreeModel.SharedTreeParameters {
        public double _learn_rate = 0.1;
        public double _learn_rate_annealing = 1.0;
        public double _col_sample_rate = 1.0;
        public double _max_abs_leafnode_pred;
        public double _pred_noise_bandwidth;
        public KeyValue[] _monotone_constraints;
        public String[][] _interaction_constraints;

        public GBMParameters() {
            this._sample_rate = 1.0;
            this._ntrees = 50;
            this._max_depth = 5;
            this._max_abs_leafnode_pred = Double.MAX_VALUE;
            this._pred_noise_bandwidth = 0.0;
        }

        @Override
        public boolean useColSampling() {
            return super.useColSampling() || this._col_sample_rate != 1.0;
        }

        public String algoName() {
            return "GBM";
        }

        public String fullName() {
            return "Gradient Boosting Machine";
        }

        public String javaName() {
            return GBMModel.class.getName();
        }

        @Override
        public boolean forceStrictlyReproducibleHistograms() {
            return this.usesMonotoneConstraints();
        }

        private boolean usesMonotoneConstraints() {
            if (this.areMonotoneConstraintsEmpty()) {
                return this.emptyConstraints(0) != null;
            }
            return true;
        }

        private boolean areMonotoneConstraintsEmpty() {
            return this._monotone_constraints == null || this._monotone_constraints.length == 0;
        }

        public Constraints constraints(Frame f) {
            if (this.areMonotoneConstraintsEmpty()) {
                return this.emptyConstraints(f.numCols());
            }
            int[] cs = new int[f.numCols()];
            for (KeyValue spec : this._monotone_constraints) {
                if (spec.getValue() == 0.0) continue;
                int col = f.find(spec.getKey());
                if (col < 0) {
                    throw new IllegalStateException("Invalid constraint specification, column '" + spec.getKey() + "' doesn't exist.");
                }
                cs[col] = spec.getValue() < 0.0 ? -1 : 1;
            }
            boolean useBounds = this._distribution == DistributionFamily.gaussian || this._distribution == DistributionFamily.bernoulli || this._distribution == DistributionFamily.tweedie || this._distribution == DistributionFamily.quasibinomial || this._distribution == DistributionFamily.multinomial || this._distribution == DistributionFamily.quantile;
            return new Constraints(cs, DistributionFactory.getDistribution((Model.Parameters)this), useBounds);
        }

        Constraints emptyConstraints(int nCols) {
            return null;
        }

        public GlobalInteractionConstraints interactionConstraints(Frame frame) {
            return new GlobalInteractionConstraints(this._interaction_constraints, frame.names());
        }

        public BranchInteractionConstraints initialInteractionConstraints(GlobalInteractionConstraints ics) {
            return new BranchInteractionConstraints(ics.getAllAllowedColumnIndices());
        }
    }
}

