/*
 * Decompiled with CFR 0.152.
 */
package edu.umn.biomedicus.acronym;

import edu.umn.biomedicus.acronym.Acronyms;
import edu.umn.biomedicus.acronym.SparseVector;
import edu.umn.biomedicus.tokenization.Token;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.regex.Pattern;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WordVectorSpace {
    private static final Logger LOGGER = LoggerFactory.getLogger(WordVectorSpace.class);
    private static final double SLOPE = 0.3;
    private static final BiFunction<Integer, Double, Double> DIST_WEIGHT = (BiFunction<Integer, Double, Double> & Serializable)(dist, maxDist) -> 1.0 / (1.0 + Math.exp(0.3 * ((double)Math.abs(dist) - maxDist)));
    private static final double IDF_POWER = 1.0;
    private static final double THRESH_WEIGHT = 0.25;
    private static final Pattern ALPHANUMERIC = Pattern.compile("[a-zA-Z0-9.&_]*");
    private transient double maxDist;
    private transient double windowSize;
    private Map<String, Integer> dictionary = new HashMap<String, Integer>();
    private Map<Integer, Integer> documentsPerTerm = new HashMap<Integer, Integer>();
    private long totalDocs = 0L;
    private SparseVector idf;
    private boolean countingDocuments = true;
    private boolean buildingDictionary = true;

    public WordVectorSpace() {
        this.setMaxDist(9.0);
    }

    public double getMaxDist() {
        return this.maxDist;
    }

    public void setMaxDist(double maxDist) {
        this.maxDist = maxDist;
        this.windowSize = Math.log(3.0) / 0.3 + maxDist;
    }

    public SparseVector getIdf() {
        return this.idf;
    }

    public void setIdf(SparseVector idf) {
        this.idf = idf;
    }

    public Map<String, Integer> getDictionary() {
        return this.dictionary;
    }

    public void setDictionary(Map<String, Integer> dictionary) {
        this.dictionary = dictionary;
        this.buildingDictionary = false;
    }

    public Map<Integer, Integer> getDocumentsPerTerm() {
        return this.documentsPerTerm;
    }

    public void setDocumentsPerTerm(Map<Integer, Integer> documentsPerTerm) {
        this.documentsPerTerm = documentsPerTerm;
    }

    public long getTotalDocs() {
        return this.totalDocs;
    }

    public void setTotalDocs(long totalDocs) {
        this.totalDocs = totalDocs;
    }

    public boolean getBuildingDictionary() {
        return this.buildingDictionary;
    }

    public void setBuildingDictionary(boolean buildingDictionary) {
        this.buildingDictionary = buildingDictionary;
    }

    public boolean getCountingDocuments() {
        return this.countingDocuments;
    }

    public void setCountingDocuments(boolean countingDocuments) {
        this.countingDocuments = countingDocuments;
    }

    public void buildIdf() {
        HashMap<Integer, Double> idf = new HashMap<Integer, Double>();
        for (Map.Entry<Integer, Integer> e : this.documentsPerTerm.entrySet()) {
            double logged = Math.pow(Math.log((1.0 + (double)this.totalDocs) / (double)e.getValue().intValue()), 1.0);
            idf.put(e.getKey(), logged);
        }
        this.idf = new SparseVector(idf);
        this.countingDocuments = false;
    }

    SparseVector vectorize(List<? extends Token> context, int startCenterToken, int stopCenterToken) {
        HashMap<Integer, Double> wordVector = new HashMap<Integer, Double>();
        int startIndex = Math.max(startCenterToken - (int)this.windowSize, 0);
        int stopIndex = Math.min(stopCenterToken + (int)this.windowSize, context.size());
        for (int i = startIndex; i < stopIndex; ++i) {
            String word;
            if (i == startCenterToken) {
                if (stopCenterToken >= context.size()) break;
                i = stopCenterToken;
            }
            if (!ALPHANUMERIC.matcher(word = Acronyms.standardContextForm(context.get(i))).matches()) continue;
            int wordInt = this.dictionary.getOrDefault(word, -1);
            if (this.buildingDictionary && wordInt == -1) {
                wordInt = this.dictionary.size();
                this.dictionary.put(word, wordInt);
            }
            if (this.countingDocuments) {
                int docPerTerm = this.documentsPerTerm.getOrDefault(wordInt, 0);
                this.documentsPerTerm.put(wordInt, docPerTerm + 1);
            }
            if (wordInt == -1) continue;
            int dist = i < startCenterToken ? startCenterToken - i : i - stopCenterToken;
            double thisIncrement = DIST_WEIGHT.apply(dist, this.maxDist);
            double oldCount = wordVector.getOrDefault(wordInt, 0.0);
            wordVector.put(wordInt, oldCount + thisIncrement);
        }
        if (this.countingDocuments) {
            ++this.totalDocs;
        }
        return new SparseVector(wordVector);
    }

    public SparseVector vectorize(List<? extends Token> context, int centerToken) {
        return this.vectorize(context, centerToken, centerToken + 1);
    }

    @Nullable
    public Integer removeWord(String word) {
        LOGGER.info("removing word {}", (Object)word);
        Integer wordInt = this.dictionary.remove(word);
        if (wordInt != null) {
            this.idf.set(wordInt, 0.0);
            this.documentsPerTerm.remove(wordInt);
        }
        return wordInt;
    }

    public Set<Integer> removeWordsExcept(Set<String> wordsToKeep) {
        LOGGER.info("dictionary size before de-ID: {}", (Object)this.dictionary.size());
        HashSet<Integer> indicesRemoved = new HashSet<Integer>();
        HashSet<String> wordsInDictionary = new HashSet<String>(this.dictionary.keySet());
        for (String word : wordsInDictionary) {
            String standardWord = Acronyms.standardContextForm(word);
            if (wordsToKeep.contains(standardWord)) continue;
            indicesRemoved.add(this.removeWord(word));
        }
        LOGGER.info("{} indices removed", (Object)indicesRemoved.size());
        LOGGER.info("dictionary size after de-ID: {}", (Object)this.dictionary.size());
        return indicesRemoved;
    }
}

