/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.util.FVec;
import hex.genmodel.GenModel;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import java.io.ByteArrayInputStream;
import java.io.IOException;

public final class XGBoostJavaMojoModel
extends XGBoostMojoModel {
    private Predictor _predictor;
    private OneHotEncoderFactory _1hotFactory;

    public XGBoostJavaMojoModel(byte[] boosterBytes, String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
        this._predictor = XGBoostJavaMojoModel.makePredictor(boosterBytes);
    }

    @Override
    public void postReadInit() {
        this._1hotFactory = new OneHotEncoderFactory();
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static Predictor makePredictor(byte[] boosterBytes) {
        try (ByteArrayInputStream is = new ByteArrayInputStream(boosterBytes);){
            Predictor predictor = new Predictor(is);
            return predictor;
        }
        catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    @Override
    public final double[] score0(double[] doubles, double offset, double[] preds) {
        if (offset != 0.0) {
            throw new UnsupportedOperationException("Unsupported: offset != 0");
        }
        OneHotEncoderFVec row = this._1hotFactory.fromArray(doubles);
        float[] out = this._predictor.predict(row);
        return XGBoostJavaMojoModel.toPreds(doubles, out, preds, this._nclasses, this._priorClassDistrib, this._defaultThreshold);
    }

    static ObjFunction getObjFunction(String name) {
        return ObjFunction.fromName(name);
    }

    @Override
    public void close() {
        this._predictor = null;
    }

    @Override
    public SharedTreeGraph convert(int treeNumber, String treeClass) {
        GradBooster booster = this._predictor.getBooster();
        return this._computeGraph(booster, treeNumber);
    }

    static {
        ObjFunction.register("reg:gamma", new RegObjFunction());
        ObjFunction.register("reg:tweedie", new RegObjFunction());
        ObjFunction.register("count:poisson", new RegObjFunction());
    }

    private class OneHotEncoderFVec
    implements FVec {
        private final int[] _catMap;
        private final int[] _catValues;
        private final float[] _numValues;
        private final float _notHot;

        private OneHotEncoderFVec(int[] catMap, int[] catValues, float[] numValues, float notHot) {
            this._catMap = catMap;
            this._catValues = catValues;
            this._numValues = numValues;
            this._notHot = notHot;
        }

        @Override
        public final float fvalue(int index) {
            if (index >= this._catMap.length) {
                return this._numValues[index - this._catMap.length];
            }
            boolean isHot = this._catValues[this._catMap[index]] == index;
            return isHot ? 1.0f : this._notHot;
        }
    }

    private class OneHotEncoderFactory {
        private final int[] _catMap;
        private final float _notHot;

        OneHotEncoderFactory() {
            float f = this._notHot = XGBoostJavaMojoModel.this._sparse ? Float.NaN : 0.0f;
            if (XGBoostJavaMojoModel.this._catOffsets == null) {
                this._catMap = new int[0];
            } else {
                this._catMap = new int[XGBoostJavaMojoModel.this._catOffsets[XGBoostJavaMojoModel.this._cats]];
                for (int c = 0; c < XGBoostJavaMojoModel.this._cats; ++c) {
                    for (int j = XGBoostJavaMojoModel.this._catOffsets[c]; j < XGBoostJavaMojoModel.this._catOffsets[c + 1]; ++j) {
                        this._catMap[j] = c;
                    }
                }
            }
        }

        OneHotEncoderFVec fromArray(double[] input) {
            float[] numValues = new float[XGBoostJavaMojoModel.this._nums];
            int[] catValues = new int[XGBoostJavaMojoModel.this._cats];
            GenModel.setCats(input, catValues, XGBoostJavaMojoModel.this._cats, XGBoostJavaMojoModel.this._catOffsets, XGBoostJavaMojoModel.this._useAllFactorLevels);
            for (int i = 0; i < numValues.length; ++i) {
                float val = (float)input[XGBoostJavaMojoModel.this._cats + i];
                numValues[i] = XGBoostJavaMojoModel.this._sparse && val == 0.0f ? Float.NaN : val;
            }
            return new OneHotEncoderFVec(this._catMap, catValues, numValues, this._notHot);
        }
    }

    private static class RegObjFunction
    extends ObjFunction {
        private RegObjFunction() {
        }

        @Override
        public float[] predTransform(float[] preds) {
            if (preds.length != 1) {
                throw new IllegalStateException("Regression problem is supposed to have just a single predicted value, got " + preds.length + " instead.");
            }
            preds[0] = (float)Math.exp(preds[0]);
            return preds;
        }

        @Override
        public float predTransform(float pred) {
            return (float)Math.exp(pred);
        }
    }
}

