/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.TreeSHAPHelper;
import biz.k11i.xgboost.util.FVec;
import hex.genmodel.GenModel;
import hex.genmodel.PredictContributions;
import hex.genmodel.PredictContributionsFactory;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import hex.genmodel.algos.xgboost.XGBoostJavaObjFunRegistration;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;

public final class XGBoostJavaMojoModel
extends XGBoostMojoModel
implements PredictContributionsFactory {
    private Predictor _predictor;
    private TreeSHAPPredictor<FVec> _treeSHAPPredictor;
    private OneHotEncoderFactory _1hotFactory;

    public XGBoostJavaMojoModel(byte[] boosterBytes, String[] columns, String[][] domains, String responseColumn) {
        this(boosterBytes, columns, domains, responseColumn, false);
    }

    public XGBoostJavaMojoModel(byte[] boosterBytes, String[] columns, String[][] domains, String responseColumn, boolean enableTreeSHAP) {
        super(columns, domains, responseColumn);
        this._predictor = XGBoostJavaMojoModel.makePredictor(boosterBytes);
        this._treeSHAPPredictor = enableTreeSHAP ? XGBoostJavaMojoModel.makeTreeSHAPPredictor(this._predictor) : null;
    }

    @Override
    public void postReadInit() {
        this._1hotFactory = new OneHotEncoderFactory();
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private static Predictor makePredictor(byte[] boosterBytes) {
        try (ByteArrayInputStream is = new ByteArrayInputStream(boosterBytes);){
            Predictor predictor = new Predictor(is);
            return predictor;
        }
        catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    private static TreeSHAPPredictor<FVec> makeTreeSHAPPredictor(Predictor predictor) {
        if (predictor.getNumClass() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        GBTree gbTree = (GBTree)predictor.getBooster();
        RegTree[] trees = gbTree.getGroupedTrees()[0];
        ArrayList predictors = new ArrayList(trees.length);
        for (RegTree tree : trees) {
            predictors.add(TreeSHAPHelper.makePredictor(tree));
        }
        float initPred = TreeSHAPHelper.getInitPrediction(predictor);
        return new TreeSHAPEnsemble<FVec>(predictors, initPred);
    }

    @Override
    public final double[] score0(double[] doubles, double offset, double[] preds) {
        if (offset != 0.0) {
            throw new UnsupportedOperationException("Unsupported: offset != 0");
        }
        OneHotEncoderFVec row = this._1hotFactory.fromArray(doubles);
        float[] out = this._predictor.predict(row);
        return XGBoostJavaMojoModel.toPreds(doubles, out, preds, this._nclasses, this._priorClassDistrib, this._defaultThreshold);
    }

    public final Object makeContributionsWorkspace() {
        return this._treeSHAPPredictor.makeWorkspace();
    }

    public final float[] calculateContributions(FVec row, float[] out_contribs, Object workspace) {
        this._treeSHAPPredictor.calculateContributions(row, out_contribs, 0, -1, workspace);
        return out_contribs;
    }

    @Override
    public final PredictContributions makeContributionsPredictor() {
        TreeSHAPPredictor<FVec> treeSHAPPredictor = this._treeSHAPPredictor != null ? this._treeSHAPPredictor : XGBoostJavaMojoModel.makeTreeSHAPPredictor(this._predictor);
        return new XGBoostContributionsPredictor(treeSHAPPredictor);
    }

    static ObjFunction getObjFunction(String name) {
        return ObjFunction.fromName(name);
    }

    @Override
    public void close() {
        this._predictor = null;
        this._treeSHAPPredictor = null;
        this._1hotFactory = null;
    }

    @Override
    public SharedTreeGraph convert(int treeNumber, String treeClass) {
        GradBooster booster = this._predictor.getBooster();
        return this._computeGraph(booster, treeNumber);
    }

    static {
        XGBoostJavaObjFunRegistration.register();
    }

    private final class XGBoostContributionsPredictor
    implements PredictContributions {
        private final TreeSHAPPredictor<FVec> _treeSHAPPredictor;
        private final Object _workspace;

        public XGBoostContributionsPredictor(TreeSHAPPredictor<FVec> treeSHAPPredictor) {
            this._treeSHAPPredictor = treeSHAPPredictor;
            this._workspace = this._treeSHAPPredictor.makeWorkspace();
        }

        @Override
        public float[] calculateContributions(double[] input) {
            OneHotEncoderFVec row = XGBoostJavaMojoModel.this._1hotFactory.fromArray(input);
            float[] contribs = new float[XGBoostJavaMojoModel.this._nums + XGBoostJavaMojoModel.this._catOffsets[XGBoostJavaMojoModel.this._cats] + 1];
            return this._treeSHAPPredictor.calculateContributions(row, contribs, 0, -1, this._workspace);
        }
    }

    private class OneHotEncoderFVec
    implements FVec {
        private final int[] _catMap;
        private final int[] _catValues;
        private final float[] _numValues;
        private final float _notHot;

        private OneHotEncoderFVec(int[] catMap, int[] catValues, float[] numValues, float notHot) {
            this._catMap = catMap;
            this._catValues = catValues;
            this._numValues = numValues;
            this._notHot = notHot;
        }

        @Override
        public final float fvalue(int index) {
            if (index >= this._catMap.length) {
                return this._numValues[index - this._catMap.length];
            }
            boolean isHot = this._catValues[this._catMap[index]] == index;
            return isHot ? 1.0f : this._notHot;
        }
    }

    private class OneHotEncoderFactory {
        private final int[] _catMap;
        private final float _notHot;

        OneHotEncoderFactory() {
            float f = this._notHot = XGBoostJavaMojoModel.this._sparse ? Float.NaN : 0.0f;
            if (XGBoostJavaMojoModel.this._catOffsets == null) {
                this._catMap = new int[0];
            } else {
                this._catMap = new int[XGBoostJavaMojoModel.this._catOffsets[XGBoostJavaMojoModel.this._cats]];
                for (int c = 0; c < XGBoostJavaMojoModel.this._cats; ++c) {
                    for (int j = XGBoostJavaMojoModel.this._catOffsets[c]; j < XGBoostJavaMojoModel.this._catOffsets[c + 1]; ++j) {
                        this._catMap[j] = c;
                    }
                }
            }
        }

        OneHotEncoderFVec fromArray(double[] input) {
            float[] numValues = new float[XGBoostJavaMojoModel.this._nums];
            int[] catValues = new int[XGBoostJavaMojoModel.this._cats];
            GenModel.setCats(input, catValues, XGBoostJavaMojoModel.this._cats, XGBoostJavaMojoModel.this._catOffsets, XGBoostJavaMojoModel.this._useAllFactorLevels);
            for (int i = 0; i < numValues.length; ++i) {
                float val = (float)input[XGBoostJavaMojoModel.this._cats + i];
                numValues[i] = XGBoostJavaMojoModel.this._sparse && val == 0.0f ? Float.NaN : val;
            }
            return new OneHotEncoderFVec(this._catMap, catValues, numValues, this._notHot);
        }
    }
}

