/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.connector.einsteinai.internal.modelsapi.helpers;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.mulesoft.connector.einsteinai.api.metadata.EinsteinResponseAttributes;
import com.mulesoft.connector.einsteinai.api.metadata.ResponseParameters;
import com.mulesoft.connector.einsteinai.internal.connection.EinsteinConnection;
import com.mulesoft.connector.einsteinai.internal.error.EinsteinErrorType;
import com.mulesoft.connector.einsteinai.internal.helpers.HttpRequestHelper;
import com.mulesoft.connector.einsteinai.internal.helpers.ThrowingFunction;
import com.mulesoft.connector.einsteinai.internal.modelsapi.dto.EinsteinEmbeddingResponseDTO;
import com.mulesoft.connector.einsteinai.internal.modelsapi.helpers.ResponseHelper;
import com.mulesoft.connector.einsteinai.internal.modelsapi.models.ParamsEmbeddingDocumentDetails;
import com.mulesoft.connector.einsteinai.internal.modelsapi.models.ParamsEmbeddingModelDetails;
import com.mulesoft.connector.einsteinai.internal.modelsapi.models.ParamsModelDetails;
import com.mulesoft.connector.einsteinai.internal.modelsapi.models.RAGParamsModelDetails;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeoutException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.tika.exception.TikaException;
import org.apache.tika.metadata.Metadata;
import org.apache.tika.parser.AutoDetectParser;
import org.apache.tika.parser.ParseContext;
import org.apache.tika.parser.txt.TXTParser;
import org.apache.tika.sax.BodyContentHandler;
import org.json.JSONArray;
import org.json.JSONObject;
import org.json.JSONTokener;
import org.mule.runtime.api.util.MultiMap;
import org.mule.runtime.extension.api.error.ErrorTypeDefinition;
import org.mule.runtime.extension.api.exception.ModuleException;
import org.mule.runtime.extension.api.runtime.operation.Result;
import org.mule.runtime.extension.api.runtime.process.CompletionCallback;
import org.mule.runtime.http.api.domain.entity.EmptyHttpEntity;
import org.mule.runtime.http.api.domain.entity.HttpEntity;
import org.mule.runtime.http.api.domain.entity.InputStreamHttpEntity;
import org.mule.runtime.http.api.domain.message.request.HttpRequest;
import org.mule.runtime.http.api.domain.message.request.HttpRequestBuilder;
import org.mule.runtime.http.api.domain.message.response.HttpResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.ContentHandler;
import org.xml.sax.SAXException;

public class RequestHelper {
    private static final Logger log = LoggerFactory.getLogger(RequestHelper.class);
    private final EinsteinConnection einsteinConnection;

    public RequestHelper(EinsteinConnection einsteinConnection) {
        this.einsteinConnection = einsteinConnection;
    }

    public void executeGenerateText(String prompt, ParamsModelDetails paramDetails, CompletionCallback<InputStream, EinsteinResponseAttributes> callback) {
        String payload = this.constructPayload(prompt, paramDetails.getLocale(), paramDetails.getProbability());
        InputStreamHttpEntity payloadStream = new InputStreamHttpEntity((InputStream)new ByteArrayInputStream(payload.getBytes()));
        this.executeEinsteinRequest(payloadStream, paramDetails.getModelApiName(), "POST", "/generations", callback, ResponseHelper::createEinsteinFormattedResponse);
    }

    public void generateChatFromMessages(String messages, ParamsModelDetails paramDetails, CompletionCallback<InputStream, ResponseParameters> callback) {
        String payload = this.constructPayloadWithMessages(messages, paramDetails);
        InputStreamHttpEntity payloadStream = new InputStreamHttpEntity((InputStream)new ByteArrayInputStream(payload.getBytes()));
        this.executeEinsteinRequest(payloadStream, paramDetails.getModelApiName(), "POST", "/chat-generations", callback, ResponseHelper::createEinsteinChatFromMessagesResponse);
    }

    public void generateEmbeddingFromText(String text, ParamsEmbeddingModelDetails paramDetails, CompletionCallback<InputStream, ResponseParameters> callback) {
        String payload = this.constructEmbeddingJsonPayload(text);
        InputStreamHttpEntity payloadStream = new InputStreamHttpEntity((InputStream)new ByteArrayInputStream(payload.getBytes()));
        this.executeEinsteinRequest(payloadStream, paramDetails.getModelApiName(), "POST", "/embeddings", callback, ResponseHelper::createEinsteinEmbeddingResponse);
    }

