/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.tools;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.gbm.GbmMojoModel;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeGraphConverter;
import hex.genmodel.algos.tree.TreeBackedMojoModel;
import hex.genmodel.tools.MojoPrinter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import water.util.JavaVersionUtils;

public class PrintMojo
implements MojoPrinter {
    protected MojoModel genModel;
    protected MojoPrinter.Format format = MojoPrinter.Format.dot;
    protected int treeToPrint = -1;
    protected int maxLevelsToPrintPerEdge = 10;
    protected boolean detail = false;
    protected String outputFileName = null;
    protected String optionalTitle = null;
    protected PrintTreeOptions pTreeOptions;
    protected boolean internal;
    protected final String tmpOutputFileName = "tmpOutputFileName.gv";

    public static void main(String[] args) {
        MojoPrinter mojoPrinter = null;
        if (JavaVersionUtils.JAVA_VERSION.isKnown() && JavaVersionUtils.JAVA_VERSION.getMajor() > 7) {
            ServiceLoader<MojoPrinter> mojoPrinters = ServiceLoader.load(MojoPrinter.class);
            for (MojoPrinter printer : mojoPrinters) {
                if (!printer.supportsFormat(PrintMojo.getFormat(args))) continue;
                mojoPrinter = printer;
            }
        } else {
            mojoPrinter = new PrintMojo();
        }
        mojoPrinter.parseArgs(args);
        try {
            mojoPrinter.run();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(2);
        }
        System.exit(0);
    }

    @Override
    public boolean supportsFormat(MojoPrinter.Format format) {
        return !MojoPrinter.Format.png.equals((Object)format);
    }

    static MojoPrinter.Format getFormat(String[] args) {
        for (int i = 0; i < args.length; ++i) {
            if (!args[i].equals("--format")) continue;
            try {
                return MojoPrinter.Format.valueOf(args[++i]);
            }
            catch (Exception e) {
                return null;
            }
        }
        return null;
    }

    private void loadMojo(String modelName) throws IOException {
        this.genModel = MojoModel.load(modelName);
    }

    protected static void usage() {
        System.out.println("Emit a human-consumable graph of a model for use with dot (graphviz).");
        System.out.println("The currently supported model types are DRF, GBM and XGBoost.");
        System.out.println();
        System.out.println("Usage:  java [...java args...] hex.genmodel.tools.PrintMojo [--tree n] [--levels n] [--title sss] [-o outputFileName]");
        System.out.println();
        System.out.println("    --format        Output format. For .png output at least Java 8 is required.");
        System.out.println("                    dot|json|raw|png [default dot]");
        System.out.println();
        System.out.println("    --tree          Tree number to print.");
        System.out.println("                    [default all]");
        System.out.println();
        System.out.println("    --levels        Number of levels per edge to print.");
        System.out.println("                    [default 10]");
        System.out.println();
        System.out.println("    --title         (Optional) Force title of tree graph.");
        System.out.println();
        System.out.println("    --detail        Specify to print additional detailed information like node numbers.");
        System.out.println();
        System.out.println("    --input | -i    Input mojo file.");
        System.out.println();
        System.out.println("    --output | -o   Output filename. Taken as a directory name in case of .png format and multiple trees to visualize.");
        System.out.println("                    [default stdout]");
        System.out.println("    --decimalplaces | -d    Set decimal places of all numerical values.");
        System.out.println();
        System.out.println("    --fontsize | -f    Set font sizes of strings.");
        System.out.println();
        System.out.println("    --internal    Internal H2O representation of the decision tree (splits etc.) is used for generating the GRAPHVIZ format.");
        System.out.println();
        System.out.println();
        System.out.println("Example:");
        System.out.println();
        System.out.println("    (brew install graphviz)");
        System.out.println("    java -cp h2o.jar hex.genmodel.tools.PrintMojo --tree 0 -i model_mojo.zip -o model.gv -f 20 -d 3");
        System.out.println("    dot -Tpng model.gv -o model.png");
        System.out.println("    open model.png");
        System.out.println();
        System.exit(1);
    }

    @Override
    public void parseArgs(String[] args) {
        int nPlaces = -1;
        int fontSize = 14;
        boolean setDecimalPlaces = false;
        try {
            block38: for (int i = 0; i < args.length; ++i) {
                String s;
                switch (s = args[i]) {
                    case "--format": {
                        if (++i >= args.length) {
                            PrintMojo.usage();
                        }
                        s = args[i];
                        try {
                            this.format = MojoPrinter.Format.valueOf(s);
                        }
                        catch (Exception e) {
                            System.out.println("ERROR: invalid --format argument (" + s + ")");
                            System.exit(1);
                        }
                        continue block38;
                    }
                    case "--tree": {
                        if (++i >= args.length) {
                            PrintMojo.usage();
                        }
                        s = args[i];
                        try {
                            this.treeToPrint = Integer.parseInt(s);
                        }
                        catch (Exception e) {
                            System.out.println("ERROR: invalid --tree argument (" + s + ")");
                            System.exit(1);
                        }
                        continue block38;
                    }
                    case "--levels": {
                        if (++i >= args.length) {
                            PrintMojo.usage();
                        }
                        s = args[i];
                        try {
                            this.maxLevelsToPrintPerEdge = Integer.parseInt(s);
                        }
                        catch (Exception e) {
                            System.out.println("ERROR: invalid --levels argument (" + s + ")");
                            System.exit(1);
                        }
                        continue block38;
                    }
                    case "--title": {
                        if (++i >= args.length) {
                            PrintMojo.usage();
                        }
                        this.optionalTitle = args[i];
                        continue block38;
                    }
                    case "--detail": {
                        this.detail = true;
                        continue block38;
                    }
                    case "--input": 
                    case "-i": {
                        if (++i >= args.length) {
                            PrintMojo.usage();
                        }
                        s = args[i];
                        this.loadMojo(s);
                        continue block38;
                    }
                    case "--fontsize": 
                    case "-f": {
                        if (++i >= args.length) {
                            PrintMojo.usage();
                        }
                        s = args[i];
                        fontSize = Integer.parseInt(s);
                        continue block38;
                    }
                    case "--decimalplaces": 
                    case "-d": {
                        if (++i >= args.length) {
                            PrintMojo.usage();
                        }
                        setDecimalPlaces = true;
                        s = args[i];
                        nPlaces = Integer.parseInt(s);
                        continue block38;
                    }
                    case "--raw": {
                        this.format = MojoPrinter.Format.raw;
                        continue block38;
                    }
                    case "--internal": {
                        this.internal = true;
                        continue block38;
                    }
                    case "-o": 
                    case "--output": {
                        if (++i >= args.length) {
                            PrintMojo.usage();
                        }
                        this.outputFileName = args[i];
                        continue block38;
                    }
                    default: {
                        System.out.println("ERROR: Unknown command line argument: " + s);
                        PrintMojo.usage();
                    }
                }
            }
            this.pTreeOptions = new PrintTreeOptions(setDecimalPlaces, nPlaces, fontSize, this.internal);
        }
        catch (Exception e) {
            e.printStackTrace();
            PrintMojo.usage();
        }
    }

    protected void validateArgs() {
        if (this.genModel == null) {
            System.out.println("ERROR: Must specify -i");
            PrintMojo.usage();
        }
    }

    @Override
    public void run() throws Exception {
        this.validateArgs();
        PrintStream os = this.outputFileName != null ? new PrintStream(new FileOutputStream(new File(this.outputFileName))) : System.out;
        if (this.genModel instanceof SharedTreeGraphConverter) {
            SharedTreeGraphConverter treeBackedModel = (SharedTreeGraphConverter)((Object)this.genModel);
            ConvertTreeOptions options = new ConvertTreeOptions().withTreeConsistencyCheckEnabled();
            SharedTreeGraph g = treeBackedModel.convert(this.treeToPrint, null, options);
            switch (this.format) {
                case raw: {
                    g.print();
                    break;
                }
                case dot: {
                    g.printDot(os, this.maxLevelsToPrintPerEdge, this.detail, this.optionalTitle, this.pTreeOptions);
                    break;
                }
                case json: {
                    if (!(treeBackedModel instanceof TreeBackedMojoModel)) {
                        System.out.println("ERROR: Printing XGBoost MOJO as JSON not supported");
                        System.exit(1);
                    }
                    this.printJson((TreeBackedMojoModel)treeBackedModel, g, os);
                }
            }
        } else {
            System.out.println("ERROR: Unsupported MOJO type");
            System.exit(1);
        }
    }

    private Map<String, Object> getParamsAsJson(TreeBackedMojoModel tree) {
        LinkedHashMap<String, Object> params = new LinkedHashMap<String, Object>();
        params.put("h2o_version", this.genModel._h2oVersion);
        params.put("mojo_version", this.genModel._mojo_version);
        params.put("algo", this.genModel._algoName);
        params.put("model_category", this.genModel._category.toString());
        params.put("classifier", this.genModel.isClassifier());
        params.put("supervised", this.genModel._supervised);
        params.put("nfeatures", this.genModel._nfeatures);
        params.put("nclasses", this.genModel._nclasses);
        params.put("balance_classes", this.genModel._balanceClasses);
        params.put("n_tree_groups", tree.getNTreeGroups());
        params.put("n_trees_in_group", tree.getNTreesPerGroup());
        params.put("base_score", tree.getInitF());
        if (this.genModel.isClassifier()) {
            String[] responseValues = this.genModel.getDomainValues(this.genModel.getResponseIdx());
            params.put("class_labels", responseValues);
        }
        if (this.genModel instanceof GbmMojoModel) {
            GbmMojoModel m = (GbmMojoModel)this.genModel;
            params.put("family", m._family.toString());
            params.put("link_function", m._link_function.toString());
        }
        return params;
    }

    private List<Object> getDomainValuesAsJSON() {
        ArrayList<Object> domainValues = new ArrayList<Object>();
        String[][] values = this.genModel.getDomainValues();
        for (int i = 0; i < values.length - 1; ++i) {
            if (values[i] == null) continue;
            LinkedHashMap<String, Object> colValuesObject = new LinkedHashMap<String, Object>();
            colValuesObject.put("colId", i);
            colValuesObject.put("colName", this.genModel._names[i]);
            colValuesObject.put("values", values[i]);
            domainValues.add(colValuesObject);
        }
        return domainValues;
    }

    private void printJson(TreeBackedMojoModel mojo, SharedTreeGraph trees, PrintStream os) {
        LinkedHashMap<String, Object> json = new LinkedHashMap<String, Object>();
        json.put("params", this.getParamsAsJson(mojo));
        json.put("domainValues", this.getDomainValuesAsJSON());
        json.put("trees", trees.toJson());
        if (this.optionalTitle != null) {
            json.put("title", this.optionalTitle);
        }
        Gson gson = new GsonBuilder().setPrettyPrinting().create();
        os.print(gson.toJson(json));
    }

    public static class PrintTreeOptions {
        public boolean _setDecimalPlace;
        public int _nPlaces;
        public int _fontSize;
        public boolean _internal;

        public PrintTreeOptions(boolean setdecimalplaces, int nplaces, int fontsize, boolean internal) {
            this._setDecimalPlace = setdecimalplaces;
            this._nPlaces = this._setDecimalPlace ? nplaces : this._nPlaces;
            this._fontSize = fontsize;
            this._internal = internal;
        }

        public float roundNPlace(float value) {
            if (this._nPlaces < 0) {
                return value;
            }
            double sc = Math.pow(10.0, this._nPlaces);
            return (float)((double)Math.round((double)value * sc) / sc);
        }
    }
}

