/*
 * Decompiled with CFR 0.152.
 */
package edu.umn.biomedicus.acronym;

import edu.umn.biomedicus.acronym.OrthographicAcronymModel;
import edu.umn.biomedicus.common.collect.IndexMap;
import edu.umn.biomedicus.serialization.YamlSerialization;
import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.yaml.snakeyaml.Yaml;

public class OrthographicAcronymModelTrainer {
    private static final double discounting = 0.9;
    private final boolean caseSensitive;
    private final transient IndexMap<Character> symbols;
    private final transient int symbolsCount;
    private final transient Set<Character> chars;
    private final double[][][] longformProbs;
    private final double[][][] abbrevProbs;
    private Set<String> longformsLower;
    private Path abbrevPath;
    private Path longformsPath;

    public OrthographicAcronymModelTrainer(boolean caseSensitive) {
        this.caseSensitive = caseSensitive;
        this.symbols = caseSensitive ? OrthographicAcronymModel.CASE_SENS_SYMBOLS : OrthographicAcronymModel.CASE_INSENS_SYMBOLS;
        this.symbolsCount = this.symbols.size();
        this.chars = caseSensitive ? OrthographicAcronymModel.CASE_SENS_CHARS : OrthographicAcronymModel.CASE_INSENS_CHARS;
        this.longformProbs = new double[this.symbolsCount][this.symbolsCount][this.symbolsCount];
        this.abbrevProbs = new double[this.symbolsCount][this.symbolsCount][this.symbolsCount];
        this.longformsLower = new HashSet<String>();
    }

