/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.jcore.ae.jnet.cli;

import cc.mallet.fst.CRF;
import cc.mallet.types.Alphabet;
import de.julielab.jcore.ae.jnet.utils.FormatConverter;
import de.julielab.jcore.ae.jnet.utils.IOBEvaluation;
import de.julielab.jcore.ae.jnet.utils.IOEvaluation;
import de.julielab.jcore.ae.jnet.utils.Utils;
import de.julielab.jnet.tagger.NETagger;
import de.julielab.jnet.tagger.Sentence;
import de.julielab.jnet.tagger.Tags;
import de.julielab.jnet.tagger.Unit;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Properties;
import java.util.Random;

public class JNETApplication {
    public static void main(String[] args) {
        File modelFile;
        File trainFile;
        String mode;
        long startTime = System.currentTimeMillis();
        if (args.length < 1) {
            System.err.println("usage: JNETApplication <mode> <mode-specific-parameters>");
            JNETApplication.showModes();
            System.exit(-1);
        }
        if ((mode = args[0]).equals("f")) {
            if (args.length < 4) {
                System.out.println("usage: JNETApplication f <iobFile> <1st meta data file> [further meta data files] <outFile> <taglist (or 0 if not used)>");
                System.exit(0);
            }
            String[] converterArgs = new String[args.length - 1];
            int i = 1;
            while (i < args.length) {
                converterArgs[i - 1] = args[i];
                ++i;
            }
            FormatConverter.main(converterArgs);
        } else if (mode.equals("s")) {
            if (args.length < 4) {
                System.err.println("usage: JNETApplication s <data.ppd> <tags.def> <pred-out> [featureConfigFile] [number of iterations]");
                System.err.println("pred-out format: token pred gold");
                System.exit(-1);
            }
            trainFile = new File(args[1]);
            File tagsFile = new File(args[2]);
            File predFile = new File(args[3]);
            File featureConfigFile = null;
            int number_iter = 0;
            boolean max_ent = false;
            if (args.length == 5) {
                featureConfigFile = new File(args[4]);
            }
            if (args.length == 6) {
                featureConfigFile = new File(args[4]);
                number_iter = new Integer(args[5]);
            }
            if (args.length == 7) {
                featureConfigFile = new File(args[4]);
                number_iter = new Integer(args[5]);
                max_ent = new Boolean(args[6]);
            }
            JNETApplication.eval9010(trainFile, tagsFile, predFile, featureConfigFile, number_iter, max_ent);
        } else if (mode.equals("x")) {
            if (args.length < 6) {
                System.err.println("usage: JNETApplication x <trainData.ppd> <tags.def> <pred-out> <x-rounds> <performance-out-file> [featureConfigFile] [number of iterations]");
                System.err.println("pred-out format: token pred gold");
                System.exit(-1);
            }
            trainFile = new File(args[1]);
            File tagsFile = new File(args[2]);
            File predFile = new File(args[3]);
            int rounds = new Integer(args[4]);
            File performanceOutFile = new File(args[5]);
            File featureConfigFile = null;
            int number_iter = 0;
            boolean max_ent = false;
            if (args.length == 7) {
                featureConfigFile = new File(args[6]);
            }
            if (args.length == 8) {
                featureConfigFile = new File(args[6]);
                number_iter = new Integer(args[7]);
            }
            if (args.length == 9) {
                featureConfigFile = new File(args[6]);
                number_iter = new Integer(args[7]);
                max_ent = new Boolean(args[8]);
            }
            JNETApplication.evalXVal(trainFile, tagsFile, rounds, predFile, performanceOutFile, featureConfigFile, number_iter, max_ent);
        } else if (mode.equals("t")) {
            if (args.length < 3) {
                System.err.println("usage: JNETApplication t <trainData.ppd> <model-out-file> [featureConfigFile] [number of iterations]");
                System.exit(-1);
            }
            trainFile = new File(args[1]);
            File modelFile2 = new File(args[2]);
            File featureConfigFile = null;
            int number_iter = 0;
            boolean max_ent = false;
            if (args.length == 4) {
                featureConfigFile = new File(args[3]);
            }
            if (args.length == 5) {
                featureConfigFile = new File(args[3]);
                number_iter = new Integer(args[4]);
            }
            if (args.length == 6) {
                featureConfigFile = new File(args[3]);
                number_iter = new Integer(args[4]);
                max_ent = new Boolean(args[5]);
            }
            JNETApplication.train(trainFile, modelFile2, featureConfigFile, number_iter, max_ent);
        } else if (mode.equals("p")) {
            if (args.length != 5) {
                System.err.println("usage: JNETApplication p <unlabeled data.ppd> <modelFile> <outFile> <estimate segment conf>");
                System.exit(-1);
            }
            trainFile = new File(args[1]);
            File modelFile3 = new File(args[2]);
            File outFile = new File(args[3]);
            boolean conf = new Boolean(args[4]);
            JNETApplication.predict(trainFile, modelFile3, outFile, conf);
        } else if (mode.equals("c")) {
            if (args.length != 4) {
                System.err.println("\ncompares the gold standard agains the prediction: give both IOB files, they must have the same length!");
                System.err.println("\nusage: JNETApplication c <predData.iob> <goldData.iob> <tag.def>");
                System.exit(-1);
            }
            File predFile = new File(args[1]);
            File goldFile = new File(args[2]);
            File tagsFile = new File(args[3]);
            double[] eval = JNETApplication.compare(predFile, goldFile, tagsFile);
            System.out.println(String.valueOf(eval[0]) + "\t" + eval[1] + "\t" + eval[2]);
        } else if (mode.equals("oc")) {
            if (args.length != 2) {
                System.err.println("\nusage: JNETApplication oc <model>");
                System.exit(-1);
            }
            modelFile = new File(args[1]);
            JNETApplication.printFeatureConfig(modelFile);
        } else if (mode.equals("oa")) {
            if (args.length != 2) {
                System.err.println("\nusage: JNETApplication oa <model>");
                System.exit(-1);
            }
            modelFile = new File(args[1]);
            JNETApplication.printOutputAlphabet(modelFile);
        } else {
            System.err.println("ERR: unknown mode");
            JNETApplication.showModes();
            System.exit(-1);
        }
        long timeNeeded = (System.currentTimeMillis() - startTime) / 1000L / 60L;
        System.out.println("Finished in " + timeNeeded + " minutes");
    }

