/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.tools;

import a.a.a.a.b;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glrm.GlrmMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.AnomalyDetectionPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.CoxPHModelPrediction;
import hex.genmodel.easy.prediction.DimReductionModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.OrdinalModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;

public class PredictCsv {
    private String inputCSVFileName;
    private String outputCSVFileName;
    private boolean useDecimalOutput = false;
    public char separator = (char)44;
    public boolean setInvNumNA = false;
    public boolean getTreePath = false;
    public boolean predictContributions = false;
    boolean returnGLRMReconstruct = false;
    public int glrmIterNumber = -1;
    private EasyPredictModelWrapper model;

    public static void main(String[] args) {
        PredictCsv predictCsv = PredictCsv.make(args, null);
        try {
            predictCsv.run();
        }
        catch (Exception exception) {
            System.out.println("Predict error: " + exception.getMessage());
            System.out.println();
            exception.printStackTrace();
            System.exit(1);
        }
        System.exit(0);
    }

    public static PredictCsv make(String[] args, GenModel model) {
        PredictCsv predictCsv = new PredictCsv();
        predictCsv.parseArgs(args);
        if (model != null) {
            try {
                predictCsv.setModel(model);
            }
            catch (IOException iOException) {
                throw new RuntimeException(iOException);
            }
        }
        return predictCsv;
    }

