/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.jcore.consumer.ew;

import de.julielab.jcore.consumer.ew.Encoder;
import de.julielab.jcore.consumer.ew.VectorOperations;
import de.julielab.jcore.types.Token;
import de.julielab.jcore.utility.index.Comparators;
import de.julielab.jcore.utility.index.IndexTermGenerator;
import de.julielab.jcore.utility.index.JCoReTreeMapAnnotationIndex;
import de.julielab.jcore.utility.index.TermGenerators;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.management.ManagementFactory;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.lang.StringUtils;
import org.apache.uima.UimaContext;
import org.apache.uima.analysis_component.JCasAnnotator_ImplBase;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.cas.Type;
import org.apache.uima.fit.descriptor.ConfigurationParameter;
import org.apache.uima.fit.descriptor.ResourceMetaData;
import org.apache.uima.fit.descriptor.TypeCapability;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.apache.uima.resource.ResourceInitializationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ResourceMetaData(name="JCoRe Flair Embedding Writer", description="Given a Flair compatible embedding and a UIMA annotation type, this component prints the embeddings of tokens annotated with the annotation to a file.")
@TypeCapability(inputs={"de.julielab.jcore.types.Token", "de.julielab.jcore.types.EmbeddingVector"})
public class EmbeddingWriter
extends JCasAnnotator_ImplBase {
    public static final String PARAM_ANNOTATION_TYPE = "AnnotationType";
    public static final String PARAM_OUTDIR = "OutputDirectory";
    public static final String PARAM_GZIP = "UseGzip";
    public static final String PARAM_MAX_FILE_ENTRY_SIZE = "MaximumEntriesPerOutputFile";
    private static final Logger log = LoggerFactory.getLogger(EmbeddingWriter.class);
    private static int currentConsumerNumber = 0;
    @ConfigurationParameter(name="UseGzip", mandatory=false, description="If set to true, the output data will be compressed. Defaults to false.")
    boolean gzip;
    @ConfigurationParameter(name="AnnotationType", mandatory=false, description="Fully qualified type name to output embeddings for. If an annotation spans multiple tokens, their embeddings are averaged. If this parameter is omitted, the embeddings of all tokens will be written")
    private String annotationType;
    @ConfigurationParameter(name="OutputDirectory", description="The directory into which the embedding files should be written. In a multi-threaded pipeline, each thread writes its own files. The file names will also include the the host name on which it ran. All output files are ordered by tokens or covered annotation text spans. To control the maximum file size, refer to the MaximumEntriesPerOutputFile parameter.")
    private String outputDir;
    @ConfigurationParameter(name="MaximumEntriesPerOutputFile", mandatory=false, description="The text-embedding pairs are accumulated from multiple CASes before writing them to file. The accumulator keeps the entries sorted by the text part, thus output files are also ordered. This parameter defines the maximum size the accumulate will take before writing its contents to file and clearing itself.", defaultValue={"200000"})
    private int maxEntriesPerFile;
    private String pid;
    private String hostName;
    private int consumerNumber;
    private OutputStream os;
    private File nextOutputFile;
    private int currentBatch;
    private ByteBuffer bb;
    private TreeMap<String, byte[]> outputCache;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void initialize(UimaContext aContext) throws ResourceInitializationException {
        this.annotationType = (String)aContext.getConfigParameterValue(PARAM_ANNOTATION_TYPE);
        this.outputDir = (String)aContext.getConfigParameterValue(PARAM_OUTDIR);
        this.gzip = Optional.ofNullable((Boolean)aContext.getConfigParameterValue(PARAM_GZIP)).orElse(false);
        this.maxEntriesPerFile = Integer.valueOf(Optional.ofNullable((String)aContext.getConfigParameterValue(PARAM_MAX_FILE_ENTRY_SIZE)).orElse("200000"));
        this.pid = this.getPID();
        this.hostName = this.getHostName();
        String string = PARAM_OUTDIR;
        synchronized (PARAM_OUTDIR) {
            this.consumerNumber = currentConsumerNumber++;
            // ** MonitorExit[var2_2] (shouldn't be in output)
            this.currentBatch = 0;
            this.nextOutputFile = this.getNextOutputFile();
            File dir = this.nextOutputFile.getParentFile();
            if (!dir.exists()) {
                dir.mkdirs();
            }
            try {
                this.os = new BufferedOutputStream(new FileOutputStream(this.nextOutputFile));
                if (this.gzip) {
                    this.os = new GZIPOutputStream(this.os);
                }
            }
            catch (FileNotFoundException e) {
                log.error("Could not create output stream for the output file {}", (Object)this.nextOutputFile, (Object)e);
                throw new ResourceInitializationException((Throwable)e);
            }
            catch (IOException e) {
                log.error("Could not create GZIPOutputStream", (Throwable)e);
                throw new ResourceInitializationException((Throwable)e);
            }
            this.outputCache = new TreeMap();
            return;
        }
    }

    private File getNextOutputFile() {
        return new File(this.outputDir + File.separator + "embeddings-" + this.hostName + "-" + this.pid + "-writer" + this.consumerNumber + "-batch" + ++this.currentBatch + ".dat" + (this.gzip ? ".gz" : ""));
    }

    public void process(JCas aJCas) throws AnalysisEngineProcessException {
        try {
            if (!StringUtils.isBlank((String)this.annotationType)) {
                Type type = aJCas.getTypeSystem().getType(this.annotationType);
                if (type == null) {
                    throw new AnalysisEngineProcessException((Throwable)new IllegalArgumentException("The type " + this.annotationType + " was not found in the type system."));
                }
                if (!aJCas.getAnnotationIndex(type).iterator().hasNext()) {
                    return;
                }
                JCoReTreeMapAnnotationIndex tokenIndex = new JCoReTreeMapAnnotationIndex(Comparators.longOverlapComparator(), (IndexTermGenerator)TermGenerators.longOffsetTermGenerator(), (IndexTermGenerator)TermGenerators.longOffsetTermGenerator(), aJCas, Token.type);
                for (Annotation a : aJCas.getAnnotationIndex(type)) {
                    Stream overlappingTokens = tokenIndex.search(a);
                    this.cacheEmbeddingsForAnnotation(overlappingTokens.collect(Collectors.toList()));
                }
            } else {
                for (Annotation token : aJCas.getAnnotationIndex(Token.type)) {
                    this.cacheEmbeddingsForAnnotation(Arrays.asList((Token)token));
                }
            }
            if (this.outputCache.size() >= this.maxEntriesPerFile) {
                this.writeEmbeddingsToFile();
            }
        }
        catch (IOException e) {
            log.error("Could not write to output stream", (Throwable)e);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
    }

    private void writeEmbeddingsToFile() throws IOException {
        for (byte[] textVector : this.outputCache.values()) {
            this.os.write(textVector);
        }
        this.nextOutputFile = this.getNextOutputFile();
        this.outputCache.clear();
    }

    private void cacheEmbeddingsForAnnotation(List<Token> tokens) throws IOException {
        String text = tokens.get(0).getCAS().getDocumentText().substring(tokens.get(0).getBegin(), tokens.get(tokens.size() - 1).getEnd());
        double[] avgEmbedding = VectorOperations.getAverageEmbeddingVector(tokens.stream().map(t -> t.getEmbeddingVectors(0).getVector().toArray()));
        byte[] cacheArray = Encoder.encodeTextVectorPair(text, avgEmbedding, this.bb);
        this.outputCache.put(text, cacheArray);
    }

    public void collectionProcessComplete() throws AnalysisEngineProcessException {
        try {
            this.writeEmbeddingsToFile();
        }
        catch (IOException e) {
            log.error("Exception while writing the last batch of embedding vectors to file {}", (Object)this.nextOutputFile, (Object)e);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
        try {
            this.os.close();
        }
        catch (IOException e) {
            log.error("Exception when closing the output stream to file {}", (Object)this.nextOutputFile, (Object)e);
            throw new AnalysisEngineProcessException((Throwable)e);
        }
    }

    private String getPID() {
        String id = ManagementFactory.getRuntimeMXBean().getName();
        return id.substring(0, id.indexOf(64));
    }

    private String getHostName() {
        String hostName;
        try {
            InetAddress address = InetAddress.getLocalHost();
            hostName = address.getHostName();
        }
        catch (UnknownHostException e) {
            throw new IllegalStateException(e);
        }
        return hostName;
    }
}

