/*
 * Decompiled with CFR 0.152.
 */
package hex.tree.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.Dart;
import biz.k11i.xgboost.gbm.GBLinear;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.RegTreeNode;
import hex.LinkFunction;
import hex.LinkFunctionFactory;
import hex.genmodel.algos.xgboost.XGBoostMojoModel;
import hex.genmodel.utils.LinkFunctionType;
import hex.tree.xgboost.XGBoostOutput;
import water.codegen.CodeGenerator;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.util.SBPrintStream;

public abstract class XGBoostPojoWriter {
    protected final Predictor _p;
    protected final String _namePrefix;
    protected final XGBoostOutput _output;
    private final double _defaultThreshold;

    public static XGBoostPojoWriter make(Predictor p, String namePrefix, XGBoostOutput output, double defaultThreshold) {
        if (p.getBooster() instanceof GBTree) {
            return new XGBoostPojoTreeWriter(p, namePrefix, output, defaultThreshold);
        }
        return new XGBoostPojoLinearWriter(p, namePrefix, output, defaultThreshold);
    }

    protected XGBoostPojoWriter(Predictor p, String namePrefix, XGBoostOutput output, double defaultThreshold) {
        this._p = p;
        this._namePrefix = namePrefix;
        this._output = output;
        this._defaultThreshold = defaultThreshold;
    }

    protected String getFeatureAccessor(int idx) {
        if (idx >= this._output._catOffsets[this._output._cats]) {
            int colIdx = idx - this._output._catOffsets[this._output._cats] + this._output._cats;
            if (this._output._sparse) {
                return "(data[" + colIdx + "] == 0 ? Double.NaN : data[" + colIdx + "])";
            }
            return "data[" + colIdx + "]";
        }
        int colIdx = 0;
        while (idx >= this._output._catOffsets[colIdx + 1]) {
            ++colIdx;
        }
        int colValue = idx - this._output._catOffsets[colIdx];
        return "(data[" + colIdx + "] == " + colValue + " ? 1 : " + (this._output._sparse ? "NaN" : "0") + ")";
    }

    private void renderPredTransformViaLinkFunction(LinkFunctionType type, SBPrintStream sb) {
        LinkFunction lf = LinkFunctionFactory.getLinkFunction((LinkFunctionType)type);
        sb.ip("preds[0] = (float) ").p(lf.linkInvString("preds[0]")).p(";").nl();
    }

    private void renderMultiClassPredTransform(SBPrintStream sb) {
        sb.ip("double max = preds[0];").nl();
        sb.ip("for (int i = 1; i < preds.length-1; i++) max = Math.max(preds[i], max); ").nl();
        sb.ip("double sum = 0.0D;").nl();
        sb.ip("for (int i = 0; i < preds.length-1; i++) {").nl();
        sb.ip("  preds[i] = Math.exp(preds[i] - max);").nl();
        sb.ip("  sum += preds[i];").nl();
        sb.ip("}").nl();
        sb.ip("for (int i = 0; i < preds.length-1; i++) {").nl();
        sb.ip("  preds[i] /= (float) sum;").nl();
        sb.ip("}").nl();
    }

    private void renderPredTransform(SBPrintStream sb) {
        String objFunction = this._p.getObjName();
        if (XGBoostMojoModel.ObjectiveType.REG_GAMMA.getId().equals(objFunction) || XGBoostMojoModel.ObjectiveType.REG_TWEEDIE.getId().equals(objFunction) || XGBoostMojoModel.ObjectiveType.COUNT_POISSON.getId().equals(objFunction)) {
            this.renderPredTransformViaLinkFunction(LinkFunctionType.log, sb);
        } else if (XGBoostMojoModel.ObjectiveType.BINARY_LOGISTIC.getId().equals(objFunction)) {
            this.renderPredTransformViaLinkFunction(LinkFunctionType.logit, sb);
        } else if (XGBoostMojoModel.ObjectiveType.REG_LINEAR.getId().equals(objFunction) || XGBoostMojoModel.ObjectiveType.RANK_PAIRWISE.getId().equals(objFunction)) {
            this.renderPredTransformViaLinkFunction(LinkFunctionType.identity, sb);
        } else if (XGBoostMojoModel.ObjectiveType.MULTI_SOFTPROB.getId().equals(objFunction)) {
            this.renderMultiClassPredTransform(sb);
        } else {
            throw new IllegalArgumentException("Unexpected objFunction " + objFunction);
        }
    }

    private void renderPredPostProcess(SBPrintStream sb) {
        if (this._output.nclasses() > 2) {
            sb.ip("for (int i = preds.length-2; i >= 0; i--)").nl();
            sb.ip("  preds[1 + i] = preds[i];").nl();
            sb.ip("preds[0] = GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, ").pj(this._defaultThreshold).p(");").nl();
        } else if (this._output.nclasses() == 2) {
            sb.ip("preds[1] = 1f - preds[0];").nl();
            sb.ip("preds[2] = preds[0];").nl();
            sb.ip("preds[0] = GenModel.getPrediction(preds, PRIOR_CLASS_DISTRIB, data, ").pj(this._defaultThreshold).p(");").nl();
        }
    }