    static void showModes() {
        System.err.println("\nAvailable modes:");
        System.err.println("f: converting multiple annotations to one file");
        System.err.println("s: 90-10 split evaluation");
        System.err.println("x: cross validation ");
        System.err.println("c: compare goldstandard and prediction");
        System.err.println("t: train ");
        System.err.println("p: predict ");
        System.err.println("oc: output model configuration ");
        System.err.println("oa: output the model's output alphabet ");
        System.exit(-1);
    }

    static void train(File trainFile, File outFile, File featureConfigFile, int number_iter, boolean maxEnt) {
        ArrayList<String> ppdSentences = Utils.readFile(trainFile);
        ArrayList<Sentence> sentences = new ArrayList<Sentence>();
        NETagger tagger = featureConfigFile != null ? new NETagger(featureConfigFile) : new NETagger();
        tagger.set_Number_Iterations(number_iter);
        tagger.set_Max_Ent(maxEnt);
        for (String ppdSentence : ppdSentences) {
            sentences.add(tagger.PPDtoUnits(ppdSentence));
        }
        tagger.train(sentences);
        tagger.writeModel(outFile.toString());
    }

    static void evalXVal(File dataFile, File tagsFile, int n, File predictionOutFile, File performanceOutFile, File featureConfigFile, int number_iter, boolean maxEnt) {
        ArrayList<String> ppdData = Utils.readFile(dataFile);
        JNETApplication.evalXVal(ppdData, tagsFile, n, predictionOutFile, performanceOutFile, featureConfigFile, number_iter, maxEnt);
    }

