/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.word2vec.wordstore;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VocabConstructor<T extends SequenceElement> {
    private List<VocabSource<T>> sources = new ArrayList<VocabSource<T>>();
    private VocabCache<T> cache;
    private List<String> stopWords;
    private boolean useAdaGrad = false;
    private boolean fetchLabels = false;
    private int limit;
    private AtomicLong seqCount = new AtomicLong(0L);
    private InvertedIndex<T> index;
    protected static final Logger log = LoggerFactory.getLogger(VocabConstructor.class);

    private VocabConstructor() {
    }

    protected WeightLookupTable<T> buildExtendedLookupTable() {
        return null;
    }

    protected VocabCache<T> buildExtendedVocabulary() {
        return null;
    }

    public VocabCache<T> buildMergedVocabulary(@NonNull WordVectors wordVectors, boolean fetchLabels) {
        if (wordVectors == null) {
            throw new NullPointerException("wordVectors");
        }
        return this.buildMergedVocabulary(wordVectors.vocab(), fetchLabels);
    }

    public long getNumberOfSequences() {
        return this.seqCount.get();
    }

    public VocabCache<T> buildMergedVocabulary(@NonNull VocabCache<T> vocabCache, boolean fetchLabels) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (this.cache == null) {
            this.cache = new AbstractCache.Builder().build();
        }
        for (int t = 0; t < vocabCache.numWords(); ++t) {
            String label = vocabCache.wordAtIndex(t);
            if (label == null) continue;
            T element = vocabCache.wordFor(label);
            if (!fetchLabels && ((SequenceElement)element).isLabel()) continue;
            this.cache.addToken(element);
            this.cache.addWordToIndex(((SequenceElement)element).getIndex(), ((SequenceElement)element).getLabel());
            this.cache.putVocabWord(((SequenceElement)element).getLabel());
        }
        if (this.cache.numWords() == 0) {
            throw new IllegalStateException("Source VocabCache has no indexes available, transfer is impossible");
        }
        log.info("Vocab size before labels: " + this.cache.numWords());
        if (fetchLabels) {
            for (VocabSource<T> source : this.sources) {
                SequenceIterator<T> iterator = source.getIterator();
                iterator.reset();
                while (iterator.hasMoreSequences()) {
                    Sequence<T> sequence = iterator.nextSequence();
                    this.seqCount.incrementAndGet();
                    for (SequenceElement label : sequence.getSequenceLabels()) {
                        if (!this.cache.containsWord(label.getLabel())) {
                            label.markAsLabel(true);
                            label.setSpecial(true);
                            label.setIndex(this.cache.numWords());
                            this.cache.addToken(label);
                            this.cache.addWordToIndex(label.getIndex(), label.getLabel());
                            this.cache.putVocabWord(label.getLabel());
                            log.info("Adding label [" + label.getLabel() + "]: " + this.cache.wordFor(label.getLabel()));
                            continue;
                        }
                        log.info("Label [" + label.getLabel() + "] already exists: " + this.cache.wordFor(label.getLabel()));
                    }
                }
            }
        }
        log.info("Vocab size after labels: " + this.cache.numWords());
        return this.cache;
    }

    public VocabCache<T> buildJointVocabulary(boolean resetCounters, boolean buildHuffmanTree) {
        if (resetCounters && buildHuffmanTree) {
            throw new IllegalStateException("You can't reset counters and build Huffman tree at the same time!");
        }
        if (this.cache == null) {
            this.cache = new AbstractCache.Builder().build();
        }
        log.debug("Target vocab size before building: [" + this.cache.numWords() + "]");
        AtomicLong elementsCounter = new AtomicLong(0L);
        AbstractCache topHolder = new AbstractCache.Builder().minElementFrequency(0).build();
        int cnt = 0;
        for (VocabSource<T> source : this.sources) {
            SequenceIterator<T> iterator = source.getIterator();
            iterator.reset();
            log.debug("Trying source iterator: [" + cnt + "]");
            log.debug("Target vocab size before building: [" + this.cache.numWords() + "]");
            ++cnt;
            AbstractCache<T> tempHolder = new AbstractCache.Builder().build();
            int sequences = 0;
            long counter = 0L;
            while (iterator.hasMoreSequences()) {
                Sequence<T> document = iterator.nextSequence();
                this.seqCount.incrementAndGet();
                tempHolder.incrementTotalDocCount();
                HashMap<String, AtomicLong> seqMap = new HashMap<String, AtomicLong>();
                if (this.fetchLabels) {
                    T labelWord = document.getSequenceLabel();
                    ((SequenceElement)labelWord).setSpecial(true);
                    ((SequenceElement)labelWord).markAsLabel(true);
                    ((SequenceElement)labelWord).setElementFrequency(1L);
                    tempHolder.addToken(labelWord);
                }
                List<String> tokens = document.asLabels();
                for (String token : tokens) {
                    T element;
                    if (this.stopWords != null && this.stopWords.contains(token) || token == null || token.isEmpty()) continue;
                    if (!tempHolder.containsWord(token)) {
                        element = document.getElementByLabel(token);
                        ((SequenceElement)element).setElementFrequency(1L);
                        tempHolder.addToken(element);
                        elementsCounter.incrementAndGet();
                        ++counter;
                        ((SequenceElement)element).setSequencesCount(1L);
                        seqMap.put(token, new AtomicLong(0L));
                        continue;
                    }
                    ++counter;
                    tempHolder.incrementWordCount(token);
                    if (!seqMap.containsKey(token)) {
                        seqMap.put(token, new AtomicLong(1L));
                        element = tempHolder.wordFor(token);
                        ((SequenceElement)element).incrementSequencesCount();
                    }
                    if (this.index == null) continue;
                    if (document.getSequenceLabel() != null) {
                        this.index.addWordsToDoc(this.index.numDocuments(), document.getElements(), document.getSequenceLabel());
                        continue;
                    }
                    this.index.addWordsToDoc(this.index.numDocuments(), document.getElements());
                }
                ++sequences;
                if (this.seqCount.get() % 100000L != 0L) continue;
                log.info("Sequences checked: [" + this.seqCount.get() + "], Current vocabulary size: [" + elementsCounter.get() + "]");
            }
            log.debug("Vocab size before truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + sequences + "], counter: [" + counter + "]");
            if (source.getMinWordFrequency() > 0) {
                LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<String>();
                for (SequenceElement element : tempHolder.vocabWords()) {
                    if (!(element.getElementFrequency() < (double)source.getMinWordFrequency()) || element.isSpecial() || element.isLabel()) continue;
                    labelsToRemove.add(element.getLabel());
                }
                for (String label : labelsToRemove) {
                    tempHolder.removeElement(label);
                }
            }
            log.debug("Vocab size after truncation: [" + tempHolder.numWords() + "],  NumWords: [" + tempHolder.totalWordOccurrences() + "], sequences parsed: [" + sequences + "], counter: [" + counter + "]");
            topHolder.importVocabulary(tempHolder);
        }
        this.cache.importVocabulary(topHolder);
        if (resetCounters) {
            for (SequenceElement element : this.cache.vocabWords()) {
                element.setElementFrequency(0L);
            }
            this.cache.updateWordsOccurencies();
        }
        if (buildHuffmanTree) {
            Huffman huffman = new Huffman(this.cache.vocabWords());
            huffman.build();
            huffman.applyIndexes(this.cache);
            if (this.limit > 0) {
                LinkedBlockingQueue<String> labelsToRemove = new LinkedBlockingQueue<String>();
                for (SequenceElement element : this.cache.vocabWords()) {
                    if (element.getIndex() <= this.limit || element.isSpecial() || element.isLabel()) continue;
                    labelsToRemove.add(element.getLabel());
                }
                for (String label : labelsToRemove) {
                    this.cache.removeElement(label);
                }
            }
        }
        log.info("Sequences checked: [" + this.seqCount.get() + "], Current vocabulary size: [" + this.cache.numWords() + "]");
        return this.cache;
    }

    private static class VocabSource<T extends SequenceElement> {
        @NonNull
        private SequenceIterator<T> iterator;
        @NonNull
        private int minWordFrequency;

        @ConstructorProperties(value={"iterator", "minWordFrequency"})
        public VocabSource(@NonNull SequenceIterator<T> iterator, @NonNull int minWordFrequency) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = iterator;
            this.minWordFrequency = minWordFrequency;
        }

        @NonNull
        public SequenceIterator<T> getIterator() {
            return this.iterator;
        }

        @NonNull
        public int getMinWordFrequency() {
            return this.minWordFrequency;
        }

        public void setIterator(@NonNull SequenceIterator<T> iterator) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.iterator = iterator;
        }

        public void setMinWordFrequency(@NonNull int minWordFrequency) {
            this.minWordFrequency = minWordFrequency;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof VocabSource)) {
                return false;
            }
            VocabSource other = (VocabSource)o;
            if (!other.canEqual(this)) {
                return false;
            }
            SequenceIterator<T> this$iterator = this.getIterator();
            SequenceIterator<T> other$iterator = other.getIterator();
            if (this$iterator == null ? other$iterator != null : !this$iterator.equals(other$iterator)) {
                return false;
            }
            return this.getMinWordFrequency() == other.getMinWordFrequency();
        }

        protected boolean canEqual(Object other) {
            return other instanceof VocabSource;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            SequenceIterator<T> $iterator = this.getIterator();
            result = result * 59 + ($iterator == null ? 0 : $iterator.hashCode());
            result = result * 59 + this.getMinWordFrequency();
            return result;
        }

        public String toString() {
            return "VocabConstructor.VocabSource(iterator=" + this.getIterator() + ", minWordFrequency=" + this.getMinWordFrequency() + ")";
        }
    }

    public static class Builder<T extends SequenceElement> {
        private List<VocabSource<T>> sources = new ArrayList<VocabSource<T>>();
        private VocabCache<T> cache;
        private List<String> stopWords = new ArrayList<String>();
        private boolean useAdaGrad = false;
        private boolean fetchLabels = false;
        private InvertedIndex<T> index;
        private int limit;

        public Builder<T> setEntriesLimit(int limit) {
            this.limit = limit;
            return this;
        }

        protected Builder<T> useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder<T> setTargetVocabCache(@NonNull VocabCache<T> cache) {
            if (cache == null) {
                throw new NullPointerException("cache");
            }
            this.cache = cache;
            return this;
        }

        public Builder<T> addSource(@NonNull SequenceIterator<T> iterator, int minElementFrequency) {
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            this.sources.add(new VocabSource<T>(iterator, minElementFrequency));
            return this;
        }

        public Builder<T> setStopWords(@NonNull List<String> stopWords) {
            if (stopWords == null) {
                throw new NullPointerException("stopWords");
            }
            this.stopWords = stopWords;
            return this;
        }

        public Builder<T> fetchLabels(boolean reallyFetch) {
            this.fetchLabels = reallyFetch;
            return this;
        }

        public Builder<T> setIndex(InvertedIndex<T> index) {
            this.index = index;
            return this;
        }

        public VocabConstructor<T> build() {
            VocabConstructor constructor = new VocabConstructor();
            constructor.sources = this.sources;
            constructor.cache = this.cache;
            constructor.stopWords = this.stopWords;
            constructor.useAdaGrad = this.useAdaGrad;
            constructor.fetchLabels = this.fetchLabels;
            constructor.limit = this.limit;
            constructor.index = this.index;
            return constructor;
        }
    }
}

