/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.jcore.consumer.ew;

import de.julielab.jcore.consumer.ew.Encoder;
import de.julielab.jcore.consumer.ew.VectorOperations;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;

public class Decoder {
    public static Pair<List<String>, List<double[]>> decodeBinaryEmbeddingVectors(InputStream is) throws IOException {
        return Decoder.decodeBinaryEmbeddingVectors(is, 8192);
    }

    public static Pair<List<String>, List<double[]>> decodeBinaryEmbeddingVectors(InputStream is, int bufferSize) throws IOException {
        int bytesRead;
        byte[] buffer = new byte[bufferSize];
        byte[] integerBuffer = new byte[4];
        byte[] doubleBuffer = new byte[8];
        ArrayList<String> expressions = new ArrayList<String>();
        ArrayList<double[]> vectors = new ArrayList<double[]>();
        while ((bytesRead = is.read(buffer)) != -1) {
            ByteBuffer bb = ByteBuffer.wrap(buffer);
            bb.limit(bytesRead);
            while (bb.position() < bb.limit()) {
                Pair<String, double[]> pair = Decoder.getNextTextEmbeddingPair(is, bb, integerBuffer, doubleBuffer);
                expressions.add((String)pair.getLeft());
                vectors.add((double[])pair.getRight());
            }
        }
        return new ImmutablePair(expressions, vectors);
    }

    private static Pair<String, double[]> getNextTextEmbeddingPair(InputStream is, ByteBuffer bb, byte[] integerBuffer, byte[] doubleBuffer) throws IOException {
        int textLength = Decoder.readInt(integerBuffer, bb, is);
        byte[] currentText = new byte[textLength];
        Decoder.readNumberOfBytes(currentText, bb, is);
        String text = new String(currentText, StandardCharsets.UTF_8);
        int vectorLength = Decoder.readInt(integerBuffer, bb, is);
        double[] currentVector = new double[vectorLength];
        for (int i = 0; i < vectorLength; ++i) {
            currentVector[i] = Decoder.readDouble(doubleBuffer, bb, is);
        }
        return new ImmutablePair((Object)text, (Object)currentVector);
    }

    public static void mergeEmbeddingFiles(List<InputStream> inputStreams, OutputStream os, boolean groupByText) throws IOException {
        ArrayList<Pair<String, double[]>> outputBuffer = new ArrayList<Pair<String, double[]>>();
        ArrayList<Pair<String, double[]>> currentStreamValues = new ArrayList<Pair<String, double[]>>(inputStreams.size());
        ArrayList<ByteBuffer> streamBuffers = new ArrayList<ByteBuffer>(inputStreams.size());
        byte[] integerBuffer = new byte[4];
        byte[] doubleBuffer = new byte[8];
        for (InputStream ignore : inputStreams) {
            streamBuffers.add(ByteBuffer.allocate(8192));
        }
        for (int i = 0; i < inputStreams.size(); ++i) {
            InputStream is = inputStreams.get(i);
            ByteBuffer bb = (ByteBuffer)streamBuffers.get(i);
            bb.position(bb.capacity());
            try {
                Pair<String, double[]> pair = Decoder.getNextTextEmbeddingPair(is, bb, integerBuffer, doubleBuffer);
                currentStreamValues.add(pair);
                continue;
            }
            catch (NoSuchElementException e) {
                currentStreamValues.add(null);
            }
        }
        int numExhaustedStreams = 0;
        ByteBuffer outputBb = ByteBuffer.allocate(8192);
        Pair<String, double[]> nextVectorPair = null;
        while (numExhaustedStreams < inputStreams.size()) {
            int minIndex = -1;
            String min = null;
            for (int i = 0; i < currentStreamValues.size(); ++i) {
                Pair streamValue = (Pair)currentStreamValues.get(i);
                if (streamValue == null || min != null && ((String)streamValue.getLeft()).compareTo(min) >= 0) continue;
                min = (String)streamValue.getLeft();
                minIndex = i;
            }
            if (minIndex != -1) {
                Pair minValue = (Pair)currentStreamValues.get(minIndex);
                Decoder.writeTextEmbeddingPairToOutputStream((Pair<String, double[]>)minValue, outputBuffer, os, outputBb, groupByText);
                outputBuffer.add((Pair<String, double[]>)minValue);
                InputStream minStream = inputStreams.get(minIndex);
                ByteBuffer minBb = (ByteBuffer)streamBuffers.get(minIndex);
                try {
                    nextVectorPair = Decoder.getNextTextEmbeddingPair(minStream, minBb, integerBuffer, doubleBuffer);
                }
                catch (NoSuchElementException e) {
                    nextVectorPair = null;
                }
                currentStreamValues.set(minIndex, nextVectorPair);
            }
            numExhaustedStreams = (int)currentStreamValues.stream().filter(Objects::isNull).count();
        }
        Decoder.writeTextEmbeddingPairToOutputStream(null, outputBuffer, os, outputBb, groupByText);
    }

    private static void writeTextEmbeddingPairToOutputStream(Pair<String, double[]> textEmbeddingPair, List<Pair<String, double[]>> outputBuffer, OutputStream os, ByteBuffer outputBb, boolean groupByText) throws IOException {
        if (!(outputBuffer.isEmpty() || textEmbeddingPair != null && ((String)outputBuffer.get(outputBuffer.size() - 1).getLeft()).equals(textEmbeddingPair.getLeft()))) {
            if (groupByText) {
                double[] avgVector = VectorOperations.getAverageEmbeddingVector(outputBuffer.stream().map(Pair::getRight));
                String text = (String)outputBuffer.get(0).getLeft();
                os.write(Encoder.encodeTextVectorPair(text, avgVector, outputBb));
            } else {
                for (Pair<String, double[]> p : outputBuffer) {
                    os.write(Encoder.encodeTextVectorPair(p, outputBb));
                }
            }
            outputBuffer.clear();
        }
    }

    public static double readDouble(byte[] dest, ByteBuffer bb, InputStream is) throws IOException {
        Decoder.readNumberOfBytes(dest, bb, is);
        int upper = ((dest[0] & 0xFF) << 24) + ((dest[1] & 0xFF) << 16) + ((dest[2] & 0xFF) << 8) + ((dest[3] & 0xFF) << 0);
        int lower = ((dest[4] & 0xFF) << 24) + ((dest[5] & 0xFF) << 16) + ((dest[6] & 0xFF) << 8) + ((dest[7] & 0xFF) << 0);
        return Double.longBitsToDouble(((long)upper << 32) + ((long)lower & 0xFFFFFFFFL));
    }

    public static int readInt(byte[] dest, ByteBuffer bb, InputStream is) throws IOException {
        Decoder.readNumberOfBytes(dest, bb, is);
        return (dest[0] & 0xFF) << 24 | (dest[1] & 0xFF) << 16 | (dest[2] & 0xFF) << 8 | dest[3] & 0xFF;
    }

    public static void readNumberOfBytes(byte[] dest, ByteBuffer bb, InputStream is) throws IOException {
        int bytesRead = 0;
        while (bytesRead < dest.length) {
            while (bb.position() < bb.limit() && bytesRead < dest.length) {
                dest[bytesRead] = bb.get();
                ++bytesRead;
            }
            if (bytesRead >= dest.length) continue;
            int numRead = is.read(bb.array());
            if (numRead == -1) {
                throw new NoSuchElementException("The input stream does not offer enough bytes to fill the passed destination array.");
            }
            bb.position(0);
            bb.limit(numRead < 0 ? 0 : numRead);
        }
    }
}