    private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) {
        RowData rowData = new RowData();
        int n2 = Math.min(inputColumnNames.length, splitLine.length);
        block9: for (int i2 = 0; i2 < n2; ++i2) {
            String string;
            String string2 = inputColumnNames[i2];
            switch (string = splitLine[i2]) {
                case "": 
                case "NA": 
                case "N/A": 
                case "-": {
                    continue block9;
                }
                default: {
                    rowData.put(string2, string);
                }
            }
        }
        return rowData;
    }

    private String myDoubleToString(double d2) {
        if (Double.isNaN(d2)) {
            return "NA";
        }
        if (this.useDecimalOutput) {
            return Double.toString(d2);
        }
        return Double.toHexString(d2);
    }

    private void writeTreePathNames(BufferedWriter output) throws Exception {
        String[] stringArray = ((SharedTreeMojoModel)this.model.m).getDecisionPathNames();
        this.writeColumnNames(output, stringArray);
    }

    private void writeContributionNames(BufferedWriter output) throws Exception {
        this.writeColumnNames(output, this.model.getContributionNames());
    }

    private void writeColumnNames(BufferedWriter output, String[] columnNames) throws Exception {
        int n2 = columnNames.length - 1;
        for (int i2 = 0; i2 < n2; ++i2) {
            output.write(columnNames[i2]);
            output.write(",");
        }
        output.write(columnNames[n2]);
    }

    public void run() throws Exception {
        String[] stringArray;
        int n2;
        String[] stringArray2;
        ModelCategory modelCategory = this.model.getModelCategory();
        b b2 = new b(new FileReader(this.inputCSVFileName), this.separator);
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this.outputCSVFileName));
        switch (modelCategory) {
            case Binomial: 
            case Multinomial: 
            case Regression: {
                if (this.getTreePath) {
                    this.writeTreePathNames(bufferedWriter);
                    break;
                }
                if (this.predictContributions) {
                    this.writeContributionNames(bufferedWriter);
                    break;
                }
                PredictCsv predictCsv = this;
                predictCsv.writeHeader(predictCsv.model.m.getOutputNames(), bufferedWriter);
                break;
            }
            case DimReduction: {
                if (this.returnGLRMReconstruct) {
                    stringArray2 = this.model.m.getNames();
                    n2 = ((GlrmMojoModel)this.model.m)._permutation.length;
                    stringArray = "reconstr_";
                    int n3 = n2 - 1;
                    for (int i2 = 0; i2 < n2; ++i2) {
                        String string = this.returnGLRMReconstruct ? (String)stringArray + stringArray2[i2] : (String)stringArray + (i2 + 1);
                        bufferedWriter.write(string);
                        if (i2 >= n3) continue;
                        bufferedWriter.write(44);
                    }
                    break;
                }
                PredictCsv predictCsv = this;
                predictCsv.writeHeader(predictCsv.model.m.getOutputNames(), bufferedWriter);
                break;
            }
            default: {
                PredictCsv predictCsv = this;
                predictCsv.writeHeader(predictCsv.model.m.getOutputNames(), bufferedWriter);
            }
        }
        bufferedWriter.write("\n");
        n2 = 1;
        try {
            stringArray2 = b2.a();
            if (stringArray2 != null) {
                stringArray = stringArray2;
                this.checkMissingColumns(stringArray);
            } else {
                throw new Exception("Input dataset file is empty!");
            }
            while ((stringArray2 = b2.a()) != null) {
                RowData rowData = PredictCsv.formatDataRow(stringArray2, stringArray);
                switch (modelCategory) {
                    case AutoEncoder: {
                        AbstractPrediction abstractPrediction = this.model.predictAutoEncoder(rowData);
                        for (int i3 = 0; i3 < abstractPrediction.reconstructed.length; ++i3) {
                            bufferedWriter.write(this.myDoubleToString(abstractPrediction.reconstructed[i3]));
                            if (i3 >= -1) continue;
                            bufferedWriter.write(44);
                        }
                        break;
                    }
                    case Binomial: {
                        AbstractPrediction abstractPrediction = this.model.predictBinomial(rowData);
                        if (this.getTreePath) {
                            this.writeTreePaths(((BinomialModelPrediction)abstractPrediction).leafNodeAssignments, bufferedWriter);
                            break;
                        }
                        if (this.predictContributions) {
                            this.writeContributions(((BinomialModelPrediction)abstractPrediction).contributions, bufferedWriter);
                            break;
                        }
                        bufferedWriter.write(((BinomialModelPrediction)abstractPrediction).label);
                        bufferedWriter.write(",");
                        for (int i4 = 0; i4 < ((BinomialModelPrediction)abstractPrediction).classProbabilities.length; ++i4) {
                            if (i4 > 0) {
                                bufferedWriter.write(",");
                            }
                            bufferedWriter.write(this.myDoubleToString(((BinomialModelPrediction)abstractPrediction).classProbabilities[i4]));
                        }
                        break;
                    }
                    case Multinomial: {
                        AbstractPrediction abstractPrediction = this.model.predictMultinomial(rowData);
                        if (this.getTreePath) {
                            this.writeTreePaths(((MultinomialModelPrediction)abstractPrediction).leafNodeAssignments, bufferedWriter);
                            break;
                        }
                        bufferedWriter.write(((MultinomialModelPrediction)abstractPrediction).label);
                        bufferedWriter.write(",");
                        for (int i5 = 0; i5 < ((MultinomialModelPrediction)abstractPrediction).classProbabilities.length; ++i5) {
                            if (i5 > 0) {
                                bufferedWriter.write(",");
                            }
                            bufferedWriter.write(this.myDoubleToString(((MultinomialModelPrediction)abstractPrediction).classProbabilities[i5]));
                        }
                        break;
                    }
                    case Ordinal: {
                        AbstractPrediction abstractPrediction = this.model.predictOrdinal(rowData);
                        bufferedWriter.write(((OrdinalModelPrediction)abstractPrediction).label);
                        bufferedWriter.write(",");
                        for (int i6 = 0; i6 < ((OrdinalModelPrediction)abstractPrediction).classProbabilities.length; ++i6) {
                            if (i6 > 0) {
                                bufferedWriter.write(",");
                            }
                            bufferedWriter.write(this.myDoubleToString(((OrdinalModelPrediction)abstractPrediction).classProbabilities[i6]));
                        }
                        break;
                    }
                    case Clustering: {
                        AbstractPrediction abstractPrediction = this.model.predictClustering(rowData);
                        bufferedWriter.write(this.myDoubleToString(((ClusteringModelPrediction)abstractPrediction).cluster));
                        break;
                    }
                    case Regression: {
                        AbstractPrediction abstractPrediction = this.model.predictRegression(rowData);
                        if (this.getTreePath) {
                            this.writeTreePaths(((RegressionModelPrediction)abstractPrediction).leafNodeAssignments, bufferedWriter);
                            break;
                        }
                        if (this.predictContributions) {
                            this.writeContributions(((RegressionModelPrediction)abstractPrediction).contributions, bufferedWriter);
                            break;
                        }
                        bufferedWriter.write(this.myDoubleToString(((RegressionModelPrediction)abstractPrediction).value));
                        break;
                    }
                    case CoxPH: {
                        AbstractPrediction abstractPrediction = this.model.predictCoxPH(rowData);
                        bufferedWriter.write(this.myDoubleToString(((CoxPHModelPrediction)abstractPrediction).value));
                        break;
                    }
                    case DimReduction: {
                        AbstractPrediction abstractPrediction = this.model.predictDimReduction(rowData);
                        double[] dArray = this.returnGLRMReconstruct ? ((DimReductionModelPrediction)abstractPrediction).reconstructed : ((DimReductionModelPrediction)abstractPrediction).dimensions;
                        int n4 = dArray.length - 1;
                        for (int i7 = 0; i7 < dArray.length; ++i7) {
                            bufferedWriter.write(this.myDoubleToString(dArray[i7]));
                            if (i7 >= n4) continue;
                            bufferedWriter.write(44);
                        }
                        break;
                    }
                    case AnomalyDetection: {
                        int n4;
                        AbstractPrediction abstractPrediction = this.model.predictAnomalyDetection(rowData);
                        double[] dArray = ((AnomalyDetectionPrediction)abstractPrediction).toPreds();
                        for (n4 = 0; n4 < dArray.length - 1; ++n4) {
                            bufferedWriter.write(this.myDoubleToString(dArray[n4]));
                            bufferedWriter.write(44);
                        }
                        bufferedWriter.write(this.myDoubleToString(dArray[dArray.length - 1]));
                        break;
                    }
                    default: {
                        throw new Exception("Unknown model category " + (Object)((Object)modelCategory));
                    }
                }
                bufferedWriter.write("\n");
                ++n2;
            }
            return;
        }
        catch (Exception exception) {
            throw new Exception("Prediction failed on line " + n2, exception);
        }
        finally {
            bufferedWriter.close();
            b2.close();
        }
    }

    private void writeHeader(String[] colNames, BufferedWriter output) throws Exception {
        output.write(colNames[0]);
        for (int i2 = 1; i2 < colNames.length; ++i2) {
            output.write(",");
            output.write(colNames[i2]);
        }
    }

    private void writeTreePaths(String[] treePaths, BufferedWriter output) throws Exception {
        int n2 = treePaths.length - 1;
        for (int i2 = 0; i2 < n2; ++i2) {
            output.write(treePaths[i2]);
            output.write(",");
        }
        output.write(treePaths[n2]);
    }

    private void writeContributions(float[] contributions, BufferedWriter output) throws Exception {
        for (int i2 = 0; i2 < contributions.length; ++i2) {
            if (i2 > 0) {
                output.write(",");
            }
            output.write(this.myDoubleToString(contributions[i2]));
        }
    }

    private void loadModel(String modelName) throws Exception {
        try {
            this.loadMojo(modelName);
            return;
        }
        catch (IOException iOException) {
            this.loadPojo(modelName);
            return;
        }
    }

    private void setModel(GenModel genModel) throws IOException {
        EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA);
        if (this.getTreePath) {
            config.setEnableLeafAssignment(true);
        }
        if (this.predictContributions) {
            config.setEnableContributions(true);
        }
        if (this.returnGLRMReconstruct) {
            config.setEnableGLRMReconstrut(true);
        }
        this.model = new EasyPredictModelWrapper(config);
    }

    private void loadPojo(String className) throws Exception {
        GenModel genModel = (GenModel)Class.forName(className).newInstance();
        this.setModel(genModel);
    }

    private void loadMojo(String modelName) throws IOException {
        MojoModel mojoModel = MojoModel.load(modelName);
        EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config().setModel(mojoModel).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA);
        if (this.getTreePath) {
            config.setEnableLeafAssignment(true);
        }
        if (this.predictContributions) {
            config.setEnableContributions(true);
        }
        if (this.returnGLRMReconstruct) {
            config.setEnableGLRMReconstrut(true);
        }
        if (this.glrmIterNumber > 0) {
            config.setGLRMIterNumber(this.glrmIterNumber);
        }
        this.model = new EasyPredictModelWrapper(config);
    }

    private static void usage() {
        System.out.println();
        System.out.println("Usage:  java [...java args...] hex.genmodel.tools.PredictCsv --mojo mojoName");
        System.out.println("             --pojo pojoName --input inputFile --output outputFile --separator sepStr --decimal --setConvertInvalidNum");
        System.out.println();
        System.out.println("     --mojo    Name of the zip file containing model's MOJO.");
        System.out.println("     --pojo    Name of the java class containing the model's POJO. Either this ");
        System.out.println("               parameter or --model must be specified.");
        System.out.println("     --input   text file containing the test data set to score.");
        System.out.println("     --output  Name of the output CSV file with computed predictions.");
        System.out.println("     --separator Separator to be used in input file containing test data set.");
        System.out.println("     --decimal Use decimal numbers in the output (default is to use hexademical).");
        System.out.println("     --setConvertInvalidNum Will call .setConvertInvalidNumbersToNa(true) when loading models.");
        System.out.println("     --leafNodeAssignment will show the leaf node assignment for tree based models instead of prediction results");
        System.out.println("     --predictContributions will output prediction contributions (Shapley values) for tree based models instead of regular model predictions");
        System.out.println("     --glrmReconstruct will return the reconstructed dataset for GLRM mojo instead of X factor derived from the dataset.");
        System.out.println("     --glrmIterNumber integer indicating number of iterations to go through when constructing X factor derived from the dataset.");
        System.out.println();
        System.exit(1);
    }

    private void checkMissingColumns(String[] parsedColumnNamesArr) {
        String[] stringArray = this.model.m._names;
        HashSet<String> hashSet = new HashSet<String>(parsedColumnNamesArr.length);
        for (int i2 = 0; i2 < parsedColumnNamesArr.length; ++i2) {
            hashSet.add(parsedColumnNamesArr[i2]);
        }
        ArrayList<String> arrayList = new ArrayList<String>();
        Object object = stringArray;
        int n2 = stringArray.length;
        for (int i3 = 0; i3 < n2; ++i3) {
            String string = object[i3];
            if (!hashSet.contains(string) && !string.equals(this.model.m._responseColumn)) {
                arrayList.add(string);
                continue;
            }
            hashSet.remove(string);
        }
        if (arrayList.size() > 0) {
            object = new StringBuilder("There were ");
            ((StringBuilder)object).append(arrayList.size());
            ((StringBuilder)object).append(" missing columns found in the input data set: {");
            for (n2 = 0; n2 < arrayList.size(); ++n2) {
                ((StringBuilder)object).append((String)arrayList.get(n2));
                if (n2 == arrayList.size() - 1) continue;
                ((StringBuilder)object).append(",");
            }
            ((StringBuilder)object).append('}');
            System.out.println(object);
        }
        if (hashSet.size() > 0) {
            object = new StringBuilder("Detected ");
            ((StringBuilder)object).append(hashSet.size());
            ((StringBuilder)object).append(" unused columns in the input data set: {");
            Iterator iterator = hashSet.iterator();
            while (iterator.hasNext()) {
                ((StringBuilder)object).append((String)iterator.next());
                if (!iterator.hasNext()) continue;
                ((StringBuilder)object).append(",");
            }
            ((StringBuilder)object).append('}');
            System.out.println(object);
        }
    }

    private void parseArgs(String[] args) {
        try {
            String string = "";
            int n2 = 0;
            block26: for (int i2 = 0; i2 < args.length; ++i2) {
                String string2 = args[i2];
                if (string2.equals("--header")) continue;
                if (string2.equals("--decimal")) {
                    this.useDecimalOutput = true;
                    continue;
                }
                if (string2.equals("--glrmReconstruct")) {
                    this.returnGLRMReconstruct = true;
                    continue;
                }
                if (string2.equals("--setConvertInvalidNum")) {
                    this.setInvNumNA = true;
                    continue;
                }
                if (string2.equals("--leafNodeAssignment")) {
                    this.getTreePath = true;
                    continue;
                }
                if (string2.equals("--predictContributions")) {
                    this.predictContributions = true;
                    continue;
                }
                if (string2.equals("--embedded")) {
                    n2 = -1;
                    continue;
                }
                if (++i2 >= args.length) {
                    PredictCsv.usage();
                }
                String string3 = args[i2];
                switch (string2) {
                    case "--model": {
                        string = string3;
                        n2 = 2;
                        continue block26;
                    }
                    case "--mojo": {
                        string = string3;
                        n2 = 1;
                        continue block26;
                    }
                    case "--pojo": {
                        string = string3;
                        n2 = 0;
                        continue block26;
                    }
                    case "--input": {
                        this.inputCSVFileName = string3;
                        continue block26;
                    }
                    case "--output": {
                        this.outputCSVFileName = string3;
                        continue block26;
                    }
                    case "--separator": {
                        String string4 = string3;
                        this.separator = string4.charAt(string4.length() - 1);
                        continue block26;
                    }
                    case "--glrmIterNumber": {
                        this.glrmIterNumber = Integer.valueOf(string3);
                        continue block26;
                    }
                    default: {
                        System.out.println("ERROR: Unknown command line argument: " + string2);
                        PredictCsv.usage();
                    }
                }
            }
            switch (n2) {
                case -1: {
                    break;
                }
                case 0: {
                    this.loadPojo(string);
                    break;
                }
                case 1: {
                    this.loadMojo(string);
                    break;
                }
                case 2: {
                    this.loadModel(string);
                }
                default: {
                    return;
                }
            }
        }
        catch (Exception exception) {
            Exception exception2 = exception;
            exception.printStackTrace();
            PredictCsv.usage();
        }
    }
}

