/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier.sequencelearning.hmm;

import com.google.common.io.Resources;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.apache.commons.io.Charsets;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmEvaluator;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmModel;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmTrainer;
import org.apache.mahout.classifier.sequencelearning.hmm.HmmUtils;
import org.apache.mahout.math.Matrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class PosTagger {
    private static final Logger log = LoggerFactory.getLogger(PosTagger.class);
    private static final Pattern SPACE = Pattern.compile(" ");
    private static final Pattern SPACES = Pattern.compile("[ ]+");
    private static HmmModel taggingModel;
    private static Map<String, Integer> tagIDs;
    private static int nextTagId;
    private static Map<String, Integer> wordIDs;
    private static int nextWordId;
    private static List<int[]> hiddenSequences;
    private static List<int[]> observedSequences;
    private static int readLines;

    private PosTagger() {
    }

    private static void readFromURL(String url, boolean assignIDs) throws IOException {
        hiddenSequences = new LinkedList<int[]>();
        observedSequences = new LinkedList<int[]>();
        readLines = 0;
        LinkedList<Integer> observedSequence = new LinkedList<Integer>();
        LinkedList<Integer> hiddenSequence = new LinkedList<Integer>();
        for (String line : Resources.readLines((URL)new URL(url), (Charset)Charsets.UTF_8)) {
            if (line.isEmpty()) {
                int[] observedSequenceArray = new int[observedSequence.size()];
                int[] hiddenSequenceArray = new int[hiddenSequence.size()];
                for (int i = 0; i < observedSequence.size(); ++i) {
                    observedSequenceArray[i] = (Integer)observedSequence.get(i);
                    hiddenSequenceArray[i] = (Integer)hiddenSequence.get(i);
                }
                hiddenSequences.add(hiddenSequenceArray);
                observedSequences.add(observedSequenceArray);
                observedSequence.clear();
                hiddenSequence.clear();
                continue;
            }
            ++readLines;
            String[] tags = SPACE.split(line);
            if (assignIDs) {
                if (!wordIDs.containsKey(tags[0])) {
                    wordIDs.put(tags[0], nextWordId++);
                }
                if (!tagIDs.containsKey(tags[1])) {
                    tagIDs.put(tags[1], nextTagId++);
                }
            }
            Integer wordID = wordIDs.get(tags[0]);
            Integer tagID = tagIDs.get(tags[1]);
            if (wordID == null) {
                observedSequence.add(0);
            } else {
                observedSequence.add(wordID);
            }
            if (tagID == null) {
                hiddenSequence.add(0);
                continue;
            }
            hiddenSequence.add(tagID);
        }
        if (!observedSequence.isEmpty()) {
            int[] observedSequenceArray = new int[observedSequence.size()];
            int[] hiddenSequenceArray = new int[hiddenSequence.size()];
            for (int i = 0; i < observedSequence.size(); ++i) {
                observedSequenceArray[i] = (Integer)observedSequence.get(i);
                hiddenSequenceArray[i] = (Integer)hiddenSequence.get(i);
            }
            hiddenSequences.add(hiddenSequenceArray);
            observedSequences.add(observedSequenceArray);
        }
    }

    private static void trainModel(String trainingURL) throws IOException {
        tagIDs = new HashMap<String, Integer>(44);
        wordIDs = new HashMap<String, Integer>(19122);
        log.info("Reading and parsing training data file from URL: {}", (Object)trainingURL);
        long start = System.currentTimeMillis();
        PosTagger.readFromURL(trainingURL, true);
        long end = System.currentTimeMillis();
        double duration = (double)(end - start) / 1000.0;
        log.info("Parsing done in {} seconds!", (Object)duration);
        log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.", new Object[]{readLines, hiddenSequences.size(), nextWordId - 1, nextTagId - 1});
        start = System.currentTimeMillis();
        taggingModel = HmmTrainer.trainSupervisedSequence((int)nextTagId, (int)nextWordId, hiddenSequences, observedSequences, (double)0.05);
        Matrix emissions = taggingModel.getEmissionMatrix();
        for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) {
            emissions.setQuick(i, 0, 0.1 / (double)taggingModel.getNrOfHiddenStates());
        }
        int nnptag = tagIDs.get("NNP");
        emissions.setQuick(nnptag, 0, 1.0 / (double)taggingModel.getNrOfHiddenStates());
        HmmUtils.normalizeModel((HmmModel)taggingModel);
        taggingModel.registerHiddenStateNames(tagIDs);
        taggingModel.registerOutputStateNames(wordIDs);
        end = System.currentTimeMillis();
        duration = (double)(end - start) / 1000.0;
        log.info("Trained HMM models in {} seconds!", (Object)duration);
    }

    private static void testModel(String testingURL) throws IOException {
        log.info("Reading and parsing test data file from URL: {}", (Object)testingURL);
        long start = System.currentTimeMillis();
        PosTagger.readFromURL(testingURL, false);
        long end = System.currentTimeMillis();
        double duration = (double)(end - start) / 1000.0;
        log.info("Parsing done in {} seconds!", (Object)duration);
        log.info("Read {} lines containing {} sentences.", (Object)readLines, (Object)hiddenSequences.size());
        start = System.currentTimeMillis();
        int errorCount = 0;
        int totalCount = 0;
        for (int i = 0; i < observedSequences.size(); ++i) {
            int[] posEstimate = HmmEvaluator.decode((HmmModel)taggingModel, (int[])observedSequences.get(i), (boolean)false);
            int[] posExpected = hiddenSequences.get(i);
            for (int j = 0; j < posExpected.length; ++j) {
                ++totalCount;
                if (posEstimate[j] == posExpected[j]) continue;
                ++errorCount;
            }
        }
        end = System.currentTimeMillis();
        duration = (double)(end - start) / 1000.0;
        log.info("POS tagged test file in {} seconds!", (Object)duration);
        double errorRate = (double)errorCount / (double)totalCount;
        log.info("Tagged the test file with an error rate of: {}", (Object)errorRate);
    }

    private static List<String> tagSentence(String sentence) {
        sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
        sentence = sentence.replaceAll("''", " '' ");
        String[] tokens = SPACES.split(sentence);
        int[] observedSequence = HmmUtils.encodeStateSequence((HmmModel)taggingModel, Arrays.asList(tokens), (boolean)true, (int)0);
        int[] hiddenSequence = HmmEvaluator.decode((HmmModel)taggingModel, (int[])observedSequence, (boolean)false);
        return HmmUtils.decodeStateSequence((HmmModel)taggingModel, (int[])hiddenSequence, (boolean)false, null);
    }

    public static void main(String[] args) throws IOException {
        PosTagger.trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
        PosTagger.testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
        String test = "McDonalds is a huge company with many employees .";
        String[] testWords = SPACE.split(test);
        List<String> posTags = PosTagger.tagSentence(test);
        for (int i = 0; i < posTags.size(); ++i) {
            log.info("{}[{}]", (Object)testWords[i], (Object)posTags.get(i));
        }
    }

    static {
        nextWordId = 1;
    }
}