    private <A> void executeEinsteinRequest(InputStreamHttpEntity payload, String modelApiName, String httpMethod, String uriModelsApiEmbeddings, CompletionCallback<InputStream, A> callback, ThrowingFunction<InputStream, Result<InputStream, A>> responseConverter) {
        String urlString = this.einsteinConnection.getApiInstanceUrl() + "/einstein/platform/v1/models/" + modelApiName + uriModelsApiEmbeddings;
        log.debug("Einstein Request URL: {}", (Object)urlString);
        CompletableFuture completableFuture = this.einsteinConnection.getHttpClient().sendAsync(this.buildRequest(urlString, this.einsteinConnection.getAccessToken(), httpMethod, (HttpEntity)(payload != null ? payload : new EmptyHttpEntity())));
        completableFuture.whenComplete((response, exception) -> HttpRequestHelper.handleHttpResponse(response, exception, EinsteinErrorType.MODELS_API_ERROR, callback, responseConverter));
    }

    public JSONArray generateEmbeddingFromFileInputStream(InputStream inputStream, ParamsEmbeddingDocumentDetails embeddingDocumentDetails) throws IOException, TikaException, SAXException, TimeoutException {
        List<List<Double>> allEmbeddings;
        if (inputStream == null) {
            throw new IllegalArgumentException("Input stream is null.");
        }
        try (BufferedInputStream bufferedInputStream = new BufferedInputStream(inputStream);){
            List<String> corpus = this.createCorpusListFromStream(bufferedInputStream, embeddingDocumentDetails.getFileType(), embeddingDocumentDetails.getOptionType());
            allEmbeddings = "PARAGRAPH".equalsIgnoreCase(embeddingDocumentDetails.getOptionType()) ? this.getBatchCorpusEmbeddings(embeddingDocumentDetails.getModelApiName(), corpus) : this.getCorpusEmbeddings(embeddingDocumentDetails.getModelApiName(), corpus);
        }
        return new JSONArray(allEmbeddings);
    }

    public InputStream executeRAG(String text, RAGParamsModelDetails paramDetails) throws IOException, TimeoutException {
        String payload = this.constructPayload(text, paramDetails.getLocale(), paramDetails.getProbability());
        return this.executeEinsteinRequest(payload, paramDetails.getModelApiName(), "/generations");
    }

    public InputStream executeTools(String originalPrompt, String prompt, InputStream inputStream, ParamsModelDetails paramDetails) throws IOException, TimeoutException {
        String payload = this.constructPayload(prompt, paramDetails.getLocale(), paramDetails.getProbability());
        String payloadOptional = this.constructPayload(originalPrompt, paramDetails.getLocale(), paramDetails.getProbability());
        String intermediateAnswer = HttpRequestHelper.readResponseStream(this.executeEinsteinRequest(payload, paramDetails.getModelApiName(), "/generations"));
        List<String> urls = this.extractUrls(intermediateAnswer);
        if (urls != null) {
            JSONObject jsonObject = new JSONObject(intermediateAnswer);
            String generatedText = jsonObject.getJSONObject("generation").getString("generatedText");
            String ePayload = this.buildPayload(generatedText);
            String response = this.getAttributes(urls.get(0), inputStream, this.extractPayload(ePayload));
            String finalPayload = this.constructPayload("data: " + response + ", question: " + originalPrompt, paramDetails.getLocale(), paramDetails.getProbability());
            return this.executeEinsteinRequest(finalPayload, paramDetails.getModelApiName(), "/generations");
        }
        return this.executeEinsteinRequest(payloadOptional, paramDetails.getModelApiName(), "/generations");
    }

    public JSONArray embeddingFileQuery(String prompt, InputStream inputStream, String modelName, String fileType, String optionType) throws IOException, TikaException, SAXException, TimeoutException {
        String body = this.constructEmbeddingJsonPayload(prompt);
        List<Double> embeddings = this.getQueryEmbedding(body, modelName);
        List<String> corpus = this.createCorpusListFromStream(inputStream, fileType, optionType);
        List<List<Double>> corpusEmbeddings = this.getCorpusEmbeddings(modelName, corpus);
        ArrayList<Double> similarityScores = new ArrayList<Double>();
        corpusEmbeddings.forEach(corpusEmbedding -> similarityScores.add(this.calculateCosineSimilarity(embeddings, (List<Double>)corpusEmbedding)));
        List<String> results = this.rankAndPrintResults(corpus, similarityScores);
        return new JSONArray(results);
    }

