/*
 * Decompiled with CFR 0.152.
 */
package edu.umn.biomedicus.acronym;

import edu.umn.biomedicus.acronym.AcronymExpansionsModel;
import edu.umn.biomedicus.acronym.AcronymVectorModel;
import edu.umn.biomedicus.acronym.SparseVector;
import edu.umn.biomedicus.acronym.WordVectorSpace;
import edu.umn.biomedicus.exc.BiomedicusException;
import edu.umn.biomedicus.tokenization.Token;
import edu.umn.nlpengine.AbstractTextRange;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.nio.file.FileVisitOption;
import java.nio.file.FileVisitResult;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.SimpleFileVisitor;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AcronymVectorOfflineTrainer {
    public static final int DEFAULT_N_WORDS = 100000;
    private static final Logger LOGGER = LoggerFactory.getLogger(AcronymVectorOfflineTrainer.class);
    private static final String TEXTBREAK = "[^\\w\\-/]+";
    private static final Pattern initialJunk = Pattern.compile("^\\W+");
    private static final Pattern finalJunk = Pattern.compile("\\W+$");
    private static final long maxBytesToCountWords = 5000000000L;
    final AcronymExpansionsModel aem;
    private final Map<String, Set<String>> alternateFormOf;
    private final int nWords;
    @Nullable
    WordVectorSpace vectorSpace;
    private boolean ignoreDoubleAlternates = false;
    @Nullable
    private Map<String, SparseVector> senseVectors;
    @Nullable
    private Map<String, Integer> wordFrequency;
    private long bytesWordCounted = 0L;
    @Nullable
    private PhraseGraph phraseGraph;
    private long total = 0L;
    private long visited = 0L;

    public AcronymVectorOfflineTrainer(String expansionsFile, int nWords, @Nullable String alternateLongformsFile) throws BiomedicusException, IOException {
        this.nWords = nWords;
        this.aem = new AcronymExpansionsModel.Loader(Paths.get(expansionsFile, new String[0])).loadModel();
        HashSet<String> allExpansions = new HashSet<String>();
        for (String acronym : this.aem.getAcronyms()) {
            Collection<String> expansions = this.aem.getExpansions(acronym);
            if (expansions == null || expansions.size() <= 1) continue;
            allExpansions.addAll(expansions);
        }
        LOGGER.info(allExpansions.size() + " possible acronym expansions/senses");
        this.senseVectors = new TreeMap<String, SparseVector>();
        for (String expansion : allExpansions) {
            this.senseVectors.put(expansion, new SparseVector());
        }
        this.alternateFormOf = new HashMap<String, Set<String>>();
        if (alternateLongformsFile != null) {
            String line;
            HashSet<String> doublyReferencedAlternateForms = new HashSet<String>();
            BufferedReader alternateFormReader = new BufferedReader(new FileReader(alternateLongformsFile));
            LOGGER.info("Adding expansion phrase search equivalents");
            while ((line = alternateFormReader.readLine()) != null) {
                String[] fields = line.split("\\t");
                if (this.senseVectors.containsKey(fields[0])) {
                    for (int i = 1; i < fields.length; ++i) {
                        if (this.alternateFormOf.containsKey(fields[i]) && !this.alternateFormOf.get(fields[i]).equals(Collections.singleton(fields[0]))) {
                            doublyReferencedAlternateForms.add(fields[i]);
                            if (!this.ignoreDoubleAlternates) continue;
                            LOGGER.warn(String.format("%s appears as an alternate for multiple longforms; ignoring", fields[i]));
                            continue;
                        }
                        if (this.senseVectors.containsKey(fields[i])) {
                            if (this.ignoreDoubleAlternates) {
                                LOGGER.warn(String.format("%s appears as a sense and as an alternate form for another sense; ignoring alternate form use", fields[i]));
                                continue;
                            }
                            this.alternateFormOf.get(fields[i]).add(fields[0]);
                            continue;
                        }
                        this.alternateFormOf.put(fields[i], Collections.singleton(fields[0]));
                    }
                    continue;
                }
                LOGGER.warn("Trying to add alternate forms of \"" + fields[0] + "\", which is not a known sense of any abbreviation");
            }
            if (this.ignoreDoubleAlternates) {
                doublyReferencedAlternateForms.forEach(this.alternateFormOf::remove);
            }
            allExpansions.addAll(this.alternateFormOf.keySet());
        }
        LOGGER.info(allExpansions.size() + " possible senses, counting equivalents");
        this.phraseGraph = new PhraseGraph(allExpansions, this::tokenize);
    }

    public static void main(String[] args) throws BiomedicusException, IOException {
        String expansionsFile = args[0];
        String corpusPath = args[1];
        String outDir = args.length > 2 ? args[2] : ".";
        int nWords = args.length > 3 ? Integer.parseInt(args[3]) : 100000;
        String alternateLongformsFile = args.length > 4 ? args[4] : null;
        AcronymVectorOfflineTrainer trainer = new AcronymVectorOfflineTrainer(expansionsFile, nWords, alternateLongformsFile);
        trainer.countDocuments(corpusPath);
        trainer.trainOnCorpus(corpusPath);
        trainer.writeAcronymModel(outDir);
    }

    private void countDocuments(String corpusPath) throws IOException {
        this.total = Files.walk(Paths.get(corpusPath, new String[0]), new FileVisitOption[0]).count();
    }

    public void trainOnCorpus(String corpusPath) throws IOException {
        if (this.vectorSpace == null) {
            this.precountWords(corpusPath);
        }
        this.visited = 0L;
        Files.walkFileTree(Paths.get(corpusPath, new String[0]), new FileVectorizer(true));
    }

    public void precountWords(String corpusPath) throws IOException {
        this.vectorSpace = new WordVectorSpace();
        this.wordFrequency = new HashMap<String, Integer>();
        this.visited = 0L;
        Files.walkFileTree(Paths.get(corpusPath, new String[0]), new FileVectorizer(false));
        TreeSet<String> sortedWordFreq = new TreeSet<String>(new ByValue<String, Integer>(this.wordFrequency));
        sortedWordFreq.addAll(this.wordFrequency.keySet());
        HashMap<String, Integer> dictionary = new HashMap<String, Integer>();
        Iterator iter = sortedWordFreq.descendingIterator();
        for (int i = 0; i < this.nWords && iter.hasNext(); ++i) {
            String word = (String)iter.next();
            dictionary.put(word, i);
        }
        this.vectorSpace.setDictionary(dictionary);
    }

    public void writeAcronymModel(String outFile) throws IOException {
        assert (this.vectorSpace != null);
        assert (this.senseVectors != null);
        this.vectorSpace.buildIdf();
        SparseVector idf = this.vectorSpace.getIdf();
        LOGGER.info("Creating vectors for senses");
        for (Map.Entry<String, SparseVector> e : this.senseVectors.entrySet()) {
            SparseVector vector = e.getValue();
            vector.applyOperation(Math::sqrt);
            vector.multiply(idf);
            vector.normVector();
            vector.multiply(idf);
            vector.multiply(idf);
        }
        LOGGER.info(this.senseVectors.size() + " vectors total");
        LOGGER.info("initializing acronym vector model");
        AcronymVectorModel avm = new AcronymVectorModel(this.vectorSpace, null, this.aem, null, 0.0);
        LOGGER.info("writing acronym vector model");
        Path outPath = Paths.get(outFile, new String[0]);
        avm.writeToDirectory(outPath, this.senseVectors);
    }

    private String[] tokenize(String orig) {
        orig = initialJunk.matcher(orig).replaceFirst("");
        orig = finalJunk.matcher(orig).replaceFirst("");
        return orig.toLowerCase().split(TEXTBREAK);
    }

    private void vectorizeForWord(String expansion, List<Token> words, int startPos, int endPos) {
        assert (this.vectorSpace != null);
        assert (this.senseVectors != null);
        SparseVector vec = this.vectorSpace.vectorize(words, startPos, endPos);
        this.senseVectors.get(expansion).add(vec);
    }

    private void vectorizeChunk(String context) {
        assert (this.phraseGraph != null);
        List<Token> words = Arrays.stream(this.tokenize(context)).map(x$0 -> new DummyToken((String)x$0)).collect(Collectors.toList());
        for (int i = 0; i < words.size(); ++i) {
            String result = this.phraseGraph.getLongestPhraseFrom(words, i);
            if (result == null) continue;
            Set<String> fullPhrases = this.alternateFormOf.getOrDefault(result, Collections.singleton(result));
            for (String fullPhrase : fullPhrases) {
                this.vectorizeForWord(fullPhrase, words, i, i + this.tokenize(result).length);
            }
        }
    }

    private void countChunk(String context) {
        String[] words;
        assert (this.wordFrequency != null);
        for (String word : words = this.tokenize(context)) {
            Integer oldVal = this.wordFrequency.putIfAbsent(word, 1);
            if (oldVal == null) continue;
            this.wordFrequency.put(word, oldVal + 1);
        }
    }

    public class ByValue<K extends Comparable<K>, V extends Comparable<V>>
    implements Comparator<K> {
        private Map<K, V> map;

        public ByValue(Map<K, V> map) {
            this.map = map;
        }

        @Override
        public int compare(K o1, K o2) {
            Comparable v2;
            Comparable v1 = (Comparable)this.map.get(o1);
            if (v1 == (v2 = (Comparable)this.map.get(o2))) {
                return 0;
            }
            int cmp = v1.compareTo(v2);
            if (cmp != 0) {
                return cmp;
            }
            return o1.compareTo(o2);
        }
    }

    private class FileVectorizer
    extends SimpleFileVisitor<Path> {
        private boolean vectorizeNotCount;

        FileVectorizer(boolean vectorizeNotCount) {
            this.vectorizeNotCount = vectorizeNotCount;
        }

        @Override
        public FileVisitResult visitFile(Path file, BasicFileAttributes attr) throws IOException {
            if (file.getFileName().toString().startsWith(".")) {
                return FileVisitResult.CONTINUE;
            }
            if (file.toFile().length() < 100000000L) {
                Scanner scanner = new Scanner(file.toFile()).useDelimiter("\\Z");
                String fileText = scanner.next();
                scanner.close();
                if (this.vectorizeNotCount) {
                    AcronymVectorOfflineTrainer.this.vectorizeChunk(fileText);
                } else {
                    AcronymVectorOfflineTrainer.this.countChunk(fileText);
                    AcronymVectorOfflineTrainer.this.bytesWordCounted = AcronymVectorOfflineTrainer.this.bytesWordCounted + (long)fileText.length();
                    if (AcronymVectorOfflineTrainer.this.bytesWordCounted >= 5000000000L) {
                        LOGGER.info("Done counting words.");
                        return FileVisitResult.TERMINATE;
                    }
                }
            } else {
                BufferedReader reader = new BufferedReader(new FileReader(file.toFile()));
                char[] chunk = new char[10000000];
                long totalLength = 0L;
                while (reader.read(chunk) > 0) {
                    StringBuilder lineBuilder = new StringBuilder(new String(chunk));
                    while (true) {
                        int nextByte = reader.read();
                        char nextChar = (char)nextByte;
                        if (nextByte < 0 || nextChar == ' ' || nextChar == '\t' || nextChar == '\n') break;
                        lineBuilder.append((char)nextByte);
                    }
                    String line = lineBuilder.toString();
                    totalLength += (long)line.length();
                    if (this.vectorizeNotCount) {
                        AcronymVectorOfflineTrainer.this.vectorizeChunk(line);
                    } else {
                        AcronymVectorOfflineTrainer.this.countChunk(line);
                        LOGGER.info(AcronymVectorOfflineTrainer.this.wordFrequency.size() + " total words found");
                        AcronymVectorOfflineTrainer.this.bytesWordCounted = AcronymVectorOfflineTrainer.this.bytesWordCounted + (long)line.length();
                        if (AcronymVectorOfflineTrainer.this.bytesWordCounted >= 5000000000L) {
                            LOGGER.info("Done counting words.");
                            return FileVisitResult.TERMINATE;
                        }
                    }
                    LOGGER.info(totalLength + " bytes of large file " + file + " processed");
                }
                reader.close();
            }
            LOGGER.trace(file + " visited");
            AcronymVectorOfflineTrainer.this.visited++;
            if (AcronymVectorOfflineTrainer.this.visited % 1000L == 0L) {
                LOGGER.info("Visited {} of {}", (Object)AcronymVectorOfflineTrainer.this.visited, (Object)AcronymVectorOfflineTrainer.this.total);
            }
            return FileVisitResult.CONTINUE;
        }
    }

    private class DummyToken
    extends AbstractTextRange
    implements Token {
        private String text;

        DummyToken(String text) {
            this.text = text;
        }

        @Override
        public String getText() {
            return this.text;
        }

        @Override
        public boolean getHasSpaceAfter() {
            return true;
        }

        public int getStartIndex() {
            return 0;
        }

        public int getEndIndex() {
            return 0;
        }
    }

    class PhraseGraph {
        private final Map<String, Object> graph = new HashMap<String, Object>();

        public PhraseGraph(Iterable<String> phrases, Function<String, String[]> tokenizer) {
            for (String phrase : phrases) {
                ArrayList<Object> words = new ArrayList<Object>(Arrays.asList((Object[])tokenizer.apply(phrase)));
                Map addToThisMap = this.graph;
                do {
                    String firstWord = (String)words.get(0);
                    addToThisMap.putIfAbsent((String)firstWord, new HashMap());
                    addToThisMap = (Map)addToThisMap.get(firstWord);
                    words.remove(0);
                } while (words.size() != 0);
                addToThisMap.put(null, phrase);
            }
        }

        @Nullable
        public String getLongestPhraseFrom(List<Token> words, int index) {
            String longestEligiblePhrase = null;
            Map lookup = this.graph;
            for (int i = index; i < words.size(); ++i) {
                String thisWord = words.get(i).getText();
                if (lookup.containsKey(null)) {
                    longestEligiblePhrase = (String)lookup.get(null);
                }
                if (!lookup.containsKey(thisWord)) break;
                lookup = (Map)lookup.get(thisWord);
            }
            return longestEligiblePhrase;
        }
    }
}

