/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.cli.transforms.text.nlp;

import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.commons.math3.util.Pair;
import org.datavec.api.berkeley.Counter;
import org.datavec.api.conf.Configuration;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.datavec.cli.transforms.Transform;
import org.datavec.cli.transforms.text.nlp.NLPUtils;
import org.datavec.nlp.metadata.DefaultVocabCache;
import org.datavec.nlp.metadata.VocabCache;
import org.datavec.nlp.stopwords.StopWords;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizer.preprocessor.EndingPreProcessor;
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class TfidfTextVectorizerTransform
implements Transform {
    protected TokenizerFactory tokenizerFactory;
    protected int minWordFrequency = 0;
    public static final String MIN_WORD_FREQUENCY = "org.nd4j.nlp.minwordfrequency";
    public static final String STOP_WORDS = "org.nd4j.nlp.stopwords";
    public static final String TOKENIZER = "org.datavec.nlp.tokenizerfactory";
    protected Collection<String> stopWords;
    protected VocabCache cache;
    public Map<String, Pair<Integer, Integer>> recordLabels = new LinkedHashMap<String, Pair<Integer, Integer>>();
    final EndingPreProcessor preProcessor = new EndingPreProcessor();

    public int getVocabularySize() {
        return this.cache.vocabWords().size();
    }

    public void debugPrintVocabList() {
        System.out.println("Vocabulary Words: ");
        for (int i = 0; i < this.cache.vocabWords().size(); ++i) {
            System.out.println(i + ". " + this.cache.wordAt(i));
        }
    }

    public void doWithTokens(Tokenizer tokenizer) {
        HashSet seen = new HashSet();
        while (tokenizer.hasMoreTokens()) {
            String token = tokenizer.nextToken();
            this.cache.incrementCount(token);
            if (seen.contains(token)) continue;
            this.cache.incrementDocCount(token);
        }
    }

    public TokenizerFactory createTokenizerFactory(Configuration conf) {
        String clazz = conf.get(TOKENIZER, DefaultTokenizerFactory.class.getName());
        try {
            Class<?> tokenizerFactoryClazz = Class.forName(clazz);
            return (TokenizerFactory)tokenizerFactoryClazz.newInstance();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void initialize(Configuration conf) {
        this.tokenizerFactory = this.createTokenizerFactory(conf);
        this.tokenizerFactory.setTokenPreProcessor(new TokenPreProcess(){

            public String preProcess(String token) {
                if (!token.startsWith("http://")) {
                    token = token.replaceAll("[^a-zA-Z ]", "").toLowerCase();
                }
                String base = TfidfTextVectorizerTransform.this.preProcessor.preProcess(token);
                if ((base = base.replaceAll("\\d", "d")).endsWith("ly") || base.endsWith("ing")) {
                    System.out.println();
                }
                return base;
            }
        });
        this.minWordFrequency = conf.getInt(MIN_WORD_FREQUENCY, 5);
        this.stopWords = conf.getStringCollection(STOP_WORDS);
        if (this.stopWords == null || this.stopWords.isEmpty()) {
            this.stopWords = StopWords.getStopWords();
        }
        this.cache = new DefaultVocabCache(this.minWordFrequency);
    }

    protected Counter<String> wordFrequenciesForSentence(String sentence) {
        Tokenizer tokenizer = this.tokenizerFactory.create(sentence);
        Counter ret = new Counter();
        while (tokenizer.hasMoreTokens()) {
            try {
                String token = tokenizer.nextToken();
                ret.incrementCount((Object)token, 1.0);
            }
            catch (NoSuchElementException e) {
                System.out.println("Bad Token");
            }
        }
        return ret;
    }

    public INDArray convertTextRecordToTFIDFVector(String textRecord) {
        Counter<String> wordFrequenciesForDocument = this.wordFrequenciesForSentence(textRecord);
        INDArray ret = Nd4j.create((int)this.cache.vocabWords().size());
        int totalDocsInCorpus = (int)this.cache.numDocs();
        for (int i = 0; i < this.cache.vocabWords().size(); ++i) {
            String term = this.cache.wordAt(i);
            int termFreq_ForThisSpecificDoc = (int)wordFrequenciesForDocument.getCount((Object)this.cache.wordAt(i));
            int numberOfDocsThisTermAppearsIn = (int)this.cache.idf(term);
            double tf_term = NLPUtils.tf(termFreq_ForThisSpecificDoc);
            double idf_term = NLPUtils.idf(totalDocsInCorpus, numberOfDocsThisTermAppearsIn);
            double tfidf_term = NLPUtils.tfidf(tf_term, idf_term);
            ret.putScalar(i, tfidf_term);
        }
        return ret;
    }

    @Override
    public void collectStatistics(Collection<Writable> vector) {
        String label = vector.toArray()[1].toString();
        String sentence = vector.toArray()[0].toString();
        this.trackLabel(label);
        Tokenizer tokenizer = this.tokenizerFactory.create(sentence);
        this.cache.incrementNumDocs(1.0);
        this.doWithTokens(tokenizer);
    }

    private void trackLabel(String label_value) {
        String trimmedKey = label_value.trim();
        if (this.recordLabels.containsKey(trimmedKey)) {
            Integer countInt;
            Integer labelID = (Integer)this.recordLabels.get(trimmedKey).getFirst();
            Integer n = countInt = (Integer)this.recordLabels.get(trimmedKey).getSecond();
            Integer n2 = countInt = Integer.valueOf(countInt + 1);
            this.recordLabels.put(trimmedKey, (Pair<Integer, Integer>)new Pair((Object)labelID, (Object)countInt));
        } else {
            Integer labelID = this.recordLabels.size();
            this.recordLabels.put(trimmedKey, (Pair<Integer, Integer>)new Pair((Object)labelID, (Object)1));
        }
    }

    public int getNumberOfLabelsSeen() {
        return this.recordLabels.keySet().size();
    }

    public Integer getLabelID(String label) {
        if (this.recordLabels.containsKey(label)) {
            return (Integer)this.recordLabels.get(label).getFirst();
        }
        return null;
    }

    @Override
    public void transform(Collection<Writable> vector) {
        if (vector.size() != 2) {
            return;
        }
        String textRecord = vector.toArray()[0].toString();
        String label = vector.toArray()[1].toString();
        Integer labelID = this.getLabelID(label);
        INDArray tfidfVector = this.convertTextRecordToTFIDFVector(textRecord);
        vector.clear();
        for (int colID = 0; colID < tfidfVector.columns(); ++colID) {
            vector.add((Writable)new DoubleWritable(tfidfVector.getDouble(0, colID)));
        }
        vector.add((Writable)new DoubleWritable((double)labelID.intValue()));
    }

    @Override
    public void evaluateStatistics() {
    }
}