    public static void main(String[] args) {
        Path abbrevsPath = Paths.get(args[0], new String[0]);
        Path longformsPath = Paths.get(args[1], new String[0]);
        OrthographicAcronymModelTrainer orthographicAcronymModelTrainer = new OrthographicAcronymModelTrainer(true);
        orthographicAcronymModelTrainer.setAbbrevPath(abbrevsPath);
        orthographicAcronymModelTrainer.setLongformsPath(longformsPath);
        try {
            orthographicAcronymModelTrainer.trainTrigramModel();
            orthographicAcronymModelTrainer.write(Paths.get(args[2], new String[0]));
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void setAbbrevPath(Path abbrevPath) {
        this.abbrevPath = abbrevPath;
    }

    public void setLongformsPath(Path longformsPath) {
        this.longformsPath = longformsPath;
    }

    private void trainTrigramModel() throws IOException {
        Set<String> abbrevs = Files.lines(this.abbrevPath).collect(Collectors.toSet());
        Set<String> longforms = Files.lines(this.longformsPath).collect(Collectors.toSet());
        this.longformsLower = longforms.stream().map(String::toLowerCase).collect(Collectors.toSet());
        this.wordsToLogProbs(longforms, this.longformProbs);
        this.wordsToLogProbs(abbrevs, this.abbrevProbs);
    }

    private void wordsToLogProbs(Set<String> words, double[][][] probs) {
        int[][][] counts = new int[this.symbolsCount][this.symbolsCount][this.symbolsCount];
        for (String word : words) {
            this.addTrigramsFromWord(word, counts);
        }
        for (int i = 0; i < this.symbolsCount; ++i) {
            for (int j = 0; j < this.symbolsCount; ++j) {
                for (int k = 0; k < this.symbolsCount; ++k) {
                    probs[i][j][k] = (float)this.getTrigramLogProbability(i, j, k, counts);
                }
            }
        }
    }

    private double getTrigramLogProbability(int w2, int w1, int w, int[][][] counts) {
        double prob = 0.0;
        int contextCount = this.tensorSum(counts[w2][w1]);
        if (contextCount == 0) {
            prob = this.getBigramProbability(w1, w, counts);
        } else {
            int triCount = counts[w2][w1][w];
            if (triCount > 0) {
                prob += ((double)triCount - 0.9) / (double)contextCount;
            }
            double interpolationCoefficient = 0.9 * (double)this.tensorSum(counts[w2][w1], true) / (double)contextCount;
            prob += interpolationCoefficient * this.getBigramProbability(w1, w, counts);
        }
        if (prob <= 0.0) {
            prob = 1.0 / (double)this.tensorSum(counts);
        }
        return Math.log(prob);
    }

    private double getBigramProbability(int w1, int w, int[][][] counts) {
        if (this.tensorSum(counts[w1]) == 0) {
            return this.getUnigramProbability(w, counts);
        }
        double prob = 0.0;
        int biCount = this.tensorSum(counts[w1][w]);
        if (biCount > 0) {
            prob += ((double)biCount - 0.9) / (double)this.tensorSum(counts[w1]);
        }
        double unigramInterpCoefficient = 0.9 * (double)this.tensorSum(counts[w1], true) / (double)this.tensorSum(counts[w1]);
        return prob += unigramInterpCoefficient * this.getUnigramProbability(w, counts);
    }

    private double getUnigramProbability(int w, int[][][] counts) {
        return (double)this.tensorSum(counts[w]) / (double)this.tensorSum(counts);
    }

    private int tensorSum(int[] array, boolean nonzeros) {
        int sum = 0;
        for (int i : array) {
            if (!nonzeros) {
                sum += i;
                continue;
            }
            if (i <= 0) continue;
            ++sum;
        }
        return sum;
    }

    private int tensorSum(int[][] matrix, boolean nonzeros) {
        int sum = 0;
        for (int[] array : matrix) {
            sum += this.tensorSum(array, nonzeros);
        }
        return sum;
    }

    private int tensorSum(int[][][] tensor3, boolean nonzeros) {
        int sum = 0;
        for (int[][] matrix : tensor3) {
            sum += this.tensorSum(matrix, nonzeros);
        }
        return sum;
    }

    private int tensorSum(int[] tensor) {
        return this.tensorSum(tensor, false);
    }

    private int tensorSum(int[][] tensor) {
        return this.tensorSum(tensor, false);
    }

    private int tensorSum(int[][][] tensor) {
        return this.tensorSum(tensor, false);
    }

    private void addTrigramsFromWord(String word, int[][][] counts) {
        char minus2 = '^';
        char minus1 = '^';
        char thisChar = '^';
        for (int i = 0; i < word.length(); ++i) {
            thisChar = this.fixChar(word.charAt(i));
            int[] nArray = counts[this.symbols.indexOf(Character.valueOf(minus2))][this.symbols.indexOf(Character.valueOf(minus1))];
            int n = this.symbols.indexOf(Character.valueOf(thisChar));
            nArray[n] = nArray[n] + 1;
            minus2 = minus1;
            minus1 = thisChar;
        }
        int[] nArray = counts[this.symbols.indexOf(Character.valueOf(minus1))][this.symbols.indexOf(Character.valueOf(thisChar))];
        int n = this.symbols.indexOf(Character.valueOf('$'));
        nArray[n] = nArray[n] + 1;
        int[] nArray2 = counts[this.symbols.indexOf(Character.valueOf(thisChar))][this.symbols.indexOf(Character.valueOf('$'))];
        int n2 = this.symbols.indexOf(Character.valueOf('$'));
        nArray2[n2] = nArray2[n2] + 1;
    }

    private char fixChar(char c) {
        if (!this.caseSensitive) {
            c = Character.toLowerCase(c);
        }
        if (Character.isDigit(c)) {
            c = (char)48;
        } else if (!this.chars.contains(Character.valueOf(c))) {
            c = (char)63;
        }
        return c;
    }

    private void write(Path out) throws IOException {
        Yaml yaml = YamlSerialization.createYaml();
        TreeMap<String, Object> serObj = new TreeMap<String, Object>();
        serObj.put("abbrevProbs", this.collapseProbs(this.abbrevProbs));
        serObj.put("longformProbs", this.collapseProbs(this.longformProbs));
        serObj.put("longformsLower", this.longformsLower.stream().collect(Collectors.toList()));
        serObj.put("caseSensitive", this.caseSensitive);
        yaml.dump(serObj, (Writer)Files.newBufferedWriter(out, new OpenOption[0]));
    }

    private Map<String, Double> collapseProbs(double[][][] probs) {
        TreeMap<String, Double> collapsedAbbrevProbs = new TreeMap<String, Double>();
        for (int i = 0; i < probs.length; ++i) {
            for (int j = 0; j < probs[i].length; ++j) {
                for (int k = 0; k < probs[i][j].length; ++k) {
                    double prob = probs[i][j][k];
                    if (prob == 0.0) continue;
                    collapsedAbbrevProbs.put("" + this.symbols.forIndex(i) + this.symbols.forIndex(j) + this.symbols.forIndex(k), prob);
                }
            }
        }
        return collapsedAbbrevProbs;
    }
}