    public static void evalXVal(List<String> ppdData, File tagsFile, int n, File predictionOutFile, File performanceOutFile, File featureConfigFile, int number_iter, boolean maxEnt) {
        ArrayList<String> output = new ArrayList<String>();
        Tags tags = new Tags(tagsFile.toString());
        long seed = 1L;
        Collections.shuffle(ppdData, new Random(1L));
        int pos = 0;
        int sizeRound = ppdData.size() / n;
        int sizeAll = ppdData.size();
        int sizeLastRound = sizeRound + sizeAll % n;
        System.out.println(" * number of sentences: " + sizeAll);
        System.out.println(" * size of each/last round: " + sizeRound + "/" + sizeLastRound);
        System.out.println();
        double[] fscores = new double[n];
        double[] recalls = new double[n];
        double[] precisions = new double[n];
        int i = 0;
        while (i < n) {
            int j;
            ArrayList<String> ppdTrainData = new ArrayList<String>();
            ArrayList<String> ppdTestData = new ArrayList<String>();
            if (i == n - 1) {
                j = 0;
                while (j < ppdData.size()) {
                    if (j < pos) {
                        ppdTrainData.add(ppdData.get(j));
                    } else {
                        ppdTestData.add(ppdData.get(j));
                    }
                    ++j;
                }
            } else {
                j = 0;
                while (j < ppdData.size()) {
                    if (j < pos || j >= pos + sizeRound) {
                        ppdTrainData.add(ppdData.get(j));
                    } else {
                        ppdTestData.add(ppdData.get(j));
                    }
                    ++j;
                }
                pos += sizeRound;
            }
            System.out.println(" * training on: " + ppdTrainData.size() + " -- testing on: " + ppdTestData.size());
            double[] eval = JNETApplication.eval(ppdTrainData, ppdTestData, tags, output, featureConfigFile, number_iter, maxEnt);
            recalls[i] = eval[0];
            precisions[i] = eval[1];
            fscores[i] = eval[2];
            System.out.println("\n** round " + (i + 1) + ": R/P/F: " + eval[0] + "/" + eval[1] + "/" + eval[2]);
            ++i;
        }
        double avgRecall = JNETApplication.getAverage(recalls);
        double avgPrecision = JNETApplication.getAverage(precisions);
        double avgFscore = JNETApplication.getAverage(fscores);
        double stdRecall = JNETApplication.getStandardDeviation(recalls, avgRecall);
        double stdPrecision = JNETApplication.getStandardDeviation(precisions, avgPrecision);
        double stdFscore = JNETApplication.getStandardDeviation(fscores, avgFscore);
        DecimalFormat df = new DecimalFormat("0.000");
        StringBuffer summary = new StringBuffer();
        summary.append("Cross-validation results:\n");
        summary.append("Number of sentences in evaluation data set: " + sizeAll + "\n");
        summary.append("Number of sentences for training in each/last round: " + sizeRound + "/" + sizeLastRound + "\n\n");
        summary.append("Overall performance: avg (standard deviation)\n");
        summary.append("Recall: " + df.format(avgRecall) + "(" + df.format(stdRecall) + ")\n");
        summary.append("Precision: " + df.format(avgPrecision) + "(" + df.format(stdPrecision) + ")\n");
        summary.append("F1-Score: " + df.format(avgFscore) + "(" + df.format(stdFscore) + ")\n");
        Utils.writeFile(performanceOutFile, summary.toString());
        Utils.writeFile(predictionOutFile, output);
        System.out.println("\n\nCross-validation finished. Results written to: " + performanceOutFile);
        System.out.println(summary.toString());
    }

    public static double getStandardDeviation(double[] values, double avg) {
        double sum = 0.0;
        double[] dArray = values;
        int n = values.length;
        int n2 = 0;
        while (n2 < n) {
            double value = dArray[n2];
            sum += Math.pow(value - avg, 2.0);
            ++n2;
        }
        return Math.sqrt(sum / ((double)values.length - 1.0));
    }

    public static double getAverage(double[] values) {
        double sum = 0.0;
        double[] dArray = values;
        int n = values.length;
        int n2 = 0;
        while (n2 < n) {
            double value = dArray[n2];
            sum += value;
            ++n2;
        }
        return sum / (double)values.length;
    }

    static void eval9010(File dataFile, File tagsFile, File outFile, File featureConfigFile, int number_iter, boolean maxEnt) {
        ArrayList<String> output = new ArrayList<String>();
        Tags tags = new Tags(tagsFile.toString());
        ArrayList<String> ppdData = Utils.readFile(dataFile);
        long seed = 1L;
        Collections.shuffle(ppdData, new Random(1L));
        int sizeAll = ppdData.size();
        int sizeTest = (int)((double)sizeAll * 0.1);
        int sizeTrain = sizeAll - sizeTest;
        if (sizeTest == 0) {
            System.err.println("Error: no test files for this split.");
            System.exit(-1);
        }
        System.out.println(" * all: " + sizeAll + "\ttrain: " + sizeTrain + "\t" + "test: " + sizeTest);
        ArrayList<String> ppdTrainData = new ArrayList<String>();
        ArrayList<String> ppdTestData = new ArrayList<String>();
        int i = 0;
        while (i < ppdData.size()) {
            if (i < sizeTrain) {
                ppdTrainData.add(ppdData.get(i));
            } else {
                ppdTestData.add(ppdData.get(i));
            }
            ++i;
        }
        System.out.println(" * training on: " + ppdTrainData.size() + " -- testing on: " + ppdTestData.size());
        double[] eval = JNETApplication.eval(ppdTrainData, ppdTestData, tags, output, featureConfigFile, number_iter, maxEnt);
        DecimalFormat df = new DecimalFormat("0.000");
        System.out.println("\n\n** R/P/F: " + df.format(eval[0]) + "/" + df.format(eval[1]) + "/" + df.format(eval[2]));
        Utils.writeFile(outFile, output);
    }

