/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.gbm;

import hex.schemas.GBMModelV2;
import hex.tree.SharedTreeModel;
import hex.tree.gbm.GBM;
import java.util.Arrays;
import water.Key;
import water.api.ModelSchema;
import water.fvec.Chunk;
import water.util.ArrayUtils;
import water.util.ModelUtils;

public class GBMModel
extends SharedTreeModel<GBMModel, GBMParameters, GBMOutput> {
    public GBMModel(Key selfKey, GBMParameters parms, GBMOutput output) {
        super(selfKey, parms, output);
    }

    public ModelSchema schema() {
        return new GBMModelV2();
    }

    protected float[] score0(Chunk[] chks, int row_in_chunk, double[] tmp, float[] preds) {
        assert (chks.length >= tmp.length);
        for (int i = 0; i < tmp.length; ++i) {
            tmp[i] = chks[i].at0(row_in_chunk);
        }
        return this.score0(tmp, preds);
    }

    @Override
    protected float[] score0(double[] data, float[] preds) {
        float[] p = super.score0(data, preds);
        if (((GBMParameters)this._parms)._loss == GBMParameters.Family.bernoulli) {
            double fx = (double)p[1] + ((GBMOutput)this._output)._initialPrediction;
            p[2] = 1.0f / (float)(1.0 + Math.exp(-fx));
            p[1] = 1.0f - p[2];
            p[0] = ModelUtils.getPrediction((float[])p, (double[])data);
            return p;
        }
        if (((GBMOutput)this._output).nclasses() > 1) {
            int k;
            if (((GBMOutput)this._output).nclasses() == 2) {
                p[1] = (float)((double)p[1] + ((GBMOutput)this._output)._initialPrediction);
                p[2] = -p[1];
            }
            float maxval = Float.NEGATIVE_INFINITY;
            float dsum = 0.0f;
            for (k = 1; k < p.length; ++k) {
                maxval = Math.max(maxval, p[k]);
            }
            assert (!Float.isInfinite(maxval)) : "Something is wrong with GBM trees since returned prediction is " + Arrays.toString(p);
            for (k = 1; k < p.length; ++k) {
                p[k] = (float)Math.exp(p[k] - maxval);
                dsum += p[k];
            }
            ArrayUtils.div((float[])p, (float)dsum);
            p[0] = ModelUtils.getPrediction((float[])p, (double[])data);
        } else {
            preds[0] = (float)((double)preds[0] + ((GBMOutput)this._output)._initialPrediction);
        }
        return p;
    }

    public static class GBMOutput
    extends SharedTreeModel.SharedTreeOutput {
        public GBMOutput(GBM b) {
            super(b);
        }
    }

    public static class GBMParameters
    extends SharedTreeModel.SharedTreeParameters {
        public Family _loss = Family.AUTO;
        public float _learn_rate = 0.1f;

        public static enum Family {
            AUTO,
            bernoulli;

        }
    }
}

