/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.gbm.GradBooster;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.SharedTreeNode;
import hex.genmodel.algos.tree.SharedTreeSubgraph;
import java.io.Closeable;
import java.util.Arrays;

public abstract class XGBoostMojoModel
extends MojoModel
implements SharedTreeGraphConverter,
Closeable {
    private static final String SPACE = " ";
    public String _boosterType;
    public int _ntrees;
    public int _nums;
    public int _cats;
    public int[] _catOffsets;
    public boolean _useAllFactorLevels;
    public boolean _sparse;
    public String _featureMap;

    public XGBoostMojoModel(String[] columns, String[][] domains, String responseColumn) {
        super(columns, domains, responseColumn);
    }

    public void postReadInit() {
    }

    @Override
    public final double[] score0(double[] row, double[] preds) {
        return this.score0(row, 0.0, preds);
    }

    public static double[] toPreds(double[] in, float[] out, double[] preds, int nclasses, double[] priorClassDistrib, double defaultThreshold) {
        if (nclasses > 2) {
            for (int i = 0; i < out.length; ++i) {
                preds[1 + i] = out[i];
            }
            preds[0] = GenModel.getPrediction(preds, priorClassDistrib, in, defaultThreshold);
        } else if (nclasses == 2) {
            preds[1] = 1.0f - out[0];
            preds[2] = out[0];
            preds[0] = GenModel.getPrediction(preds, priorClassDistrib, in, defaultThreshold);
        } else {
            preds[0] = out[0];
        }
        return preds;
    }

    protected void constructSubgraph(RegTreeNode[] xgBoostNodes, SharedTreeNode sharedTreeNode, int nodeIndex, SharedTreeSubgraph sharedTreeSubgraph, boolean[] oneHotEncodedMap, boolean inclusiveNA, String[] features) {
        RegTreeNode xgBoostNode = xgBoostNodes[nodeIndex];
        if (oneHotEncodedMap[xgBoostNode.getSplitIndex()]) {
            sharedTreeNode.setSplitValue(1.0f);
        } else {
            sharedTreeNode.setSplitValue(xgBoostNode.getSplitCondition());
        }
        sharedTreeNode.setPredValue(xgBoostNode.getLeafValue());
        sharedTreeNode.setCol(xgBoostNode.getSplitIndex(), features[xgBoostNode.getSplitIndex()].split(SPACE)[1]);
        sharedTreeNode.setInclusiveNa(inclusiveNA);
        sharedTreeNode.setNodeNumber(nodeIndex);
        if (xgBoostNode.getLeftChildIndex() != -1) {
            this.constructSubgraph(xgBoostNodes, sharedTreeSubgraph.makeLeftChildNode(sharedTreeNode), xgBoostNode.getLeftChildIndex(), sharedTreeSubgraph, oneHotEncodedMap, xgBoostNode.default_left(), features);
        }
        if (xgBoostNode.getRightChildIndex() != -1) {
            this.constructSubgraph(xgBoostNodes, sharedTreeSubgraph.makeRightChildNode(sharedTreeNode), xgBoostNode.getRightChildIndex(), sharedTreeSubgraph, oneHotEncodedMap, !xgBoostNode.default_left(), features);
        }
    }

    private String[] constructFeatureMap() {
        String[] featureMapTokens = this._featureMap.split("\n");
        int nonEmptyTokenRange = featureMapTokens.length;
        for (int i = 0; i < featureMapTokens.length; ++i) {
            if (!featureMapTokens[i].trim().isEmpty()) continue;
            nonEmptyTokenRange = i + 1;
            break;
        }
        return Arrays.copyOfRange(featureMapTokens, 0, nonEmptyTokenRange);
    }

    protected boolean[] markOneHotEncodedCategoricals(String[] featureMap) {
        int numColumns = featureMap.length;
        int numCatCols = -1;
        for (int i = 0; i < featureMap.length; ++i) {
            String[] s = featureMap[i].split(SPACE);
            assert (s.length > 3);
            if (s[2].equals("i")) continue;
            numCatCols = i;
            break;
        }
        if (numCatCols == -1) {
            numCatCols = featureMap.length;
        }
        boolean[] categorical = new boolean[numColumns];
        for (int i = 0; i < numColumns; ++i) {
            if (i >= numCatCols) continue;
            categorical[i] = true;
        }
        return categorical;
    }

    protected SharedTreeGraph _computeGraph(GradBooster booster, int treeNumber) {
        if (!(booster instanceof GBTree)) {
            throw new IllegalArgumentException(String.format("Given XGBoost model is not backed by a tree-based booster. Booster class is %d", booster.getClass().getCanonicalName()));
        }
        RegTree[][] treesAndClasses = ((GBTree)booster).getGroupedTrees();
        SharedTreeGraph sharedTreeGraph = new SharedTreeGraph();
        for (int i = 0; i < treesAndClasses.length; ++i) {
            RegTree[] treesInGroup = treesAndClasses[i];
            if (treeNumber >= treesInGroup.length || treeNumber < 0) {
                throw new IllegalArgumentException(String.format("There is no such tree number for given class. Total number of trees is %d.", treesInGroup.length));
            }
            RegTreeNode[] treeNodes = treesInGroup[treeNumber].getNodes();
            assert (treeNodes.length >= 1);
            SharedTreeSubgraph sharedTreeSubgraph = sharedTreeGraph.makeSubgraph(String.format("Class %d", i));
            String[] features = this.constructFeatureMap();
            boolean[] oneHotEncodedMap = this.markOneHotEncodedCategoricals(features);
            this.constructSubgraph(treeNodes, sharedTreeSubgraph.makeRootNode(), 0, sharedTreeSubgraph, oneHotEncodedMap, true, features);
        }
        return sharedTreeGraph;
    }

    public static enum ObjectiveType {
        BINARY_LOGISTIC("binary:logistic"),
        REG_GAMMA("reg:gamma"),
        REG_TWEEDIE("reg:tweedie"),
        COUNT_POISSON("count:poisson"),
        REG_LINEAR("reg:linear"),
        MULTI_SOFTPROB("multi:softprob");

        private String _id;

        private ObjectiveType(String id) {
            this._id = id;
        }

        public String getId() {
            return this._id;
        }

        public static ObjectiveType fromXGBoost(String type) {
            for (ObjectiveType t : ObjectiveType.values()) {
                if (!t.getId().equals(type)) continue;
                return t;
            }
            return null;
        }
    }
}

