/*
 * Decompiled with CFR 0.152.
 */
package com.quasiris.qsf.commons.ai.embedding;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.quasiris.qsf.commons.ai.dto.Document;
import com.quasiris.qsf.commons.ai.dto.TextVector;
import com.quasiris.qsf.commons.ai.dto.TextVectorDocument;
import com.quasiris.qsf.commons.ai.embedding.TextEmbeddingEncoder;
import com.quasiris.qsf.commons.nlp.SentenceSplitter;
import com.quasiris.qsf.commons.text.normalizer.TextNormalizerService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpEntity;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BertAsAServiceEncoder
implements TextEmbeddingEncoder {
    private static final Logger logger = LoggerFactory.getLogger(BertAsAServiceEncoder.class);
    private static String BULK_FIELD = "_bulk";
    private String baseUrl;
    private Integer timeout;
    private ObjectMapper objectMapper;

    public BertAsAServiceEncoder(String baseUrl, Integer timeout) {
        this.baseUrl = baseUrl;
        this.timeout = timeout;
        this.objectMapper = new ObjectMapper();
    }

    @Override
    public List<TextVector> embed(String text, TextNormalizerService normalizer, boolean autosplit) {
        List<TextVector> textVectors = new ArrayList<TextVector>();
        List<TextVectorDocument> textVectorDocuments = this.embedTextBulk(Arrays.asList(text), normalizer, autosplit);
        if (textVectorDocuments.size() == 1) {
            textVectors = (List)textVectorDocuments.get(0).getFields().get(BULK_FIELD);
        }
        return textVectors;
    }

    private List<TextVectorDocument> embedTextBulk(List<String> textList, TextNormalizerService normalizer, boolean autosplit) {
        SentenceSplitter textSplitter = new SentenceSplitter();
        ArrayList<String> allSentences = new ArrayList<String>();
        ArrayList<TextVector> allVectors = new ArrayList<TextVector>();
        ArrayList<TextVectorDocument> vectorDocs = new ArrayList<TextVectorDocument>();
        for (String text : textList) {
            List<String> sentences = autosplit ? textSplitter.split(text) : Arrays.asList(text);
            TextVectorDocument vectorDoc = new TextVectorDocument();
            vectorDoc.getFields().put(BULK_FIELD, new ArrayList());
            for (String sentence : sentences) {
                String normalized = normalizer != null ? normalizer.normalize(sentence) : sentence;
                if (!StringUtils.isNotBlank((CharSequence)normalized)) continue;
                TextVector textVector = new TextVector(sentence, normalized, null);
                ((List)vectorDoc.getFields().get(BULK_FIELD)).add(textVector);
                allVectors.add(textVector);
                allSentences.add(normalized);
            }
            vectorDocs.add(vectorDoc);
        }
        try {
            HttpEntity requestEntity = this.buildEntity(allSentences);
            HttpPost request = new HttpPost(this.baseUrl);
            request.setEntity(requestEntity);
            request.addHeader("Content-Type", "application/json");
            RequestConfig.Builder requestConfig = RequestConfig.custom();
            requestConfig.setConnectTimeout(this.timeout.intValue());
            requestConfig.setConnectionRequestTimeout(this.timeout.intValue());
            requestConfig.setSocketTimeout(this.timeout.intValue());
            request.setConfig(requestConfig.build());
            try (CloseableHttpClient httpClient = HttpClients.createDefault();
                 CloseableHttpResponse response = httpClient.execute((HttpUriRequest)request);){
                String jsonResult;
                Map responseBody;
                HttpEntity entity = response.getEntity();
                if (entity != null && response.getStatusLine().getStatusCode() == 200 && (responseBody = (Map)this.objectMapper.readValue(jsonResult = EntityUtils.toString((HttpEntity)entity), Map.class)).containsKey("result")) {
                    List outputs = (List)responseBody.get("result");
                    if (outputs.size() != allVectors.size()) {
                        throw new Exception("Input and output vectors does not match!");
                    }
                    for (int i = 0; i < outputs.size(); ++i) {
                        List vectorList = (List)outputs.get(i);
                        Double[] vector = (Double[])vectorList.stream().toArray(Double[]::new);
                        ((TextVector)allVectors.get(i)).setVector(vector);
                    }
                }
            }
            catch (Exception e) {
                logger.warn("Something gone wrong in GET document for BertAsAServiceEncoder!", (Throwable)e);
            }
        }
        catch (JsonProcessingException e) {
            logger.warn(e.getMessage());
        }
        return vectorDocs;
    }

    @Override
    public TextVectorDocument embed(Document<String> doc, TextNormalizerService normalizer, boolean autosplit) {
        TextVectorDocument vectorDoc = new TextVectorDocument(doc.getId());
        List<TextVectorDocument> vectorDocs = this.embedBulk(Arrays.asList(doc), normalizer, autosplit);
        if (vectorDocs.size() == 1) {
            vectorDoc = vectorDocs.get(0);
        }
        return vectorDoc;
    }

    @Override
    public List<TextVectorDocument> embedBulk(List<Document<String>> docs, TextNormalizerService normalizer, boolean autosplit) {
        ArrayList<TextVectorDocument> vectorDocs = new ArrayList<TextVectorDocument>();
        for (Document<String> doc : docs) {
            TextVectorDocument vectorDoc = new TextVectorDocument(doc.getId());
            ArrayList<String> allSentences = new ArrayList<String>();
            for (Map.Entry<String, String> entry : doc.getFields().entrySet()) {
                allSentences.add(entry.getValue());
            }
            List<TextVectorDocument> textVectorDocuments = this.embedTextBulk(allSentences, normalizer, autosplit);
            if (doc.getFields().values().size() == textVectorDocuments.size()) {
                int i = 0;
                for (Map.Entry<String, String> entry : doc.getFields().entrySet()) {
                    vectorDoc.getFields().put(entry.getKey(), (List)textVectorDocuments.get(i).getFields().get(BULK_FIELD));
                    ++i;
                }
            }
            vectorDocs.add(vectorDoc);
        }
        return vectorDocs;
    }

    private HttpEntity buildEntity(List<String> sentences) throws JsonProcessingException {
        HashMap<String, Object> body = new HashMap<String, Object>();
        body.put("id", "");
        body.put("texts", sentences);
        body.put("is_tokenized", false);
        String payload = this.objectMapper.writeValueAsString(body);
        return new StringEntity(payload, "UTF-8");
    }
}