    private List<String> createCorpusListFromStream(InputStream inputStream, String fileType, String splitOption) throws TikaException, IOException, SAXException {
        List<String> corpus = "FULL".equalsIgnoreCase(splitOption) ? Collections.singletonList(this.splitFullDocument(inputStream, fileType)) : Arrays.asList(this.splitByType(inputStream, fileType, splitOption));
        return corpus;
    }

    private List<List<Double>> getCorpusEmbeddings(String modelName, List<String> corpus) throws IOException, TimeoutException {
        ArrayList<List<Double>> corpusEmbeddings = new ArrayList<List<Double>>();
        for (String text : corpus) {
            if (text == null || text.isEmpty()) continue;
            String corpusBody = this.constructEmbeddingJsonPayload(text);
            InputStream embeddingResponse = this.executeEinsteinRequest(corpusBody, modelName, "/embeddings");
            Throwable throwable = null;
            try {
                EinsteinEmbeddingResponseDTO embeddingResponseDTO = (EinsteinEmbeddingResponseDTO)new ObjectMapper().readValue(embeddingResponse, EinsteinEmbeddingResponseDTO.class);
                corpusEmbeddings.add(embeddingResponseDTO.getEmbeddings().get(0).getEmbeddings());
            }
            catch (Throwable throwable2) {
                throwable = throwable2;
                throw throwable2;
            }
            finally {
                if (embeddingResponse == null) continue;
                if (throwable != null) {
                    try {
                        embeddingResponse.close();
                    }
                    catch (Throwable throwable3) {
                        throwable.addSuppressed(throwable3);
                    }
                    continue;
                }
                embeddingResponse.close();
            }
        }
        return corpusEmbeddings;
    }

    private List<List<Double>> getBatchCorpusEmbeddings(String modelName, List<String> corpus) {
        ArrayList<List<Double>> allEmbeddings = new ArrayList<List<Double>>();
        for (int i = 0; i < corpus.size(); i += 100) {
            List<String> batch = corpus.subList(i, Math.min(i + 100, corpus.size()));
            String batchJsonPayload = this.constructEmbeddingJsonPayload(batch);
            try (InputStream embeddingResponse = this.executeEinsteinRequest(batchJsonPayload, modelName, "/embeddings");){
                EinsteinEmbeddingResponseDTO embeddingResponseDTO = (EinsteinEmbeddingResponseDTO)new ObjectMapper().readValue(embeddingResponse, EinsteinEmbeddingResponseDTO.class);
                allEmbeddings.add(embeddingResponseDTO.getEmbeddings().get(0).getEmbeddings());
                continue;
            }
            catch (IOException | TimeoutException e) {
                throw new ModuleException("Error fetching embeddings", (ErrorTypeDefinition)EinsteinErrorType.MODELS_API_ERROR, (Throwable)e);
            }
        }
        return allEmbeddings;
    }

    private List<Double> getQueryEmbedding(String body, String modelName) throws IOException, TimeoutException {
        InputStream embeddingResponse = this.executeEinsteinRequest(body, modelName, "/embeddings");
        EinsteinEmbeddingResponseDTO embeddingResponseDTO = (EinsteinEmbeddingResponseDTO)new ObjectMapper().readValue(embeddingResponse, EinsteinEmbeddingResponseDTO.class);
        return embeddingResponseDTO.getEmbeddings().get(0).getEmbeddings();
    }

