/*
 * Decompiled with CFR 0.152.
 */
package org.mule.extension.mulechain.helpers;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.tika.exception.TikaException;
import org.apache.tika.metadata.Metadata;
import org.apache.tika.parser.ParseContext;
import org.apache.tika.parser.pdf.PDFParser;
import org.apache.tika.sax.BodyContentHandler;
import org.json.JSONArray;
import org.json.JSONObject;
import org.mule.extension.mulechain.helpers.AwsbedrockPayloadHelper;
import org.mule.extension.mulechain.internal.AwsbedrockConfiguration;
import org.mule.extension.mulechain.internal.embeddings.AwsbedrockParametersEmbedding;
import org.mule.extension.mulechain.internal.embeddings.AwsbedrockParametersEmbeddingDocument;
import org.xml.sax.ContentHandler;
import org.xml.sax.SAXException;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

public class AwsbedrockEmbeddingPayloadHelper {
    private static String getAmazonTitanEmbeddingG1(String prompt) {
        return new JSONObject().put("inputText", (Object)prompt).toString();
    }

    private static String getAmazonTitanEmbeddingG2(String prompt, AwsbedrockParametersEmbedding awsBedrockParameters) {
        return new JSONObject().put("inputText", (Object)prompt).put("dimensions", (Object)awsBedrockParameters.getDimension()).put("normalize", awsBedrockParameters.getNormalize()).toString();
    }

    private static String getAmazonTitanImageEmbeddingG1(String prompt, AwsbedrockParametersEmbedding awsBedrockParameters) {
        JSONObject embeddingConfig = new JSONObject();
        embeddingConfig.put("outputEmbeddingLength", 256);
        JSONObject body = new JSONObject();
        body.put("inputText", (Object)prompt);
        body.put("embeddingConfig", (Object)embeddingConfig);
        return body.toString();
    }

    private static String getCoherEmbeddingModel(String prompt, AwsbedrockParametersEmbedding awsBedrockParameters) {
        JSONObject jsonObject = new JSONObject();
        JSONArray textsArray = new JSONArray();
        for (String text : prompt.split(".")) {
            textsArray.put((Object)text);
        }
        jsonObject.put("texts", (Object)textsArray);
        jsonObject.put("input_type", (Object)"search_query");
        return jsonObject.toString();
    }

    private static String identifyPayload(String prompt, AwsbedrockParametersEmbedding awsBedrockParameters) {
        if (awsBedrockParameters.getModelName().contains("amazon.titan-embed-text-v1")) {
            return AwsbedrockEmbeddingPayloadHelper.getAmazonTitanEmbeddingG1(prompt);
        }
        if (awsBedrockParameters.getModelName().contains("amazon.titan-embed-text-v2:0")) {
            return AwsbedrockEmbeddingPayloadHelper.getAmazonTitanEmbeddingG2(prompt, awsBedrockParameters);
        }
        if (awsBedrockParameters.getModelName().contains("amazon.titan-embed-image-v1")) {
            return AwsbedrockEmbeddingPayloadHelper.getAmazonTitanImageEmbeddingG1(prompt, awsBedrockParameters);
        }
        if (awsBedrockParameters.getModelName().contains("cohere.embed")) {
            return AwsbedrockEmbeddingPayloadHelper.getCoherEmbeddingModel(prompt, awsBedrockParameters);
        }
        return "Unsupported model";
    }

    private static String getAmazonTitanEmbeddingG1Doc(String prompt) {
        return new JSONObject().put("inputText", (Object)prompt).toString();
    }

    private static String getAmazonTitanEmbeddingG2Doc(String prompt, AwsbedrockParametersEmbeddingDocument awsBedrockParameters) {
        return new JSONObject().put("inputText", (Object)prompt).put("dimensions", (Object)awsBedrockParameters.getDimension()).put("normalize", awsBedrockParameters.getNormalize()).toString();
    }

    private static String getAmazonTitanImageEmbeddingG1Doc(String prompt, AwsbedrockParametersEmbeddingDocument awsBedrockParameters) {
        JSONObject embeddingConfig = new JSONObject();
        embeddingConfig.put("outputEmbeddingLength", 256);
        JSONObject body = new JSONObject();
        body.put("inputText", (Object)prompt);
        body.put("embeddingConfig", (Object)embeddingConfig);
        return body.toString();
    }

    private static String getCoherEmbeddingModelDoc(String prompt, AwsbedrockParametersEmbeddingDocument awsBedrockParameters) {
        JSONObject jsonObject = new JSONObject();
        JSONArray textsArray = new JSONArray();
        for (String text : prompt.split(".")) {
            textsArray.put((Object)text);
        }
        jsonObject.put("texts", (Object)textsArray);
        jsonObject.put("input_type", (Object)"search_query");
        return jsonObject.toString();
    }