    public void renderJavaPredictBody(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
        this.renderComputePredict(sb, fileCtx);
        this.renderPredTransform(sb);
        this.renderPredPostProcess(sb);
    }

    protected abstract void renderComputePredict(SBPrintStream var1, CodeGeneratorPipeline var2);

    static class XGBoostPojoLinearWriter
    extends XGBoostPojoWriter {
        protected XGBoostPojoLinearWriter(Predictor p, String namePrefix, XGBoostOutput output, double defaultThreshold) {
            super(p, namePrefix, output, defaultThreshold);
        }

        @Override
        public void renderComputePredict(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
            GBLinear booster = (GBLinear)this._p.getBooster();
            for (int gidx = 0; gidx < booster.getNumOutputGroup(); ++gidx) {
                sb.ip("preds[").p(gidx).p("] =").nl();
                sb.ii(1);
                for (int fid = 0; fid < booster.getNumFeature(); ++fid) {
                    String accessor = this.getFeatureAccessor(fid);
                    sb.ip("(Double.isNaN(").p(accessor).p(") ? 0 : (").pj(booster.weight(fid, gidx)).p(" * ").p(accessor).p(")) + ").nl();
                }
                sb.ip("").pj(booster.bias(gidx)).p(" +").nl();
                sb.ip("").pj(this._p.getBaseScore()).p(";").nl();
                sb.di(1);
            }
        }
    }

    static class XGBoostPojoTreeWriter
    extends XGBoostPojoWriter {
        protected XGBoostPojoTreeWriter(Predictor p, String namePrefix, XGBoostOutput output, double defaultThreshold) {
            super(p, namePrefix, output, defaultThreshold);
        }

        @Override
        public void renderComputePredict(SBPrintStream sb, CodeGeneratorPipeline fileCtx) {
            GBTree booster = (GBTree)this._p.getBooster();
            Dart dartBooster = null;
            if (booster instanceof Dart) {
                dartBooster = (Dart)booster;
            }
            RegTree[][] trees = booster.getGroupedTrees();
            for (int gidx = 0; gidx < trees.length; ++gidx) {
                sb.ip("float preds_").p(gidx).p(" = 0f;").nl();
                for (int tidx = 0; tidx < trees[gidx].length; ++tidx) {
                    String treeClassName = this.renderTreeClass(trees, gidx, tidx, dartBooster, fileCtx);
                    sb.ip("preds_").p(gidx).p(" += ").p(treeClassName).p(".score0(data);").nl();
                }
                sb.ip("preds_").p(gidx).p(" += ").pj(this._p.getBaseScore()).p(";").nl();
                sb.ip("preds[").p(gidx).p("] = preds_").p(gidx).p(";").nl();
            }
        }

        private String renderTreeClass(RegTree[][] trees, int gidx, final int tidx, final Dart dart, CodeGeneratorPipeline fileCtx) {
            final RegTree tree = trees[gidx][tidx];
            final String className = this._namePrefix + "_Tree_g_" + gidx + "_t_" + tidx;
            fileCtx.add((Object)new CodeGenerator(){

                public void generate(JCodeSB sb) {
                    sb.nl().p("class ").p(className).p(" {").nl();
                    sb.ii(1);
                    sb.ip("static float score0(double[] data) {").nl();
                    sb.ii(1);
                    sb.ip("return ");
                    if (dart != null) {
                        sb.pj(dart.weight(tidx)).p(" * ");
                    }
                    XGBoostPojoTreeWriter.this.renderTree(sb, tree, 0);
                    sb.p(";").nl();
                    sb.di(1);
                    sb.ip("}").nl();
                    sb.di(1);
                    sb.ip("}").nl();
                }
            });
            return className;
        }

        private void renderTree(JCodeSB sb, RegTree tree, int nidx) {
            RegTreeNode node = tree.getNodes()[nidx];
            if (node.isLeaf()) {
                sb.ip("").pj(node.getLeafValue());
            } else {
                int falseChild;
                int trueChild;
                String operator;
                String accessor = this.getFeatureAccessor(node.getSplitIndex());
                if (node.default_left()) {
                    operator = " < ";
                    trueChild = node.getLeftChildIndex();
                    falseChild = node.getRightChildIndex();
                } else {
                    operator = " >= ";
                    trueChild = node.getRightChildIndex();
                    falseChild = node.getLeftChildIndex();
                }
                sb.ip("((Double.isNaN(").p(accessor).p(") || ((float)").p(accessor).p(")").p(operator).pj(node.getSplitCondition()).p(") ?").nl();
                sb.ii(1);
                this.renderTree(sb, tree, trueChild);
                sb.nl().ip(":").nl();
                this.renderTree(sb, tree, falseChild);
                sb.di(1);
                sb.nl().ip(")");
            }
        }
    }
}

