/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.tools;

import au.com.bytecode.opencsv.CSVReader;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.OrdinalModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;

public class PredictCsv {
    private String inputCSVFileName;
    private String outputCSVFileName;
    private boolean useDecimalOutput = false;
    public char separator = (char)44;
    public boolean setInvNumNA = false;
    private EasyPredictModelWrapper model;

    public static void main(String[] args) {
        PredictCsv main = new PredictCsv();
        main.parseArgs(args);
        try {
            main.run();
        }
        catch (Exception e) {
            e.printStackTrace();
            System.exit(2);
        }
        System.exit(0);
    }

    private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) {
        RowData row = new RowData();
        int maxI = Math.min(inputColumnNames.length, splitLine.length);
        block9: for (int i = 0; i < maxI; ++i) {
            String cellData;
            String columnName = inputColumnNames[i];
            switch (cellData = splitLine[i]) {
                case "": 
                case "NA": 
                case "N/A": 
                case "-": {
                    continue block9;
                }
                default: {
                    row.put(columnName, cellData);
                }
            }
        }
        return row;
    }

    private String myDoubleToString(double d) {
        if (Double.isNaN(d)) {
            return "NA";
        }
        return this.useDecimalOutput ? Double.toString(d) : Double.toHexString(d);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void run() throws Exception {
        ModelCategory category = this.model.getModelCategory();
        CSVReader reader = new CSVReader((Reader)new FileReader(this.inputCSVFileName), this.separator);
        BufferedWriter output = new BufferedWriter(new FileWriter(this.outputCSVFileName));
        int lastCommaAutoEn = -1;
        switch (category) {
            case AutoEncoder: {
                String[] cnames = this.model.m.getNames();
                int numCats = this.model.domainMap.size();
                int numNums = this.model.m.nfeatures() - numCats;
                String[][] domainValues = this.model.m.getDomainValues();
                int lastCatIdx = numCats - 1;
                for (int index = 0; index <= lastCatIdx; ++index) {
                    String[] tdomains = domainValues[index];
                    int tdomainLen = tdomains.length - 1;
                    for (int index2 = 0; index2 <= tdomainLen; ++index2) {
                        ++lastCommaAutoEn;
                        String temp = "reconstr_" + tdomains[index2];
                        output.write(temp);
                        output.write(44);
                    }
                    ++lastCommaAutoEn;
                    String temp = "reconstr_" + cnames[index] + ".missing(NA)";
                    output.write(temp);
                    if (numNums <= 0 && index >= lastCatIdx) continue;
                    output.write(44);
                }
                int lastComma = cnames.length - 1;
                for (int index = numCats; index < cnames.length; ++index) {
                    ++lastCommaAutoEn;
                    String temp = "reconstr_" + cnames[index];
                    output.write(temp);
                    if (index >= lastComma) continue;
                    output.write(44);
                }
                break;
            }
            case Binomial: 
            case Multinomial: 
            case Ordinal: {
                String[] responseDomainValues;
                output.write("predict");
                for (String s : responseDomainValues = this.model.getResponseDomainValues()) {
                    output.write(",");
                    output.write(s);
                }
                break;
            }
            case Clustering: {
                output.write("cluster");
                break;
            }
            case Regression: {
                output.write("predict");
                break;
            }
            default: {
                throw new Exception("Unknown model category " + (Object)((Object)category));
            }
        }
        output.write("\n");
        int lineNum = 1;
        try {
            String[] inputColumnNames = null;
            String[] splitLine = reader.readNext();
            if (splitLine != null) {
                inputColumnNames = splitLine;
            } else {
                throw new Exception("Input dataset file is empty!");
            }
            while ((splitLine = reader.readNext()) != null) {
                RowData row = PredictCsv.formatDataRow(splitLine, inputColumnNames);
                switch (category) {
                    case AutoEncoder: {
                        int i;
                        AbstractPrediction p = this.model.predictAutoEncoder(row);
                        for (i = 0; i < p.reconstructed.length; ++i) {
                            output.write(this.myDoubleToString(p.reconstructed[i]));
                            if (i >= lastCommaAutoEn) continue;
                            output.write(44);
                        }
                        break;
                    }
                    case Binomial: {
                        int i;
                        AbstractPrediction p = this.model.predictBinomial(row);
                        output.write(((BinomialModelPrediction)p).label);
                        output.write(",");
                        for (i = 0; i < ((BinomialModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((BinomialModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Multinomial: {
                        int i;
                        AbstractPrediction p = this.model.predictMultinomial(row);
                        output.write(((MultinomialModelPrediction)p).label);
                        output.write(",");
                        for (i = 0; i < ((MultinomialModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((MultinomialModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Ordinal: {
                        int i;
                        AbstractPrediction p = this.model.predictOrdinal(row);
                        output.write(((OrdinalModelPrediction)p).label);
                        output.write(",");
                        for (i = 0; i < ((OrdinalModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((OrdinalModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Clustering: {
                        AbstractPrediction p = this.model.predictClustering(row);
                        output.write(this.myDoubleToString(((ClusteringModelPrediction)p).cluster));
                        break;
                    }
                    case Regression: {
                        AbstractPrediction p = this.model.predictRegression(row);
                        output.write(this.myDoubleToString(((RegressionModelPrediction)p).value));
                        break;
                    }
                    default: {
                        throw new Exception("Unknown model category " + (Object)((Object)category));
                    }
                }
                output.write("\n");
                ++lineNum;
            }
        }
        catch (Exception e) {
            System.out.println("Caught exception on line " + lineNum);
            System.out.println("");
            e.printStackTrace();
            System.exit(1);
        }
        finally {
            output.close();
            reader.close();
        }
    }

    private void loadModel(String modelName) throws Exception {
        try {
            this.loadMojo(modelName);
        }
        catch (IOException e) {
            this.loadPojo(modelName);
        }
    }

    private void loadPojo(String className) throws Exception {
        GenModel genModel = (GenModel)Class.forName(className).newInstance();
        this.model = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA));
    }

    private void loadMojo(String modelName) throws IOException {
        MojoModel genModel = MojoModel.load(modelName);
        this.model = new EasyPredictModelWrapper(new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA));
    }

    private static void usage() {
        System.out.println("");
        System.out.println("Usage:  java [...java args...] hex.genmodel.tools.PredictCsv --mojo mojoName");
        System.out.println("             --pojo pojoName --input inputFile --output outputFile --separator sepStr --decimal --setConvertInvalidNum");
        System.out.println("");
        System.out.println("     --mojo    Name of the zip file containing model's MOJO.");
        System.out.println("     --pojo    Name of the java class containing the model's POJO. Either this ");
        System.out.println("               parameter or --model must be specified.");
        System.out.println("     --input   text file containing the test data set to score.");
        System.out.println("     --output  Name of the output CSV file with computed predictions.");
        System.out.println("     --separator Separator to be used in input file containing test data set.");
        System.out.println("     --decimal Use decimal numbers in the output (default is to use hexademical).");
        System.out.println("     --setConvertInvalidNum Will call .setConvertInvalidNumbersToNa(true) when loading models.");
        System.out.println("");
        System.exit(1);
    }

    private void parseArgs(String[] args) {
        try {
            String pojoMojoModelNames = "";
            int loadType = 0;
            block23: for (int i = 0; i < args.length; ++i) {
                String s = args[i];
                if (s.equals("--header")) continue;
                if (s.equals("--decimal")) {
                    this.useDecimalOutput = true;
                    continue;
                }
                if (s.equals("--setConvertInvalidNum")) {
                    this.setInvNumNA = true;
                    continue;
                }
                if (++i >= args.length) {
                    PredictCsv.usage();
                }
                String sarg = args[i];
                switch (s) {
                    case "--model": {
                        pojoMojoModelNames = sarg;
                        loadType = 2;
                        continue block23;
                    }
                    case "--mojo": {
                        pojoMojoModelNames = sarg;
                        loadType = 1;
                        continue block23;
                    }
                    case "--pojo": {
                        pojoMojoModelNames = sarg;
                        loadType = 0;
                        continue block23;
                    }
                    case "--input": {
                        this.inputCSVFileName = sarg;
                        continue block23;
                    }
                    case "--output": {
                        this.outputCSVFileName = sarg;
                        continue block23;
                    }
                    case "--separator": {
                        this.separator = sarg.charAt(sarg.length() - 1);
                        continue block23;
                    }
                    default: {
                        System.out.println("ERROR: Unknown command line argument: " + s);
                        PredictCsv.usage();
                    }
                }
            }
            switch (loadType) {
                case 0: {
                    this.loadPojo(pojoMojoModelNames);
                    break;
                }
                case 1: {
                    this.loadMojo(pojoMojoModelNames);
                    break;
                }
                case 2: {
                    this.loadModel(pojoMojoModelNames);
                }
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            PredictCsv.usage();
        }
    }
}