    private double calculateCosineSimilarity(List<Double> embeddingList, List<Double> corpusEmbedding) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < embeddingList.size(); ++i) {
            double a = embeddingList.get(i);
            double b = corpusEmbedding.get(i);
            dotProduct += a * b;
            normA += Math.pow(a, 2.0);
            normB += Math.pow(b, 2.0);
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }

    private List<String> rankAndPrintResults(List<String> corpus, List<Double> similarityScores) {
        if (!similarityScores.isEmpty() && !corpus.isEmpty()) {
            List indices = IntStream.range(0, corpus.size()).boxed().sorted((i, j) -> Double.compare((Double)similarityScores.get((int)j), (Double)similarityScores.get((int)i))).collect(Collectors.toList());
            return indices.stream().map(index -> similarityScores.get((int)index) + " - " + (String)corpus.get((int)index)).collect(Collectors.toList());
        }
        return Collections.emptyList();
    }

    private InputStream executeEinsteinRequest(String payload, String modelName, String resource) throws IOException, TimeoutException {
        String urlString = this.einsteinConnection.getApiInstanceUrl() + "/einstein/platform/v1/models/" + modelName + resource;
        log.debug("Einstein Request URL: {}", (Object)urlString);
        ByteArrayInputStream payloadStream = new ByteArrayInputStream(payload.getBytes(StandardCharsets.UTF_8));
        HttpResponse httpResponse = this.einsteinConnection.getHttpClient().send(this.buildRequest(urlString, this.einsteinConnection.getAccessToken(), "POST", (HttpEntity)new InputStreamHttpEntity((InputStream)payloadStream)));
        return HttpRequestHelper.handleHttpResponse(httpResponse, EinsteinErrorType.MODELS_API_ERROR);
    }

    private String getFileTypeContextFromFile(InputStream inputStream, String fileType) throws IOException, SAXException, TikaException {
        BodyContentHandler handler = new BodyContentHandler(-1);
        Metadata metadata = new Metadata();
        ParseContext pcontext = new ParseContext();
        AutoDetectParser parser = "PDF".equalsIgnoreCase(fileType) ? new AutoDetectParser() : new TXTParser();
        parser.parse(inputStream, (ContentHandler)handler, metadata, pcontext);
        String content = handler.toString();
        content = content.replace("\u00a0", " ").replace("\u200b", "").replace("\ufeff", "").replaceAll("[\\p{Cc}&&[^\\t\\n\\r]]", "");
        return content.trim();
    }

    private String splitFullDocument(InputStream inputStream, String fileType) throws TikaException, IOException, SAXException {
        String content = this.getFileTypeContextFromFile(inputStream, fileType);
        content = content.replaceAll("(\\r?\\n)+", "\n");
        return content.trim();
    }

    private String[] splitByType(InputStream inputStream, String fileType, String splitOption) throws IOException, SAXException, TikaException {
        String content = this.getFileTypeContextFromFile(inputStream, fileType);
        return this.splitContentByParagraph(content, splitOption);
    }

    private String[] splitContentByParagraph(String text, String option) {
        if ("PARAGRAPH".equalsIgnoreCase(option)) {
            return this.splitByParagraphs(text);
        }
        throw new IllegalArgumentException("Unknown split option: " + option);
    }

    private String[] splitByParagraphs(String text) {
        return this.removeEmptyStrings(text.split("\\r?\\n+"));
    }

    private String[] removeEmptyStrings(String[] array) {
        return (String[])Arrays.stream(array).filter(Objects::nonNull).map(String::trim).filter(trim -> !trim.isEmpty()).toArray(String[]::new);
    }

    private List<String> extractUrls(String input) {
        String urlPattern = "(https?://[\\w\\-\\.]+(?:\\.[\\w\\-]+)+(?:[\\w\\-.,@?^=%&:/~+#]*[\\w\\-@?^=%&/~+#])?)";
        Pattern pattern = Pattern.compile(urlPattern);
        Matcher matcher = pattern.matcher(input);
        ArrayList<String> urls = new ArrayList<String>();
        while (matcher.find()) {
            urls.add(matcher.group(1));
        }
        return urls.isEmpty() ? null : urls;
    }

    private String constructPayload(String prompt, String locale, Number probability) {
        JSONObject localization = new JSONObject();
        localization.put("defaultLocale", (Object)locale);
        JSONArray inputLocales = new JSONArray();
        JSONObject inputLocale = new JSONObject();
        inputLocale.put("locale", (Object)locale);
        inputLocale.put("probability", (Object)probability);
        inputLocales.put((Object)inputLocale);
        localization.put("inputLocales", (Object)inputLocales);
        JSONArray expectedLocales = new JSONArray();
        expectedLocales.put((Object)locale);
        localization.put("expectedLocales", (Object)expectedLocales);
        JSONObject jsonPayload = new JSONObject();
        jsonPayload.put("prompt", (Object)prompt);
        jsonPayload.put("localization", (Object)localization);
        jsonPayload.put("tags", (Object)new JSONObject());
        return jsonPayload.toString();
    }

    private String constructPayloadWithMessages(String message, ParamsModelDetails paramsModelDetails) {
        JSONArray messages = new JSONArray(message);
        JSONObject locale = new JSONObject();
        locale.put("locale", (Object)paramsModelDetails.getLocale());
        locale.put("probability", (Object)paramsModelDetails.getProbability());
        JSONArray inputLocales = new JSONArray();
        inputLocales.put((Object)locale);
        JSONArray expectedLocales = new JSONArray();
        expectedLocales.put((Object)paramsModelDetails.getLocale());
        JSONObject localization = new JSONObject();
        localization.put("defaultLocale", (Object)paramsModelDetails.getLocale());
        localization.put("inputLocales", (Object)inputLocales);
        localization.put("expectedLocales", (Object)expectedLocales);
        JSONObject tags = new JSONObject();
        JSONObject jsonObject = new JSONObject();
        jsonObject.put("messages", (Object)messages);
        jsonObject.put("localization", (Object)localization);
        jsonObject.put("tags", (Object)tags);
        return jsonObject.toString();
    }

    private String constructEmbeddingJsonPayload(String text) {
        JSONArray input = new JSONArray();
        input.put((Object)text);
        JSONObject jsonObject = new JSONObject();
        jsonObject.put("input", (Object)input);
        return jsonObject.toString();
    }

    private String constructEmbeddingJsonPayload(List<String> texts) {
        JSONObject jsonPayload = new JSONObject();
        jsonPayload.put("input", (Object)new JSONArray(texts));
        return jsonPayload.toString();
    }

    private String getAttributes(String url, InputStream toolsConfigInputStream, String payload) throws IOException, TimeoutException {
        try (InputStream inputStream = toolsConfigInputStream;){
            JSONTokener tokener = new JSONTokener(inputStream);
            JSONArray rootArray = new JSONArray(tokener);
            String responseString = "";
            for (int i = 0; i < rootArray.length(); ++i) {
                JSONObject node = rootArray.getJSONObject(i);
                if (!node.getString("url").trim().equals(url)) continue;
                String method = node.getString("method");
                String headers = node.getString("headers");
                InputStreamHttpEntity httpEntity = "POST".equalsIgnoreCase(method) ? new InputStreamHttpEntity((InputStream)new ByteArrayInputStream(payload.getBytes(StandardCharsets.UTF_8))) : new EmptyHttpEntity();
                MultiMap requestHeaders = new MultiMap();
                if (headers != null && !headers.isEmpty()) {
                    requestHeaders.put((Object)"Authorization", (Object)headers);
                }
                requestHeaders.put((Object)"Content-Type", (Object)"application/json;charset=utf-8");
                HttpRequest request = ((HttpRequestBuilder)((HttpRequestBuilder)HttpRequest.builder().uri(url).headers(requestHeaders)).method(method).entity((HttpEntity)httpEntity)).build();
                HttpResponse httpResponse = this.einsteinConnection.getHttpClient().send(request);
                responseString = HttpRequestHelper.handleHttpResponseForTools(httpResponse);
                break;
            }
            String string = responseString;
            return string;
        }
    }

    private String extractPayload(String payload) {
        Pattern pattern = Pattern.compile("\\{.*\\}");
        Matcher matcher = pattern.matcher(payload);
        String response = matcher.find() ? matcher.group() : "Payload not found!";
        return response;
    }

    private String buildPayload(String payload) {
        String findPayload = this.extractPayload(payload);
        if (findPayload.equals("Payload not found!")) {
            return this.extractPayload(payload);
        }
        return findPayload;
    }

    private HttpRequest buildRequest(String url, String accessToken, String httpMethod, HttpEntity httpEntity) {
        return ((HttpRequestBuilder)((HttpRequestBuilder)HttpRequest.builder().uri(url).headers(this.addConnectionHeaders(accessToken))).method(httpMethod).entity(httpEntity)).build();
    }

    private MultiMap<String, String> addConnectionHeaders(String accessToken) {
        MultiMap multiMap = new MultiMap();
        multiMap.put((Object)"Authorization", (Object)("Bearer " + accessToken));
        multiMap.put((Object)"x-sfdc-app-context", (Object)"EinsteinGPT");
        multiMap.put((Object)"x-client-feature-id", (Object)"ai-platform-models-connected-app");
        multiMap.put((Object)"Content-Type", (Object)"application/json;charset=utf-8");
        return multiMap;
    }
}

