/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.iterator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import lombok.NonNull;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.iterator.provider.LabelAwareConverter;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.documentiterator.LabelAwareDocumentIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.interoperability.DocumentIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.interoperability.SentenceIteratorConverter;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareSentenceIterator;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class CnnSentenceDataSetIterator
implements DataSetIterator {
    private static final String UNKNOWN_WORD_SENTINEL = "UNKNOWN_WORD_SENTINEL";
    private Format format;
    private LabeledSentenceProvider sentenceProvider;
    private WordVectors wordVectors;
    private TokenizerFactory tokenizerFactory;
    private UnknownWordHandling unknownWordHandling;
    private boolean useNormalizedWordVectors;
    private int minibatchSize;
    private int maxSentenceLength;
    private boolean sentencesAlongHeight;
    private DataSetPreProcessor dataSetPreProcessor;
    private int wordVectorSize;
    private int numClasses;
    private Map<String, Integer> labelClassMap;
    private INDArray unknown;
    private int cursor = 0;
    private Pair<List<String>, String> preLoadedTokens;

    protected CnnSentenceDataSetIterator(Builder builder) {
        this.format = builder.format;
        this.sentenceProvider = builder.sentenceProvider;
        this.wordVectors = builder.wordVectors;
        this.tokenizerFactory = builder.tokenizerFactory;
        this.unknownWordHandling = builder.unknownWordHandling;
        this.useNormalizedWordVectors = builder.useNormalizedWordVectors;
        this.minibatchSize = builder.minibatchSize;
        this.maxSentenceLength = builder.maxSentenceLength;
        this.sentencesAlongHeight = builder.sentencesAlongHeight;
        this.dataSetPreProcessor = builder.dataSetPreProcessor;
        this.numClasses = this.sentenceProvider.numLabelClasses();
        this.labelClassMap = new HashMap<String, Integer>();
        int count = 0;
        ArrayList<String> sortedLabels = new ArrayList<String>(this.sentenceProvider.allLabels());
        Collections.sort(sortedLabels);
        this.wordVectorSize = this.wordVectors.getWordVector(this.wordVectors.vocab().wordAtIndex(0)).length;
        for (String s : sortedLabels) {
            this.labelClassMap.put(s, count++);
        }
        if (this.unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
            this.unknown = this.useNormalizedWordVectors ? this.wordVectors.getWordVectorMatrixNormalized(this.wordVectors.getUNK()) : this.wordVectors.getWordVectorMatrix(this.wordVectors.getUNK());
            if (this.unknown == null) {
                this.unknown = this.wordVectors.getWordVectorMatrix(this.wordVectors.vocab().wordAtIndex(0)).like();
            }
        }
    }

    public INDArray loadSingleSentence(String sentence) {
        List<String> tokens = this.tokenizeSentence(sentence);
        if (tokens.isEmpty()) {
            throw new IllegalStateException("No tokens available for input sentence - empty string or no words in vocabulary with RemoveWord unknown handling? Sentence = \"" + sentence + "\"");
        }
        if (this.format == Format.CNN1D || this.format == Format.RNN) {
            int[] featuresShape = new int[]{1, this.wordVectorSize, Math.min(this.maxSentenceLength, tokens.size())};
            INDArray features = Nd4j.create((int[])featuresShape, (char)(this.format == Format.CNN1D ? (char)'c' : 'f'));
            INDArrayIndex[] indices = new INDArrayIndex[3];
            indices[0] = NDArrayIndex.point((long)0L);
            for (int i = 0; i < featuresShape[2]; ++i) {
                INDArray vector = this.getVector(tokens.get(i));
                indices[1] = NDArrayIndex.all();
                indices[2] = NDArrayIndex.point((long)i);
                features.put(indices, vector);
            }
            return features;
        }
        int[] featuresShape = new int[]{1, 1, 0, 0};
        if (this.sentencesAlongHeight) {
            featuresShape[2] = Math.min(this.maxSentenceLength, tokens.size());
            featuresShape[3] = this.wordVectorSize;
        } else {
            featuresShape[2] = this.wordVectorSize;
            featuresShape[3] = Math.min(this.maxSentenceLength, tokens.size());
        }
        INDArray features = Nd4j.create((int[])featuresShape);
        int length = this.sentencesAlongHeight ? featuresShape[2] : featuresShape[3];
        INDArrayIndex[] indices = new INDArrayIndex[4];
        indices[0] = NDArrayIndex.point((long)0L);
        indices[1] = NDArrayIndex.point((long)0L);
        for (int i = 0; i < length; ++i) {
            INDArray vector = this.getVector(tokens.get(i));
            if (this.sentencesAlongHeight) {
                indices[2] = NDArrayIndex.point((long)i);
                indices[3] = NDArrayIndex.all();
            } else {
                indices[2] = NDArrayIndex.all();
                indices[3] = NDArrayIndex.point((long)i);
            }
            features.put(indices, vector);
        }
        return features;
    }

    private INDArray getVector(String word) {
        INDArray vector = this.unknownWordHandling == UnknownWordHandling.UseUnknownVector && word == UNKNOWN_WORD_SENTINEL ? this.unknown : (this.useNormalizedWordVectors ? this.wordVectors.getWordVectorMatrixNormalized(word) : this.wordVectors.getWordVectorMatrix(word));
        return vector;
    }

    private List<String> tokenizeSentence(String sentence) {
        Tokenizer t = this.tokenizerFactory.create(sentence);
        ArrayList<String> tokens = new ArrayList<String>();
        block4: while (t.hasMoreTokens()) {
            String token = t.nextToken();
            if (!this.wordVectors.outOfVocabularySupported() && !this.wordVectors.hasWord(token)) {
                switch (this.unknownWordHandling) {
                    case RemoveWord: {
                        continue block4;
                    }
                    case UseUnknownVector: {
                        token = UNKNOWN_WORD_SENTINEL;
                    }
                }
            }
            tokens.add(token);
        }
        return tokens;
    }

    public Map<String, Integer> getLabelClassMap() {
        return new HashMap<String, Integer>(this.labelClassMap);
    }

    public List<String> getLabels() {
        String[] str = new String[this.labelClassMap.size()];
        for (Map.Entry<String, Integer> e : this.labelClassMap.entrySet()) {
            str[e.getValue().intValue()] = e.getKey();
        }
        return Arrays.asList(str);
    }

    public boolean hasNext() {
        if (this.sentenceProvider == null) {
            throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
        }
        while (this.preLoadedTokens == null && this.sentenceProvider.hasNext()) {
            this.preLoadTokens();
        }
        return this.preLoadedTokens != null;
    }

    private void preLoadTokens() {
        if (this.preLoadedTokens != null) {
            return;
        }
        Pair<String, String> p = this.sentenceProvider.nextSentence();
        List<String> tokens = this.tokenizeSentence((String)p.getFirst());
        if (!tokens.isEmpty()) {
            this.preLoadedTokens = new Pair(tokens, (Object)((String)p.getSecond()));
        }
    }

    public DataSet next() {
        return this.next(this.minibatchSize);
    }

    public DataSet next(int num) {
        INDArray features;
        int[] featuresShape;
        if (this.sentenceProvider == null) {
            throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
        }
        if (!this.hasNext()) {
            throw new NoSuchElementException("No next element");
        }
        ArrayList<Object> tokenizedSentences = new ArrayList<Object>(num);
        int maxLength = -1;
        int minLength = Integer.MAX_VALUE;
        if (this.preLoadedTokens != null) {
            tokenizedSentences.add(this.preLoadedTokens);
            maxLength = Math.max(maxLength, ((List)this.preLoadedTokens.getFirst()).size());
            minLength = Math.min(minLength, ((List)this.preLoadedTokens.getFirst()).size());
            this.preLoadedTokens = null;
        }
        for (int i = tokenizedSentences.size(); i < num && this.sentenceProvider.hasNext(); ++i) {
            Pair<String, String> p = this.sentenceProvider.nextSentence();
            List<String> tokens = this.tokenizeSentence((String)p.getFirst());
            if (!tokens.isEmpty()) {
                maxLength = Math.max(maxLength, tokens.size());
                minLength = Math.min(minLength, tokens.size());
                tokenizedSentences.add(new Pair(tokens, (Object)((String)p.getSecond())));
                continue;
            }
            --i;
        }
        if (this.maxSentenceLength > 0 && maxLength > this.maxSentenceLength) {
            maxLength = this.maxSentenceLength;
        }
        int currMinibatchSize = tokenizedSentences.size();
        INDArray labels = Nd4j.create((int[])new int[]{currMinibatchSize, this.numClasses});
        for (int i = 0; i < tokenizedSentences.size(); ++i) {
            String labelStr = (String)((Pair)tokenizedSentences.get(i)).getSecond();
            if (!this.labelClassMap.containsKey(labelStr)) {
                throw new IllegalStateException("Got label \"" + labelStr + "\" that is not present in list of LabeledSentenceProvider labels");
            }
            int labelIdx = this.labelClassMap.get(labelStr);
            labels.putScalar((long)i, (long)labelIdx, 1.0);
        }
        INDArray featuresMask = null;
        if (this.format == Format.CNN1D || this.format == Format.RNN) {
            int i;
            featuresShape = new int[]{currMinibatchSize, this.wordVectorSize, maxLength};
            features = Nd4j.create((int[])featuresShape, (char)(this.format == Format.CNN1D ? (char)'c' : 'f'));
            INDArrayIndex[] idxs = new INDArrayIndex[3];
            idxs[1] = NDArrayIndex.all();
            for (i = 0; i < currMinibatchSize; ++i) {
                idxs[0] = NDArrayIndex.point((long)i);
                List currSentence = (List)((Pair)tokenizedSentences.get(i)).getFirst();
                for (int j = 0; j < currSentence.size() && j < this.maxSentenceLength; ++j) {
                    idxs[2] = NDArrayIndex.point((long)j);
                    INDArray vector = this.getVector((String)currSentence.get(j));
                    features.put(idxs, vector);
                }
            }
            if (minLength != maxLength) {
                featuresMask = Nd4j.create((int[])new int[]{currMinibatchSize, maxLength});
                for (i = 0; i < currMinibatchSize; ++i) {
                    int sentenceLength = ((List)((Pair)tokenizedSentences.get(i)).getFirst()).size();
                    if (sentenceLength >= maxLength) {
                        featuresMask.getRow((long)i).assign((Number)1.0);
                        continue;
                    }
                    featuresMask.get(new INDArrayIndex[]{NDArrayIndex.point((long)i), NDArrayIndex.interval((int)0, (int)sentenceLength)}).assign((Number)1.0);
                }
            }
        } else {
            int i;
            featuresShape = new int[4];
            featuresShape[0] = currMinibatchSize;
            featuresShape[1] = 1;
            if (this.sentencesAlongHeight) {
                featuresShape[2] = maxLength;
                featuresShape[3] = this.wordVectorSize;
            } else {
                featuresShape[2] = this.wordVectorSize;
                featuresShape[3] = maxLength;
            }
            features = Nd4j.create((int[])featuresShape);
            INDArrayIndex[] indices = new INDArrayIndex[4];
            indices[1] = NDArrayIndex.point((long)0L);
            for (i = 0; i < currMinibatchSize; ++i) {
                indices[0] = NDArrayIndex.point((long)i);
                List currSentence = (List)((Pair)tokenizedSentences.get(i)).getFirst();
                for (int j = 0; j < currSentence.size() && j < this.maxSentenceLength; ++j) {
                    INDArray vector = this.getVector((String)currSentence.get(j));
                    if (this.sentencesAlongHeight) {
                        indices[2] = NDArrayIndex.point((long)j);
                        indices[3] = NDArrayIndex.all();
                    } else {
                        indices[2] = NDArrayIndex.all();
                        indices[3] = NDArrayIndex.point((long)j);
                    }
                    features.put(indices, vector);
                }
            }
            if (minLength != maxLength) {
                if (this.sentencesAlongHeight) {
                    featuresMask = Nd4j.create((int[])new int[]{currMinibatchSize, 1, maxLength, 1});
                    for (i = 0; i < currMinibatchSize; ++i) {
                        int sentenceLength = ((List)((Pair)tokenizedSentences.get(i)).getFirst()).size();
                        if (sentenceLength >= maxLength) {
                            featuresMask.slice((long)i).assign((Number)1.0);
                            continue;
                        }
                        featuresMask.get(new INDArrayIndex[]{NDArrayIndex.point((long)i), NDArrayIndex.point((long)0L), NDArrayIndex.interval((int)0, (int)sentenceLength), NDArrayIndex.point((long)0L)}).assign((Number)1.0);
                    }
                } else {
                    featuresMask = Nd4j.create((int[])new int[]{currMinibatchSize, 1, 1, maxLength});
                    for (i = 0; i < currMinibatchSize; ++i) {
                        int sentenceLength = ((List)((Pair)tokenizedSentences.get(i)).getFirst()).size();
                        if (sentenceLength >= maxLength) {
                            featuresMask.slice((long)i).assign((Number)1.0);
                            continue;
                        }
                        featuresMask.get(new INDArrayIndex[]{NDArrayIndex.point((long)i), NDArrayIndex.point((long)0L), NDArrayIndex.point((long)0L), NDArrayIndex.interval((int)0, (int)sentenceLength)}).assign((Number)1.0);
                    }
                }
            }
        }
        DataSet ds = new DataSet(features, labels, featuresMask, null);
        if (this.dataSetPreProcessor != null) {
            this.dataSetPreProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)ds);
        }
        this.cursor += ds.numExamples();
        return ds;
    }

    public int inputColumns() {
        return this.wordVectorSize;
    }

    public int totalOutcomes() {
        return this.numClasses;
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        this.cursor = 0;
        this.sentenceProvider.reset();
    }

    public int batch() {
        return this.minibatchSize;
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.dataSetPreProcessor = preProcessor;
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.dataSetPreProcessor;
    }

    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    public CnnSentenceDataSetIterator(Format format, LabeledSentenceProvider sentenceProvider, WordVectors wordVectors, TokenizerFactory tokenizerFactory, UnknownWordHandling unknownWordHandling, boolean useNormalizedWordVectors, int minibatchSize, int maxSentenceLength, boolean sentencesAlongHeight, DataSetPreProcessor dataSetPreProcessor, int wordVectorSize, int numClasses, Map<String, Integer> labelClassMap, INDArray unknown, int cursor, Pair<List<String>, String> preLoadedTokens) {
        this.format = format;
        this.sentenceProvider = sentenceProvider;
        this.wordVectors = wordVectors;
        this.tokenizerFactory = tokenizerFactory;
        this.unknownWordHandling = unknownWordHandling;
        this.useNormalizedWordVectors = useNormalizedWordVectors;
        this.minibatchSize = minibatchSize;
        this.maxSentenceLength = maxSentenceLength;
        this.sentencesAlongHeight = sentencesAlongHeight;
        this.dataSetPreProcessor = dataSetPreProcessor;
        this.wordVectorSize = wordVectorSize;
        this.numClasses = numClasses;
        this.labelClassMap = labelClassMap;
        this.unknown = unknown;
        this.cursor = cursor;
        this.preLoadedTokens = preLoadedTokens;
    }

    public static class Builder {
        private Format format;
        private LabeledSentenceProvider sentenceProvider = null;
        private WordVectors wordVectors;
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private UnknownWordHandling unknownWordHandling = UnknownWordHandling.RemoveWord;
        private boolean useNormalizedWordVectors = true;
        private int maxSentenceLength = -1;
        private int minibatchSize = 32;
        private boolean sentencesAlongHeight = true;
        private DataSetPreProcessor dataSetPreProcessor;

        @Deprecated
        public Builder() {
            this(Format.CNN2D);
        }

        public Builder(@NonNull Format format) {
            if (format == null) {
                throw new NullPointerException("format is marked non-null but is null");
            }
            this.format = format;
        }

        public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) {
            this.sentenceProvider = labeledSentenceProvider;
            return this;
        }

        public Builder sentenceProvider(LabelAwareIterator iterator, @NonNull List<String> labels) {
            if (labels == null) {
                throw new NullPointerException("labels is marked non-null but is null");
            }
            LabelAwareConverter converter = new LabelAwareConverter(iterator, labels);
            return this.sentenceProvider(converter);
        }

        public Builder sentenceProvider(LabelAwareDocumentIterator iterator, @NonNull List<String> labels) {
            if (labels == null) {
                throw new NullPointerException("labels is marked non-null but is null");
            }
            DocumentIteratorConverter converter = new DocumentIteratorConverter(iterator);
            return this.sentenceProvider(converter, labels);
        }

        public Builder sentenceProvider(LabelAwareSentenceIterator iterator, @NonNull List<String> labels) {
            if (labels == null) {
                throw new NullPointerException("labels is marked non-null but is null");
            }
            SentenceIteratorConverter converter = new SentenceIteratorConverter(iterator);
            return this.sentenceProvider(converter, labels);
        }

        public Builder wordVectors(WordVectors wordVectors) {
            this.wordVectors = wordVectors;
            return this;
        }

        public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder unknownWordHandling(UnknownWordHandling unknownWordHandling) {
            this.unknownWordHandling = unknownWordHandling;
            return this;
        }

        public Builder minibatchSize(int minibatchSize) {
            this.minibatchSize = minibatchSize;
            return this;
        }

        public Builder useNormalizedWordVectors(boolean useNormalizedWordVectors) {
            this.useNormalizedWordVectors = useNormalizedWordVectors;
            return this;
        }

        public Builder maxSentenceLength(int maxSentenceLength) {
            this.maxSentenceLength = maxSentenceLength;
            return this;
        }

        public Builder sentencesAlongHeight(boolean sentencesAlongHeight) {
            this.sentencesAlongHeight = sentencesAlongHeight;
            return this;
        }

        public Builder dataSetPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
            this.dataSetPreProcessor = dataSetPreProcessor;
            return this;
        }

        public CnnSentenceDataSetIterator build() {
            if (this.wordVectors == null) {
                throw new IllegalStateException("Cannot build CnnSentenceDataSetIterator without a WordVectors instance");
            }
            return new CnnSentenceDataSetIterator(this);
        }
    }

    public static enum Format {
        RNN,
        CNN1D,
        CNN2D;

    }

    public static enum UnknownWordHandling {
        RemoveWord,
        UseUnknownVector;

    }
}