    static void predict(File testDataFile, File modelFile, File outFile, boolean showSegmentConfidence) {
        ArrayList<String> ppdTestData = Utils.readFile(testDataFile);
        ArrayList<Sentence> sentences = new ArrayList<Sentence>();
        NETagger tagger = new NETagger();
        try {
            tagger.readModel(modelFile);
            for (String ppdSentence : ppdTestData) {
                sentences.add(tagger.PPDtoUnits(ppdSentence));
            }
            Utils.writeFile(outFile, tagger.predictIOB(sentences, showSegmentConfidence));
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    static double[] eval(ArrayList<String> ppdTrainData, ArrayList<String> ppdTestData, Tags tags, ArrayList<String> output, File featureConfigFile, int number_iter, boolean maxEnt) {
        ArrayList<Sentence> trainSentences = new ArrayList<Sentence>();
        ArrayList<Sentence> testSentences = new ArrayList<Sentence>();
        NETagger tagger = featureConfigFile != null ? new NETagger(featureConfigFile) : new NETagger();
        tagger.set_Number_Iterations(number_iter);
        tagger.set_Max_Ent(maxEnt);
        for (String ppdTrainSentence : ppdTrainData) {
            trainSentences.add(tagger.PPDtoUnits(ppdTrainSentence));
        }
        for (String ppdTestSentence : ppdTestData) {
            testSentences.add(tagger.PPDtoUnits(ppdTestSentence));
        }
        tagger.train(trainSentences);
        ArrayList<String> pos = new ArrayList<String>();
        ArrayList<String> gold = new ArrayList<String>();
        int i = 0;
        while (i < testSentences.size()) {
            Sentence sentence = (Sentence)testSentences.get(i);
            for (Unit unit : sentence.getUnits()) {
                gold.add(String.valueOf(unit.getRep()) + "\t" + unit.getLabel());
                pos.add(unit.getMetaInfo(tagger.getFeatureConfig().getProperty("pos_feat_unit")));
            }
            gold.add("O\tO");
            pos.add("");
            ++i;
        }
        tagger.predictIOB(testSentences, false);
        ArrayList<String> pred = new ArrayList<String>();
        int i2 = 0;
        while (i2 < testSentences.size()) {
            Sentence sentence = testSentences.get(i2);
            for (Unit unit : sentence.getUnits()) {
                pred.add(String.valueOf(unit.getRep()) + "\t" + unit.getLabel());
            }
            pred.add("O\tO");
            ++i2;
        }
        double[] eval = new double[]{0.0, 0.0, 0.0};
        if (tags.type.equals("IO")) {
            eval = IOEvaluation.evaluate(gold, pred);
        } else {
            try {
                eval = IOBEvaluation.evaluate(gold, pred);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        int i3 = 0;
        while (i3 < pred.size()) {
            output.add(String.valueOf((String)pred.get(i3)) + "\t" + ((String)gold.get(i3)).split("\t")[1] + "\t" + (String)pos.get(i3));
            ++i3;
        }
        return eval;
    }

    static double[] compare(File predFile, File goldFile, File tagsFile) {
        ArrayList<String> gold = Utils.readFile(goldFile);
        ArrayList<String> pred = Utils.readFile(predFile);
        Tags tags = new Tags(tagsFile.toString());
        int i = 0;
        while (i < gold.size()) {
            if (gold.get(i).equals("")) {
                gold.set(i, "O\tO");
            }
            ++i;
        }
        i = 0;
        while (i < pred.size()) {
            if (pred.get(i).equals("")) {
                pred.set(i, "O\tO");
            }
            ++i;
        }
        if (gold.size() != pred.size()) {
            System.err.println("ERR: number of tokens/lines in gold standard is different from prediction... please check!");
            System.exit(-1);
        }
        double[] eval = new double[]{0.0, 0.0, 0.0};
        if (tags.type.equals("IO")) {
            eval = IOEvaluation.evaluate(gold, pred);
        } else {
            try {
                eval = IOEvaluation.evaluate(gold, pred);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        return eval;
    }

    public static void printFeatureConfig(File modelFile) {
        NETagger tagger = new NETagger();
        try {
            tagger.readModel(modelFile);
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        Properties featureConfig = tagger.getFeatureConfig();
        Enumeration<?> keys = featureConfig.propertyNames();
        while (keys.hasMoreElements()) {
            String key = (String)keys.nextElement();
            System.out.printf("%s = %s\n", key, featureConfig.getProperty(key));
        }
    }

    public static void printOutputAlphabet(File modelFile) {
        Object[] modelLabels;
        NETagger tagger = new NETagger();
        try {
            tagger.readModel(modelFile);
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        Object model = tagger.getModel();
        Alphabet alpha = ((CRF)model).getOutputAlphabet();
        Object[] objectArray = modelLabels = alpha.toArray();
        int n = modelLabels.length;
        int n2 = 0;
        while (n2 < n) {
            Object modelLabel = objectArray[n2];
            System.out.println(modelLabel);
            ++n2;
        }
    }
}

