/*
 * Decompiled with CFR 0.152.
 */
package smile.nlp.pos;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.io.Paths;
import smile.math.MathEx;
import smile.nlp.pos.POSTagger;
import smile.nlp.pos.PennTreebankPOS;
import smile.nlp.pos.RegexPOSTagger;

public class HMMPOSTagger
implements POSTagger,
Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(HMMPOSTagger.class);
    private Map<String, Integer> symbol;
    private Map<String, Integer> suffix;
    private double[] pi;
    private double[][] a;
    private double[][] b;
    private double[][] c;
    private static HMMPOSTagger DEFAULT_TAGGER;

    public HMMPOSTagger() {
    }

    private HMMPOSTagger(Map<String, Integer> symbol, Map<String, Integer> suffix, double[] pi, double[][] a, double[][] b, double[][] c) {
        if (pi.length != PennTreebankPOS.values().length) {
            throw new IllegalArgumentException("The number of states is different from the size of Penn Treebank tagset.");
        }
        if (a[0].length != PennTreebankPOS.values().length) {
            throw new IllegalArgumentException("Invlid state transition probability size.");
        }
        if (b[0].length != symbol.size() + 1) {
            throw new IllegalArgumentException("Invlid symbol emission probability size.");
        }
        if (c[0].length != suffix.size()) {
            throw new IllegalArgumentException("Invlid symbol suffix emission probability size.");
        }
        this.pi = pi;
        this.a = a;
        this.b = b;
        this.c = c;
        this.symbol = symbol;
        this.suffix = suffix;
    }

    public static HMMPOSTagger getDefault() {
        if (DEFAULT_TAGGER == null) {
            try {
                ObjectInputStream ois = new ObjectInputStream(HMMPOSTagger.class.getResourceAsStream("/smile/nlp/pos/hmm-pos-tagger.model"));
                DEFAULT_TAGGER = (HMMPOSTagger)ois.readObject();
                ois.close();
            }
            catch (Exception ex) {
                logger.error("Failed to load /smile/nlp/pos/hmm-pos-tagger.model", (Throwable)ex);
            }
        }
        return DEFAULT_TAGGER;
    }

    @Override
    public PennTreebankPOS[] tag(String[] sentence) {
        int[] s = this.viterbi(sentence);
        int n = sentence.length;
        PennTreebankPOS[] pos = new PennTreebankPOS[n];
        for (int i = 0; i < n; ++i) {
            if (this.symbol.get(sentence[i]) == null) {
                pos[i] = RegexPOSTagger.tag(sentence[i]);
            }
            if (pos[i] != null) continue;
            pos[i] = PennTreebankPOS.values()[s[i]];
        }
        return pos;
    }

    private int[] viterbi(String[] sentence) {
        int n = sentence.length;
        int[][] o = HMMPOSTagger.translate(this.symbol, this.suffix, sentence);
        int[] s = new int[n];
        int numStates = this.pi.length;
        double[][] delta = new double[n][numStates];
        int[][] psy = new int[n][numStates];
        for (int i = 0; i < numStates; ++i) {
            delta[0][i] = o[0][0] == 0 && o[0][1] >= 0 ? MathEx.log((double)this.pi[i]) + MathEx.log((double)this.c[i][o[0][1]]) : MathEx.log((double)this.pi[i]) + MathEx.log((double)this.b[i][o[0][0]]);
        }
        for (int t = 1; t < n; ++t) {
            for (int k = 0; k < numStates; ++k) {
                double maxDelta = Double.NEGATIVE_INFINITY;
                int maxPsy = -1;
                for (int i = 0; i < numStates; ++i) {
                    double thisDelta = delta[t - 1][i] + MathEx.log((double)this.a[i][k]);
                    if (!(maxDelta < thisDelta)) continue;
                    maxDelta = thisDelta;
                    maxPsy = i;
                }
                delta[t][k] = o[t][0] == 0 && o[t][1] >= 0 ? maxDelta + MathEx.log((double)this.c[k][o[t][1]]) : maxDelta + MathEx.log((double)this.b[k][o[t][0]]);
                psy[t][k] = maxPsy;
            }
        }
        n = o.length - 1;
        double maxDelta = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < numStates; ++i) {
            if (!(maxDelta < delta[n][i])) continue;
            maxDelta = delta[n][i];
            s[n] = i;
        }
        int t = n;
        while (t-- > 0) {
            s[t] = psy[t + 1][s[t + 1]];
        }
        return s;
    }

    private static int[][] translate(Map<String, Integer> symbol, Map<String, Integer> suffix, String[] o) {
        int[][] seq = new int[o.length][2];
        for (int i = 0; i < o.length; ++i) {
            Integer index = symbol.get(o[i]);
            seq[i][0] = Objects.requireNonNullElse(index, 0);
            index = null;
            if (o[i].length() > 2) {
                index = suffix.get(o[i].substring(o[i].length() - 2));
            }
            seq[i][1] = Objects.requireNonNullElse(index, -1);
        }
        return seq;
    }

    private static int[] translate(PennTreebankPOS[] tags) {
        int[] seq = new int[tags.length];
        for (int i = 0; i < tags.length; ++i) {
            seq[i] = tags[i].ordinal();
        }
        return seq;
    }

    public static HMMPOSTagger fit(String[][] sentences, PennTreebankPOS[][] labels) {
        int i;
        int index = 1;
        int suffixIndex = 0;
        HashMap<String, Integer> symbol = new HashMap<String, Integer>();
        HashMap<String, Integer> suffix = new HashMap<String, Integer>();
        String[][] stringArray = sentences;
        int n = stringArray.length;
        for (int j = 0; j < n; ++j) {
            String[] sentence;
            for (String word : sentence = stringArray[j]) {
                String s;
                Integer sym = (Integer)symbol.get(word);
                if (sym == null) {
                    symbol.put(word, index++);
                }
                if (word.length() <= 2 || (sym = (Integer)suffix.get(s = word.substring(word.length() - 2))) != null) continue;
                suffix.put(s, suffixIndex++);
            }
        }
        int numStates = PennTreebankPOS.values().length;
        double[] pi = new double[numStates];
        double[][] a = new double[numStates][numStates];
        double[][] b = new double[numStates][symbol.size() + 1];
        double[][] c = new double[numStates][suffix.size()];
        PennTreebankPOS[] tags = PennTreebankPOS.values();
        for (i = 0; i < numStates; ++i) {
            if (!tags[i].open) continue;
            b[i][0] = 1.0;
        }
        for (i = 0; i < sentences.length; ++i) {
            int[] tag = HMMPOSTagger.translate(labels[i]);
            int[][] obs = HMMPOSTagger.translate(symbol, suffix, sentences[i]);
            int n2 = tag[0];
            pi[n2] = pi[n2] + 1.0;
            double[] dArray = b[tag[0]];
            int n3 = obs[0][0];
            dArray[n3] = dArray[n3] + 1.0;
            if (obs[0][1] >= 0) {
                double[] dArray2 = c[tag[0]];
                int n4 = obs[0][1];
                dArray2[n4] = dArray2[n4] + 1.0;
            }
            for (int j = 1; j < obs.length; ++j) {
                double[] dArray3 = a[tag[j - 1]];
                int n5 = tag[j];
                dArray3[n5] = dArray3[n5] + 1.0;
                double[] dArray4 = b[tag[j]];
                int n6 = obs[j][0];
                dArray4[n6] = dArray4[n6] + 1.0;
                if (obs[j][1] < 0) continue;
                double[] dArray5 = c[tag[j]];
                int n7 = obs[j][1];
                dArray5[n7] = dArray5[n7] + 1.0;
            }
        }
        MathEx.unitize1((double[])pi);
        for (i = 0; i < numStates; ++i) {
            MathEx.unitize1((double[])a[i]);
            MathEx.unitize1((double[])b[i]);
            MathEx.unitize1((double[])c[i]);
        }
        return new HMMPOSTagger(symbol, suffix, pi, a, b, c);
    }

    public static void read(Path dir, List<String[]> sentences, List<PennTreebankPOS[]> labels) {
        ArrayList<File> files = new ArrayList<File>();
        HMMPOSTagger.walkin(dir.toFile(), files);
        for (File file : files) {
            try {
                String line;
                FileInputStream stream = new FileInputStream(file);
                BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
                ArrayList<String> sent = new ArrayList<String>();
                ArrayList<PennTreebankPOS> label = new ArrayList<PennTreebankPOS>();
                while ((line = reader.readLine()) != null) {
                    String[] words;
                    if ((line = line.trim()).isEmpty()) {
                        if (sent.isEmpty()) continue;
                        sentences.add(sent.toArray(new String[0]));
                        labels.add(label.toArray(new PennTreebankPOS[0]));
                        sent.clear();
                        label.clear();
                        continue;
                    }
                    if (line.startsWith("===") || line.startsWith("*x*")) continue;
                    for (String word : words = line.split("\\s")) {
                        String tag;
                        String[] w = word.split("/");
                        if (w.length != 2) continue;
                        sent.add(w[0]);
                        int pos = w[1].indexOf(124);
                        String string = tag = pos == -1 ? w[1] : w[1].substring(0, pos);
                        if (tag.equals("PRP$R")) {
                            tag = "PRP$";
                        }
                        if (tag.equals("JJSS")) {
                            tag = "JJS";
                        }
                        label.add(PennTreebankPOS.getValue(tag));
                    }
                }
                if (!sent.isEmpty()) {
                    sentences.add(sent.toArray(new String[0]));
                    labels.add(label.toArray(new PennTreebankPOS[0]));
                    sent.clear();
                    label.clear();
                }
                reader.close();
            }
            catch (Exception e) {
                logger.error("Failed to load training data {}", (Object)file, (Object)e);
            }
        }
    }

    public static void walkin(File dir, List<File> files) {
        String pattern = ".POS";
        File[] listFile = dir.listFiles();
        if (listFile != null) {
            for (File file : listFile) {
                if (file.isDirectory()) {
                    HMMPOSTagger.walkin(file, files);
                    continue;
                }
                if (!file.getName().endsWith(pattern)) continue;
                files.add(file);
            }
        }
    }

    public static void main(String[] args) {
        ArrayList<String[]> sentences = new ArrayList<String[]>();
        ArrayList<PennTreebankPOS[]> labels = new ArrayList<PennTreebankPOS[]>();
        HMMPOSTagger.read(Paths.getTestData((String[])new String[]{"nlp/PennTreebank/PennTreebank2/TAGGED/POS/WSJ"}), sentences, labels);
        HMMPOSTagger.read(Paths.getTestData((String[])new String[]{"nlp/PennTreebank/PennTreebank2/TAGGED/POS/BROWN"}), sentences, labels);
        String[][] x = (String[][])sentences.toArray((T[])new String[sentences.size()][]);
        PennTreebankPOS[][] y = (PennTreebankPOS[][])labels.toArray((T[])new PennTreebankPOS[labels.size()][]);
        HMMPOSTagger tagger = HMMPOSTagger.fit(x, y);
        try {
            FileOutputStream fos = new FileOutputStream("hmm-pos-tagger.model");
            ObjectOutputStream oos = new ObjectOutputStream(fos);
            oos.writeObject(tagger);
            oos.flush();
            oos.close();
        }
        catch (Exception ex) {
            logger.error("Failed to save HMM POS model", (Throwable)ex);
        }
    }
}

