/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.iterator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.iterator.LabeledPairSentenceProvider;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.bert.BertSequenceMasker;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.primitives.Triple;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class BertIterator
implements MultiDataSetIterator {
    protected Task task;
    protected TokenizerFactory tokenizerFactory;
    protected int maxTokens = -1;
    protected int minibatchSize = 32;
    protected boolean padMinibatches = false;
    protected MultiDataSetPreProcessor preProcessor;
    protected LabeledSentenceProvider sentenceProvider = null;
    protected LabeledPairSentenceProvider sentencePairProvider = null;
    protected LengthHandling lengthHandling;
    protected FeatureArrays featureArrays;
    protected Map<String, Integer> vocabMap;
    protected BertSequenceMasker masker = null;
    protected UnsupervisedLabelFormat unsupervisedLabelFormat = null;
    protected String maskToken;
    protected String prependToken;
    protected String appendToken;
    protected List<String> vocabKeysAsList;

    protected BertIterator(Builder b) {
        this.task = b.task;
        this.tokenizerFactory = b.tokenizerFactory;
        this.maxTokens = b.maxTokens;
        this.minibatchSize = b.minibatchSize;
        this.padMinibatches = b.padMinibatches;
        this.preProcessor = b.preProcessor;
        this.sentenceProvider = b.sentenceProvider;
        this.sentencePairProvider = b.sentencePairProvider;
        this.lengthHandling = b.lengthHandling;
        this.featureArrays = b.featureArrays;
        this.vocabMap = b.vocabMap;
        this.masker = b.masker;
        this.unsupervisedLabelFormat = b.unsupervisedLabelFormat;
        this.maskToken = b.maskToken;
        this.prependToken = b.prependToken;
        this.appendToken = b.appendToken;
    }

    public boolean hasNext() {
        if (this.sentenceProvider != null) {
            return this.sentenceProvider.hasNext();
        }
        return this.sentencePairProvider.hasNext();
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next() {
        return this.next(this.minibatchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    public org.nd4j.linalg.dataset.api.MultiDataSet next(int num) {
        int outLength;
        List<Pair<List<String>, String>> tokensAndLabelList;
        Preconditions.checkState((boolean)this.hasNext(), (String)"No next element available");
        int mbSize = 0;
        long[] segIdOnesFrom = null;
        if (this.sentenceProvider != null) {
            ArrayList<Pair<String, String>> list = new ArrayList<Pair<String, String>>(num);
            while (this.sentenceProvider.hasNext() && mbSize++ < num) {
                list.add(this.sentenceProvider.nextSentence());
            }
            SentenceListProcessed sentenceListProcessed = this.tokenizeMiniBatch(list);
            tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
            outLength = sentenceListProcessed.getMaxL();
        } else if (this.sentencePairProvider != null) {
            ArrayList<Triple<String, String, String>> listPairs = new ArrayList<Triple<String, String, String>>(num);
            while (this.sentencePairProvider.hasNext() && mbSize++ < num) {
                listPairs.add(this.sentencePairProvider.nextSentencePair());
            }
            SentencePairListProcessed sentencePairListProcessed = this.tokenizePairsMiniBatch(listPairs);
            tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
            outLength = sentencePairListProcessed.getMaxL();
            segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
        } else {
            throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
        }
        Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = this.convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
        INDArray[] featureArray = (INDArray[])featuresAndMaskArraysPair.getFirst();
        INDArray[] featureMaskArray = (INDArray[])featuresAndMaskArraysPair.getSecond();
        Pair<INDArray[], INDArray[]> labelsAndMaskArraysPair = this.convertMiniBatchLabels(tokensAndLabelList, featureArray, outLength);
        INDArray[] labelArray = (INDArray[])labelsAndMaskArraysPair.getFirst();
        INDArray[] labelMaskArray = (INDArray[])labelsAndMaskArraysPair.getSecond();
        MultiDataSet mds = new MultiDataSet(featureArray, labelArray, featureMaskArray, labelMaskArray);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.MultiDataSet)mds);
        }
        return mds;
    }

    public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> listOnlySentences) {
        List<Pair<String, String>> sentencesWithNullLabel = this.addDummyLabel(listOnlySentences);
        SentenceListProcessed sentenceListProcessed = this.tokenizeMiniBatch(sentencesWithNullLabel);
        List<Pair<List<String>, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
        int outLength = sentenceListProcessed.getMaxL();
        if (this.preProcessor != null) {
            Pair<INDArray[], INDArray[]> featureFeatureMasks = this.convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
            MultiDataSet dummyMDS = new MultiDataSet((INDArray[])featureFeatureMasks.getFirst(), null, (INDArray[])featureFeatureMasks.getSecond(), null);
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.MultiDataSet)dummyMDS);
            return new Pair((Object)dummyMDS.getFeatures(), (Object)dummyMDS.getFeaturesMaskArrays());
        }
        return this.convertMiniBatchFeatures(tokensAndLabelList, outLength, null);
    }

    public Pair<INDArray[], INDArray[]> featurizeSentencePairs(List<Pair<String, String>> listOnlySentencePairs) {
        Preconditions.checkState((this.sentencePairProvider != null ? 1 : 0) != 0, (String)"The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null).");
        List<Triple<String, String, String>> sentencePairsWithNullLabel = this.addDummyLabelForPairs(listOnlySentencePairs);
        SentencePairListProcessed sentencePairListProcessed = this.tokenizePairsMiniBatch(sentencePairsWithNullLabel);
        List<Pair<List<String>, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
        int outLength = sentencePairListProcessed.getMaxL();
        long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
        if (this.preProcessor != null) {
            Pair<INDArray[], INDArray[]> featuresAndMaskArraysPair = this.convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
            MultiDataSet dummyMDS = new MultiDataSet((INDArray[])featuresAndMaskArraysPair.getFirst(), null, (INDArray[])featuresAndMaskArraysPair.getSecond(), null);
            this.preProcessor.preProcess((org.nd4j.linalg.dataset.api.MultiDataSet)dummyMDS);
            return new Pair((Object)dummyMDS.getFeatures(), (Object)dummyMDS.getFeaturesMaskArrays());
        }
        return this.convertMiniBatchFeatures(tokensAndLabelList, outLength, segIdOnesFrom);
    }

    private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> tokensAndLabelList, int outLength, long[] segIdOnesFrom) {
        INDArray[] fm;
        INDArray[] f;
        int mbPadded = this.padMinibatches ? this.minibatchSize : tokensAndLabelList.size();
        int[][] outIdxs = new int[mbPadded][outLength];
        int[][] outMask = new int[mbPadded][outLength];
        int[][] outSegmentId = null;
        if (this.featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
            outSegmentId = new int[mbPadded][outLength];
        }
        for (int i = 0; i < tokensAndLabelList.size(); ++i) {
            Pair<List<String>, String> p = tokensAndLabelList.get(i);
            List t = (List)p.getFirst();
            for (int j = 0; j < outLength && j < t.size(); ++j) {
                int idx;
                Preconditions.checkState((boolean)this.vocabMap.containsKey(t.get(j)), (String)"Unknown token encountered: token \"%s\" is not in vocabulary", t.get(j));
                outIdxs[i][j] = idx = this.vocabMap.get(t.get(j)).intValue();
                outMask[i][j] = 1;
                if (segIdOnesFrom == null || (long)j < segIdOnesFrom[i]) continue;
                outSegmentId[i][j] = 1;
            }
        }
        INDArray outIdxsArr = Nd4j.createFromArray((int[][])outIdxs);
        INDArray outMaskArr = Nd4j.createFromArray((int[][])outMask);
        if (this.featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
            INDArray outSegmentIdArr = Nd4j.createFromArray((int[][])outSegmentId);
            f = new INDArray[]{outIdxsArr, outSegmentIdArr};
            fm = new INDArray[]{outMaskArr, null};
        } else {
            f = new INDArray[]{outIdxsArr};
            fm = new INDArray[]{outMaskArr};
        }
        return new Pair((Object)f, (Object)fm);
    }

    private SentenceListProcessed tokenizeMiniBatch(List<Pair<String, String>> list) {
        int outLength;
        SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size());
        int longestSeq = -1;
        for (Pair<String, String> p : list) {
            List<String> tokens = this.tokenizeSentence((String)p.getFirst());
            sentenceListProcessed.addProcessedToList((Pair<List<String>, String>)new Pair(tokens, (Object)((String)p.getSecond())));
            longestSeq = Math.max(longestSeq, tokens.size());
        }
        switch (this.lengthHandling) {
            case FIXED_LENGTH: {
                outLength = this.maxTokens;
                break;
            }
            case ANY_LENGTH: {
                outLength = longestSeq;
                break;
            }
            case CLIP_ONLY: {
                outLength = Math.min(this.maxTokens, longestSeq);
                break;
            }
            default: {
                throw new RuntimeException("Not implemented length handling mode: " + this.lengthHandling);
            }
        }
        sentenceListProcessed.setMaxL(outLength);
        return sentenceListProcessed;
    }

    private SentencePairListProcessed tokenizePairsMiniBatch(List<Triple<String, String, String>> listPairs) {
        SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(listPairs.size());
        for (Triple<String, String, String> t : listPairs) {
            List<String> tokensL = this.tokenizeSentence((String)t.getFirst(), true);
            List<String> tokensR = this.tokenizeSentence((String)t.getSecond(), true);
            ArrayList<String> tokens = new ArrayList<String>(this.maxTokens);
            int maxLength = this.maxTokens;
            if (this.prependToken != null) {
                --maxLength;
            }
            if (this.appendToken != null) {
                maxLength -= 2;
            }
            if (tokensL.size() + tokensR.size() > maxLength) {
                boolean shortOnL = tokensL.size() < tokensR.size();
                int shortSize = Math.min(tokensL.size(), tokensR.size());
                if (shortSize > maxLength / 2) {
                    tokensL.subList(maxLength / 2, tokensL.size()).clear();
                    tokensR.subList(maxLength - maxLength / 2, tokensR.size()).clear();
                } else if (shortOnL) {
                    tokensR.subList(maxLength - tokensL.size(), tokensR.size()).clear();
                } else {
                    tokensL.subList(maxLength - tokensR.size(), tokensL.size()).clear();
                }
            }
            if (this.prependToken != null) {
                tokens.add(this.prependToken);
            }
            tokens.addAll(tokensL);
            if (this.appendToken != null) {
                tokens.add(this.appendToken);
            }
            int segIdOnesFrom = tokens.size();
            tokens.addAll(tokensR);
            if (this.appendToken != null) {
                tokens.add(this.appendToken);
            }
            sentencePairListProcessed.addProcessedToList(segIdOnesFrom, (Pair<List<String>, String>)new Pair(tokens, (Object)((String)t.getThird())));
        }
        sentencePairListProcessed.setMaxL(this.maxTokens);
        return sentencePairListProcessed;
    }

    private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> tokenizedSentences, INDArray[] featureArray, int outLength) {
        INDArray[] lm;
        int mbPadded;
        INDArray[] l = new INDArray[1];
        int mbSize = tokenizedSentences.size();
        int n = mbPadded = this.padMinibatches ? this.minibatchSize : tokenizedSentences.size();
        if (this.task == Task.SEQ_CLASSIFICATION) {
            String lbl;
            int i;
            List<String> labels;
            int numClasses;
            int[] classLabels = new int[mbPadded];
            if (this.sentenceProvider != null) {
                numClasses = this.sentenceProvider.numLabelClasses();
                labels = this.sentenceProvider.allLabels();
                for (i = 0; i < mbSize; ++i) {
                    lbl = (String)tokenizedSentences.get(i).getRight();
                    classLabels[i] = labels.indexOf(lbl);
                    Preconditions.checkState((classLabels[i] >= 0 ? 1 : 0) != 0, (String)"Provided label \"%s\" for sentence does not exist in set of classes/categories", (Object)lbl);
                }
            } else if (this.sentencePairProvider != null) {
                numClasses = this.sentencePairProvider.numLabelClasses();
                labels = this.sentencePairProvider.allLabels();
                for (i = 0; i < mbSize; ++i) {
                    lbl = (String)tokenizedSentences.get(i).getRight();
                    classLabels[i] = labels.indexOf(lbl);
                    Preconditions.checkState((classLabels[i] >= 0 ? 1 : 0) != 0, (String)"Provided label \"%s\" for sentence does not exist in set of classes/categories", (Object)lbl);
                }
            } else {
                throw new RuntimeException();
            }
            l[0] = Nd4j.create((DataType)DataType.FLOAT, (long[])new long[]{mbPadded, numClasses});
            for (int i2 = 0; i2 < mbSize; ++i2) {
                l[0].putScalar((long)i2, (long)classLabels[i2], 1.0);
            }
            lm = null;
            if (this.padMinibatches && mbSize != mbPadded) {
                INDArray a = Nd4j.zeros((DataType)DataType.FLOAT, (long[])new long[]{mbPadded, 1L});
                lm = new INDArray[]{a};
                a.get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)mbSize), NDArrayIndex.all()}).assign((Number)1);
            }
        } else if (this.task == Task.UNSUPERVISED) {
            INDArray labelArr;
            if (this.vocabKeysAsList == null) {
                String[] arr = new String[this.vocabMap.size()];
                for (Map.Entry<String, Integer> e : this.vocabMap.entrySet()) {
                    arr[e.getValue().intValue()] = e.getKey();
                }
                this.vocabKeysAsList = Arrays.asList(arr);
            }
            int vocabSize = this.vocabMap.size();
            INDArray lMask = Nd4j.zeros((DataType)DataType.INT, (long[])new long[]{mbPadded, outLength});
            if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
                labelArr = Nd4j.create((DataType)DataType.INT, (long[])new long[]{mbPadded, outLength});
            } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
                labelArr = Nd4j.create((DataType)DataType.FLOAT, (long[])new long[]{mbPadded, vocabSize, outLength});
            } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
                labelArr = Nd4j.create((DataType)DataType.FLOAT, (long[])new long[]{outLength, mbPadded, vocabSize});
            } else {
                throw new IllegalStateException("Unknown unsupervised label format: " + this.unsupervisedLabelFormat);
            }
            for (int i = 0; i < mbSize; ++i) {
                List tokens = (List)tokenizedSentences.get(i).getFirst();
                Pair<List<String>, boolean[]> p = this.masker.maskSequence(tokens, this.maskToken, this.vocabKeysAsList);
                List maskedTokens = (List)p.getFirst();
                boolean[] predictionTarget = (boolean[])p.getSecond();
                int seqLen = Math.min(predictionTarget.length, outLength);
                for (int j = 0; j < seqLen; ++j) {
                    if (!predictionTarget[j]) continue;
                    String oldToken = (String)((List)tokenizedSentences.get(i).getFirst()).get(j);
                    int targetTokenIdx = this.vocabMap.get(oldToken);
                    if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
                        labelArr.putScalar((long)i, (long)j, (double)targetTokenIdx);
                    } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
                        labelArr.putScalar((long)i, (long)j, (long)targetTokenIdx, 1.0);
                    } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
                        labelArr.putScalar((long)j, (long)i, (long)targetTokenIdx, 1.0);
                    }
                    lMask.putScalar((long)i, (long)j, 1.0);
                    String newToken = (String)maskedTokens.get(j);
                    int newTokenIdx = this.vocabMap.get(newToken);
                    featureArray[0].putScalar((long)i, (long)j, (double)newTokenIdx);
                }
            }
            l[0] = labelArr;
            lm = new INDArray[]{lMask};
        } else {
            throw new IllegalStateException("Task not yet implemented: " + this.task);
        }
        return new Pair((Object)l, lm);
    }

    private List<String> tokenizeSentence(String sentence) {
        return this.tokenizeSentence(sentence, false);
    }

    private List<String> tokenizeSentence(String sentence, boolean ignorePrependAppend) {
        Tokenizer t = this.tokenizerFactory.create(sentence);
        ArrayList<String> tokens = new ArrayList<String>();
        if (this.prependToken != null && !ignorePrependAppend) {
            tokens.add(this.prependToken);
        }
        while (t.hasMoreTokens()) {
            String token = t.nextToken();
            tokens.add(token);
        }
        if (this.appendToken != null && !ignorePrependAppend) {
            tokens.add(this.appendToken);
        }
        return tokens;
    }

    private List<Pair<String, String>> addDummyLabel(List<String> listOnlySentences) {
        ArrayList<Pair<String, String>> list = new ArrayList<Pair<String, String>>(listOnlySentences.size());
        for (String s : listOnlySentences) {
            list.add((Pair<String, String>)new Pair((Object)s, null));
        }
        return list;
    }

    private List<Triple<String, String, String>> addDummyLabelForPairs(List<Pair<String, String>> listOnlySentencePairs) {
        ArrayList<Triple<String, String, String>> list = new ArrayList<Triple<String, String, String>>(listOnlySentencePairs.size());
        for (Pair<String, String> p : listOnlySentencePairs) {
            list.add((Triple<String, String, String>)new Triple((Object)((String)p.getFirst()), (Object)((String)p.getSecond()), null));
        }
        return list;
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        if (this.sentenceProvider != null) {
            this.sentenceProvider.reset();
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public MultiDataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
        this.preProcessor = preProcessor;
    }

    private static class SentenceListProcessed {
        private int listLength;
        private int maxL;
        private List<Pair<List<String>, String>> tokensAndLabelList;

        private SentenceListProcessed(int listLength) {
            this.listLength = listLength;
            this.tokensAndLabelList = new ArrayList<Pair<List<String>, String>>(listLength);
        }

        private void addProcessedToList(Pair<List<String>, String> tokenizedSentenceAndLabel) {
            this.tokensAndLabelList.add(tokenizedSentenceAndLabel);
        }

        public int getMaxL() {
            return this.maxL;
        }

        public void setMaxL(int maxL) {
            this.maxL = maxL;
        }

        public List<Pair<List<String>, String>> getTokensAndLabelList() {
            return this.tokensAndLabelList;
        }
    }

    private static class SentencePairListProcessed {
        private int listLength = 0;
        private long[] segIdOnesFrom;
        private int cursor = 0;
        private SentenceListProcessed sentenceListProcessed;

        private SentencePairListProcessed(int listLength) {
            this.listLength = listLength;
            this.segIdOnesFrom = new long[listLength];
            this.sentenceListProcessed = new SentenceListProcessed(listLength);
        }

        private void addProcessedToList(long segIdIdx, Pair<List<String>, String> tokenizedSentencePairAndLabel) {
            this.segIdOnesFrom[this.cursor] = segIdIdx;
            this.sentenceListProcessed.addProcessedToList(tokenizedSentencePairAndLabel);
            ++this.cursor;
        }

        private void setMaxL(int maxL) {
            this.sentenceListProcessed.setMaxL(maxL);
        }

        private int getMaxL() {
            return this.sentenceListProcessed.getMaxL();
        }

        private List<Pair<List<String>, String>> getTokensAndLabelList() {
            return this.sentenceListProcessed.getTokensAndLabelList();
        }

        public long[] getSegIdOnesFrom() {
            return this.segIdOnesFrom;
        }
    }

    public static class Builder {
        protected Task task;
        protected TokenizerFactory tokenizerFactory;
        protected LengthHandling lengthHandling = LengthHandling.FIXED_LENGTH;
        protected int maxTokens = -1;
        protected int minibatchSize = 32;
        protected boolean padMinibatches = false;
        protected MultiDataSetPreProcessor preProcessor;
        protected LabeledSentenceProvider sentenceProvider = null;
        protected LabeledPairSentenceProvider sentencePairProvider = null;
        protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
        protected Map<String, Integer> vocabMap;
        protected BertSequenceMasker masker = new BertMaskedLMMasker();
        protected UnsupervisedLabelFormat unsupervisedLabelFormat;
        protected String maskToken;
        protected String prependToken;
        protected String appendToken;

        public Builder task(Task task) {
            this.task = task;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int maxLength) {
            if (lengthHandling == null) {
                throw new NullPointerException("lengthHandling is marked non-null but is null");
            }
            this.lengthHandling = lengthHandling;
            this.maxTokens = maxLength;
            return this;
        }

        public Builder minibatchSize(int minibatchSize) {
            this.minibatchSize = minibatchSize;
            return this;
        }

        public Builder padMinibatches(boolean padMinibatches) {
            this.padMinibatches = padMinibatches;
            return this;
        }

        public Builder preProcessor(MultiDataSetPreProcessor preProcessor) {
            this.preProcessor = preProcessor;
            return this;
        }

        public Builder sentenceProvider(LabeledSentenceProvider sentenceProvider) {
            this.sentenceProvider = sentenceProvider;
            return this;
        }

        public Builder sentencePairProvider(LabeledPairSentenceProvider sentencePairProvider) {
            this.sentencePairProvider = sentencePairProvider;
            return this;
        }

        public Builder featureArrays(FeatureArrays featureArrays) {
            this.featureArrays = featureArrays;
            return this;
        }

        public Builder vocabMap(Map<String, Integer> vocabMap) {
            this.vocabMap = vocabMap;
            return this;
        }

        public Builder masker(BertSequenceMasker masker) {
            this.masker = masker;
            return this;
        }

        public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat labelFormat) {
            this.unsupervisedLabelFormat = labelFormat;
            return this;
        }

        public Builder maskToken(String maskToken) {
            this.maskToken = maskToken;
            return this;
        }

        public Builder prependToken(String prependToken) {
            this.prependToken = prependToken;
            return this;
        }

        public Builder appendToken(String appendToken) {
            this.appendToken = appendToken;
            return this;
        }

        public BertIterator build() {
            Preconditions.checkState((this.task != null ? 1 : 0) != 0, (String)"No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
            Preconditions.checkState((this.tokenizerFactory != null ? 1 : 0) != 0, (String)"No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
            Preconditions.checkState((this.vocabMap != null ? 1 : 0) != 0, (String)"Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set");
            Preconditions.checkState((this.task != Task.UNSUPERVISED || this.masker != null ? 1 : 0) != 0, (String)"If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
            Preconditions.checkState((this.task != Task.UNSUPERVISED || this.unsupervisedLabelFormat != null ? 1 : 0) != 0, (String)"If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
            Preconditions.checkState((this.task != Task.UNSUPERVISED || this.maskToken != null ? 1 : 0) != 0, (String)"If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
            if (this.sentencePairProvider != null) {
                Preconditions.checkState((this.task == Task.SEQ_CLASSIFICATION ? 1 : 0) != 0, (String)"Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider");
                Preconditions.checkState((this.featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID ? 1 : 0) != 0, (String)"Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider");
                Preconditions.checkState((this.lengthHandling == LengthHandling.FIXED_LENGTH ? 1 : 0) != 0, (String)"Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider");
                Preconditions.checkState((this.sentencePairProvider != null ? 1 : 0) != 0, (String)"Provide either a sentence provider or a sentence pair provider. Both cannot be non null");
            }
            if (this.appendToken != null) {
                Preconditions.checkState((this.sentencePairProvider != null ? 1 : 0) != 0, (String)"Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider.");
            }
            return new BertIterator(this);
        }
    }

    public static enum UnsupervisedLabelFormat {
        RANK2_IDX,
        RANK3_NCL,
        RANK3_LNC;

    }

    public static enum FeatureArrays {
        INDICES_MASK,
        INDICES_MASK_SEGMENTID;

    }

    public static enum LengthHandling {
        FIXED_LENGTH,
        ANY_LENGTH,
        CLIP_ONLY;

    }

    public static enum Task {
        UNSUPERVISED,
        SEQ_CLASSIFICATION;

    }
}