    private static String identifyPayloadDoc(String prompt, AwsbedrockParametersEmbeddingDocument awsBedrockParameters) {
        if (awsBedrockParameters.getModelName().contains("amazon.titan-embed-text-v1")) {
            return AwsbedrockEmbeddingPayloadHelper.getAmazonTitanEmbeddingG1Doc(prompt);
        }
        if (awsBedrockParameters.getModelName().contains("amazon.titan-embed-text-v2:0")) {
            return AwsbedrockEmbeddingPayloadHelper.getAmazonTitanEmbeddingG2Doc(prompt, awsBedrockParameters);
        }
        if (awsBedrockParameters.getModelName().contains("amazon.titan-embed-image-v1")) {
            return AwsbedrockEmbeddingPayloadHelper.getAmazonTitanImageEmbeddingG1Doc(prompt, awsBedrockParameters);
        }
        if (awsBedrockParameters.getModelName().contains("cohere.embed")) {
            return AwsbedrockEmbeddingPayloadHelper.getCoherEmbeddingModelDoc(prompt, awsBedrockParameters);
        }
        return "Unsupported model";
    }

    private static BedrockRuntimeClient createClient(AwsbedrockConfiguration configuration, Region region) {
        Object awsCredentials = configuration.getAwsSessionToken() == null || configuration.getAwsSessionToken().isEmpty() ? AwsBasicCredentials.create((String)configuration.getAwsAccessKeyId(), (String)configuration.getAwsSecretAccessKey()) : AwsSessionCredentials.create((String)configuration.getAwsAccessKeyId(), (String)configuration.getAwsSecretAccessKey(), (String)configuration.getAwsSessionToken());
        return (BedrockRuntimeClient)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)BedrockRuntimeClient.builder().credentialsProvider((AwsCredentialsProvider)StaticCredentialsProvider.create((AwsCredentials)awsCredentials))).region(region)).build();
    }

    private static InvokeModelRequest createInvokeRequest(String modelId, String nativeRequest) {
        return (InvokeModelRequest)InvokeModelRequest.builder().body(SdkBytes.fromUtf8String((String)nativeRequest)).accept("application/json").contentType("application/json").modelId(modelId).build();
    }

    public static JSONObject generateEmbedding(String modelId, String body, AwsbedrockConfiguration configuration, Region region) throws IOException {
        BedrockRuntimeClient bedrock = AwsbedrockEmbeddingPayloadHelper.createClient(configuration, region);
        InvokeModelRequest request = AwsbedrockEmbeddingPayloadHelper.createInvokeRequest(modelId, body);
        InvokeModelResponse response = bedrock.invokeModel(request);
        String responseBody = new String(response.body().asByteArray(), StandardCharsets.UTF_8);
        return new JSONObject(responseBody);
    }

    public static String invokeModel(String prompt, AwsbedrockConfiguration configuration, AwsbedrockParametersEmbedding awsBedrockParameters) {
        Region region = AwsbedrockPayloadHelper.getRegion(awsBedrockParameters.getRegion());
        String modelId = awsBedrockParameters.getModelName();
        String body = AwsbedrockEmbeddingPayloadHelper.identifyPayload(prompt, awsBedrockParameters);
        System.out.println(body);
        try {
            JSONObject response = AwsbedrockEmbeddingPayloadHelper.generateEmbedding(modelId, body, configuration, region);
            return response.toString();
        }
        catch (Exception e) {
            System.err.println("Error: " + e.getMessage());
            e.printStackTrace();
            return null;
        }
    }

    public static String InvokeAdhocRAG(String prompt, String filePath, AwsbedrockConfiguration configuration, AwsbedrockParametersEmbeddingDocument awsBedrockParameters) throws IOException, SAXException, TikaException {
        Region region = AwsbedrockPayloadHelper.getRegion(awsBedrockParameters.getRegion());
        String modelId = awsBedrockParameters.getModelName();
        List<String> corpus = awsBedrockParameters.getOptionType().equals("FULL") ? Arrays.asList(AwsbedrockEmbeddingPayloadHelper.splitFullDocument(filePath, awsBedrockParameters)) : Arrays.asList(AwsbedrockEmbeddingPayloadHelper.splitByType(filePath, awsBedrockParameters));
        String body = AwsbedrockEmbeddingPayloadHelper.identifyPayloadDoc(prompt, awsBedrockParameters);
        try {
            JSONObject queryResponse = AwsbedrockEmbeddingPayloadHelper.generateEmbedding(modelId, body, configuration, region);
            JSONArray queryEmbedding = queryResponse.getJSONArray("embedding");
            String corpusBody = null;
            ArrayList<JSONArray> corpusEmbeddings = new ArrayList<JSONArray>();
            for (String string : corpus) {
                corpusBody = AwsbedrockEmbeddingPayloadHelper.identifyPayloadDoc(string, awsBedrockParameters);
                if (string == null || string.isEmpty()) continue;
                body = AwsbedrockEmbeddingPayloadHelper.identifyPayloadDoc(corpusBody, awsBedrockParameters);
                corpusEmbeddings.add(AwsbedrockEmbeddingPayloadHelper.generateEmbedding(modelId, body, configuration, region).getJSONArray("embedding"));
            }
            ArrayList<Double> similarityScores = new ArrayList<Double>();
            for (JSONArray corpusEmbedding : corpusEmbeddings) {
                similarityScores.add(AwsbedrockEmbeddingPayloadHelper.calculateCosineSimilarity(queryEmbedding, corpusEmbedding));
            }
            List<String> list = AwsbedrockEmbeddingPayloadHelper.rankAndPrintResults(corpus, similarityScores);
            JSONArray jsonArray = new JSONArray(list);
            return jsonArray.toString();
        }
        catch (Exception e) {
            System.err.println("Error: " + e.getMessage());
            e.printStackTrace();
            return null;
        }
    }

    private static double calculateCosineSimilarity(JSONArray vec1, JSONArray vec2) {
        double dotProduct = 0.0;
        double normA = 0.0;
        double normB = 0.0;
        for (int i = 0; i < vec1.length(); ++i) {
            double a = vec1.getDouble(i);
            double b = vec2.getDouble(i);
            dotProduct += a * b;
            normA += Math.pow(a, 2.0);
            normB += Math.pow(b, 2.0);
        }
        return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
    }

    private static List<String> rankAndPrintResults(List<String> corpus, List<Double> similarityScores) {
        ArrayList<Integer> indices = new ArrayList<Integer>();
        System.out.println(corpus.size());
        for (int i2 = 0; i2 < corpus.size(); ++i2) {
            indices.add(i2);
        }
        indices.sort((i, j) -> Double.compare((Double)similarityScores.get((int)j), (Double)similarityScores.get((int)i)));
        System.out.println("Ranked results:");
        ArrayList<String> results = new ArrayList<String>();
        Iterator iterator = indices.iterator();
        while (iterator.hasNext()) {
            int index = (Integer)iterator.next();
            System.out.println("Score: " + similarityScores.get(index) + " - Text: " + corpus.get(index));
            results.add(similarityScores.get(index) + " - " + corpus.get(index));
        }
        return results;
    }

    private static String getContentFromFile(String filePath) throws IOException, SAXException, TikaException {
        BodyContentHandler handler = new BodyContentHandler();
        Metadata metadata = new Metadata();
        FileInputStream inputstream = new FileInputStream(new File(filePath));
        ParseContext pcontext = new ParseContext();
        PDFParser pdfparser = new PDFParser();
        pdfparser.parse((InputStream)inputstream, (ContentHandler)handler, metadata, pcontext);
        return handler.toString();
    }

    private static String splitFullDocument(String filePath, AwsbedrockParametersEmbeddingDocument awsBedrockParameters) throws IOException, SAXException, TikaException {
        String content = AwsbedrockEmbeddingPayloadHelper.getContentFromFile(filePath);
        return content;
    }

    private static String[] splitByType(String filePath, AwsbedrockParametersEmbeddingDocument awsBedrockParameters) throws IOException, SAXException, TikaException {
        String content = AwsbedrockEmbeddingPayloadHelper.getContentFromFile(filePath);
        String[] parts = AwsbedrockEmbeddingPayloadHelper.splitContent(content, awsBedrockParameters.getOptionType());
        return parts;
    }

    private static String[] splitContent(String text, String option) {
        switch (option) {
            case "PARAGRAPH": {
                return AwsbedrockEmbeddingPayloadHelper.splitByParagraphs(text);
            }
            case "SENTENCES": {
                return AwsbedrockEmbeddingPayloadHelper.splitBySentences(text);
            }
        }
        throw new IllegalArgumentException("Unknown split option: " + option);
    }

    private static String[] splitByParagraphs(String text) {
        return AwsbedrockEmbeddingPayloadHelper.removeEmptyStrings(text.split("\\r?\\n\\r?\\n"));
    }

    private static String[] splitBySentences(String text) {
        return AwsbedrockEmbeddingPayloadHelper.removeEmptyStrings(text.split("(?<!Mr|Mrs|Ms|Dr|Sr|Jr|Prof)\\.\\s+"));
    }

    public static String[] removeEmptyStrings(String[] array) {
        ArrayList<String> list = new ArrayList<String>(Arrays.asList(array));
        list.removeIf(String::isEmpty);
        return list.toArray(new String[0]);
    }
}

