/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.sequencevectors.transformers.impl.iterables;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.transformers.impl.SentenceTransformer;
import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.BasicTransformerIterator;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.AsyncLabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelTransformerIterator
extends BasicTransformerIterator {
    private static final Logger log = LoggerFactory.getLogger(ParallelTransformerIterator.class);
    protected BlockingQueue<Sequence<VocabWord>> buffer = new LinkedBlockingQueue<Sequence<VocabWord>>(1024);
    protected BlockingQueue<LabelledDocument> stringBuffer;
    protected TokenizerThread[] threads;
    protected boolean underlyingHas = true;
    protected AtomicInteger processing = new AtomicInteger(0);
    protected static final AtomicInteger count = new AtomicInteger(0);

    public ParallelTransformerIterator(@NonNull LabelAwareIterator iterator, @NonNull SentenceTransformer transformer) {
        this(iterator, transformer, true);
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (transformer == null) {
            throw new NullPointerException("transformer is marked @NonNull but is null");
        }
    }

    public ParallelTransformerIterator(@NonNull LabelAwareIterator iterator, @NonNull SentenceTransformer transformer, boolean allowMultithreading) {
        super(new AsyncLabelAwareIterator(iterator, 512), transformer);
        if (iterator == null) {
            throw new NullPointerException("iterator is marked @NonNull but is null");
        }
        if (transformer == null) {
            throw new NullPointerException("transformer is marked @NonNull but is null");
        }
        this.allowMultithreading = allowMultithreading;
        this.stringBuffer = new LinkedBlockingQueue<LabelledDocument>(512);
        this.threads = new TokenizerThread[allowMultithreading ? Math.max(Runtime.getRuntime().availableProcessors(), 2) : 1];
        try {
            for (int cnt = 0; cnt < 256; ++cnt) {
                boolean before = this.underlyingHas;
                if (before) {
                    this.underlyingHas = this.iterator.hasNextDocument();
                }
                if (this.underlyingHas) {
                    this.stringBuffer.put(this.iterator.nextDocument());
                    continue;
                }
                cnt += 257;
            }
        }
        catch (Exception cnt) {
            // empty catch block
        }
        for (int x = 0; x < this.threads.length; ++x) {
            this.threads[x] = new TokenizerThread(x, transformer, this.stringBuffer, this.buffer, this.processing);
            this.threads[x].setDaemon(true);
            this.threads[x].setName("ParallelTransformer thread " + x);
            this.threads[x].start();
        }
    }

    @Override
    public void reset() {
        this.iterator.shutdown();
        for (int x = 0; x < this.threads.length; ++x) {
            if (this.threads[x] == null) continue;
            this.threads[x].shutdown();
            try {
                this.threads[x].interrupt();
                continue;
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
    }

    @Override
    public boolean hasNext() {
        boolean before = this.underlyingHas;
        this.underlyingHas = before ? this.iterator.hasNextDocument() : false;
        return this.underlyingHas || !this.buffer.isEmpty() || !this.stringBuffer.isEmpty() || this.processing.get() > 0;
    }

    @Override
    public Sequence<VocabWord> next() {
        try {
            if (this.underlyingHas) {
                this.stringBuffer.put(this.iterator.nextDocument());
            }
            return this.buffer.take();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static class TokenizerThread
    extends Thread
    implements Runnable {
        protected BlockingQueue<Sequence<VocabWord>> sequencesBuffer;
        protected BlockingQueue<LabelledDocument> stringsBuffer;
        protected SentenceTransformer sentenceTransformer;
        protected AtomicBoolean shouldWork = new AtomicBoolean(true);
        protected AtomicInteger processing;

        public TokenizerThread(int threadIdx, SentenceTransformer transformer, BlockingQueue<LabelledDocument> stringsBuffer, BlockingQueue<Sequence<VocabWord>> sequencesBuffer, AtomicInteger processing) {
            this.stringsBuffer = stringsBuffer;
            this.sequencesBuffer = sequencesBuffer;
            this.sentenceTransformer = transformer;
            this.processing = processing;
            this.setDaemon(true);
            this.setName("Tokenization thread " + threadIdx);
        }

        @Override
        public void run() {
            try {
                while (this.shouldWork.get()) {
                    LabelledDocument document = this.stringsBuffer.take();
                    if (document == null || document.getContent() == null) continue;
                    this.processing.incrementAndGet();
                    Sequence<VocabWord> sequence = this.sentenceTransformer.transformToSequence(document.getContent());
                    if (document.getLabels() != null) {
                        for (String label : document.getLabels()) {
                            if (label == null || label.isEmpty()) continue;
                            sequence.addSequenceLabel(new VocabWord(1.0, label));
                        }
                    }
                    if (sequence != null) {
                        this.sequencesBuffer.put(sequence);
                    }
                    this.processing.decrementAndGet();
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                this.shutdown();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public void shutdown() {
            this.shouldWork.set(false);
        }
    }
}

