/*
 * Decompiled with CFR 0.152.
 */
package hex.tree;

import hex.Model;
import hex.PojoWriter;
import hex.genmodel.CategoricalEncoding;
import hex.tree.CompressedTree;
import hex.tree.TreeJCodeGen;
import hex.tree.TreeStats;
import water.Key;
import water.codegen.CodeGeneratorPipeline;
import water.exceptions.JCodeSB;
import water.util.JCodeGen;
import water.util.PojoUtils;
import water.util.SB;
import water.util.SBPrintStream;

public abstract class SharedTreePojoWriter
implements PojoWriter {
    protected final Key<?> _modelKey;
    protected final Model.Output _output;
    protected final CategoricalEncoding _encoding;
    protected final boolean _binomialOpt;
    protected final CompressedTree[][] _trees;
    protected final TreeStats _treeStats;

    protected SharedTreePojoWriter(Key<?> modelKey, Model.Output output, CategoricalEncoding encoding, boolean binomialOpt, CompressedTree[][] trees, TreeStats treeStats) {
        this._modelKey = modelKey;
        this._output = output;
        this._encoding = encoding;
        this._binomialOpt = binomialOpt;
        this._trees = trees;
        this._treeStats = treeStats;
    }

    @Override
    public boolean toJavaCheckTooBig() {
        return this._treeStats == null || (float)this._treeStats._num_trees * this._treeStats._mean_leaves > 1000000.0f;
    }

    @Override
    public SBPrintStream toJavaInit(SBPrintStream sb, CodeGeneratorPipeline fileContext) {
        sb.nl();
        sb.ip("public boolean isSupervised() { return true; }").nl();
        sb.ip("public int nfeatures() { return " + this._output.nfeatures() + "; }").nl();
        sb.ip("public int nclasses() { return " + this._output.nclasses() + "; }").nl();
        if (this._encoding == CategoricalEncoding.Eigen) {
            sb.ip("public double[] getOrigProjectionArray() { return " + PojoUtils.toJavaDoubleArray(this._output._orig_projection_array) + "; }").nl();
        }
        if (this._encoding != CategoricalEncoding.AUTO) {
            sb.ip("public hex.genmodel.CategoricalEncoding getCategoricalEncoding() { return hex.genmodel.CategoricalEncoding." + this._encoding.name() + "; }").nl();
        }
        return sb;
    }

    @Override
    public void toJavaPredictBody(SBPrintStream body, CodeGeneratorPipeline classCtx, CodeGeneratorPipeline fileCtx, boolean verboseCode) {
        int nclass = this._output.nclasses();
        body.ip("java.util.Arrays.fill(preds,0);").nl();
        String mname = JCodeGen.toJavaId(this._modelKey.toString());
        int t2 = 0;
        while (t2 < this._trees.length) {
            SharedTreePojoWriter.toJavaForestName(body.i(), mname, t2).p(".score0(data,preds);").nl();
            int treeIdx = t2++;
            fileCtx.add(out -> {
                try {
                    int c2;
                    out.nl();
                    SharedTreePojoWriter.toJavaForestName(out.ip("class "), mname, treeIdx).p(" {").nl().ii(1);
                    out.ip("public static void score0(double[] fdata, double[] preds) {").nl().ii(1);
                    for (c2 = 0; c2 < nclass; ++c2) {
                        if (this._trees[treeIdx][c2] == null || this._binomialOpt && c2 == 1 && nclass == 2) continue;
                        SharedTreePojoWriter.toJavaTreeName(out.ip("preds[").p(nclass == 1 ? 0 : c2 + 1).p("] += "), mname, treeIdx, c2).p(".score0(fdata);").nl();
                    }
                    out.di(1).ip("}").nl();
                    out.di(1).ip("}").nl();
                    for (c2 = 0; c2 < nclass; ++c2) {
                        if (this._trees[treeIdx][c2] == null || this._binomialOpt && c2 == 1 && nclass == 2) continue;
                        String javaClassName = SharedTreePojoWriter.toJavaTreeName(new SB(), mname, treeIdx, c2).toString();
                        SB sb = new SB();
                        new TreeJCodeGen(this._output, this._trees[treeIdx][c2], sb, javaClassName, verboseCode).generate();
                        out.p(sb);
                    }
                }
                catch (Exception e2) {
                    throw new RuntimeException("Internal error creating the POJO.", e2);
                }
            });
        }
        this.toJavaUnifyPreds(body);
    }

    protected abstract void toJavaUnifyPreds(SBPrintStream var1);

    private static <T extends JCodeSB<T>> T toJavaTreeName(T sb, String mname, int t2, int c2) {
        return sb.p(mname).p("_Tree_").p(t2).p("_class_").p(c2);
    }

    private static <T extends JCodeSB<T>> T toJavaForestName(T sb, String mname, int t2) {
        return sb.p(mname).p("_Forest_").p(t2);
    }
}

