/*
 * Decompiled with CFR 0.152.
 */
package org.mule.extension.mulechain.helpers;

import java.awt.image.BufferedImage;
import java.awt.image.RenderedImage;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.util.Base64;
import javax.imageio.ImageIO;
import org.json.JSONArray;
import org.json.JSONObject;
import org.mule.extension.mulechain.helpers.AwsbedrockPayloadHelper;
import org.mule.extension.mulechain.internal.AwsbedrockConfiguration;
import org.mule.extension.mulechain.internal.image.AwsbedrockImageParameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

public class AwsbedrockImagePayloadHelper {
    private static final Logger logger = LoggerFactory.getLogger(AwsbedrockImagePayloadHelper.class);

    private static String getAmazonTitanImage(String prompt, String avoidInImage, AwsbedrockImageParameters awsBedrockParameters) {
        return new JSONObject().put("taskType", (Object)"TEXT_IMAGE").put("textToImageParams", (Object)new JSONObject().put("text", (Object)prompt).put("negativeText", (Object)avoidInImage)).put("imageGenerationConfig", (Object)new JSONObject().put("numberOfImages", (Object)awsBedrockParameters.getNumOfImages()).put("height", (Object)awsBedrockParameters.getHeight()).put("width", (Object)awsBedrockParameters.getWidth()).put("cfgScale", (Object)awsBedrockParameters.getCfgScale()).put("seed", (Object)awsBedrockParameters.getSeed())).toString();
    }

    private static String getAmazonNovaImage(String prompt, String avoidInImage, AwsbedrockImageParameters awsBedrockParameters) {
        return AwsbedrockImagePayloadHelper.getAmazonTitanImage(prompt, avoidInImage, awsBedrockParameters);
    }

    private static String getStabilityAiDiffusionImage(String prompt, String avoidInImage, AwsbedrockImageParameters awsBedrockParameters) {
        JSONArray textPromptsArray = new JSONArray().put((Object)new JSONObject().put("text", (Object)prompt).put("weight", 0));
        JSONObject json = new JSONObject().put("text_prompts", (Object)textPromptsArray).put("height", (Object)awsBedrockParameters.getHeight()).put("width", (Object)awsBedrockParameters.getWidth()).put("cfg_scale", (Object)awsBedrockParameters.getCfgScale()).put("seed", (Object)awsBedrockParameters.getSeed());
        return json.toString();
    }

    private static String identifyPayload(String prompt, String avoidInImage, AwsbedrockImageParameters awsBedrockParameters) {
        if (awsBedrockParameters.getModelName().contains("amazon.titan-image")) {
            return AwsbedrockImagePayloadHelper.getAmazonTitanImage(prompt, avoidInImage, awsBedrockParameters);
        }
        if (awsBedrockParameters.getModelName().contains("amazon.nova")) {
            return AwsbedrockImagePayloadHelper.getAmazonNovaImage(prompt, avoidInImage, awsBedrockParameters);
        }
        if (awsBedrockParameters.getModelName().contains("stability.stable-diffusion-xl")) {
            return AwsbedrockImagePayloadHelper.getStabilityAiDiffusionImage(prompt, avoidInImage, awsBedrockParameters);
        }
        return "Unsupported model";
    }

    private static BedrockRuntimeClient createClient(AwsbedrockConfiguration configuration, Region region) {
        Object awsCredentials = configuration.getAwsSessionToken() == null || configuration.getAwsSessionToken().isEmpty() ? AwsBasicCredentials.create((String)configuration.getAwsAccessKeyId(), (String)configuration.getAwsSecretAccessKey()) : AwsSessionCredentials.create((String)configuration.getAwsAccessKeyId(), (String)configuration.getAwsSecretAccessKey(), (String)configuration.getAwsSessionToken());
        return (BedrockRuntimeClient)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)BedrockRuntimeClient.builder().credentialsProvider((AwsCredentialsProvider)StaticCredentialsProvider.create((AwsCredentials)awsCredentials))).region(region)).build();
    }

    private static InvokeModelRequest createInvokeRequest(String modelId, String nativeRequest) {
        return (InvokeModelRequest)InvokeModelRequest.builder().body(SdkBytes.fromUtf8String((String)nativeRequest)).modelId(modelId).build();
    }

    public static byte[] generateImage(String modelId, String body, AwsbedrockConfiguration configuration, Region region) throws IOException {
        BedrockRuntimeClient bedrock = AwsbedrockImagePayloadHelper.createClient(configuration, region);
        InvokeModelRequest request = AwsbedrockImagePayloadHelper.createInvokeRequest(modelId, body);
        InvokeModelResponse response = bedrock.invokeModel(request);
        JSONObject responseBody = new JSONObject(response.body().asUtf8String());
        byte[] imageBytes = null;
        if (modelId.contains("amazon.titan-image") || modelId.contains("amazon.nova")) {
            String base64Image = responseBody.getJSONArray("images").getString(0);
            imageBytes = Base64.getDecoder().decode(base64Image);
        } else if (modelId.contains("stability.stable-diffusion-xl")) {
            JSONArray artifactsArray = responseBody.getJSONArray("artifacts");
            String base64Image = artifactsArray.getJSONObject(0).getString("base64");
            imageBytes = Base64.getDecoder().decode(base64Image);
        }
        String finishReason = responseBody.optString("error", null);
        if (finishReason != null) {
            throw new RuntimeException("Image generation error. Error is " + finishReason);
        }
        return imageBytes;
    }

    public static String invokeModel(String prompt, String avoidInImage, String fullPath, AwsbedrockConfiguration configuration, AwsbedrockImageParameters awsBedrockParameters) {
        Region region = AwsbedrockPayloadHelper.getRegion(awsBedrockParameters.getRegion());
        String modelId = awsBedrockParameters.getModelName();
        String body = AwsbedrockImagePayloadHelper.identifyPayload(prompt, avoidInImage, awsBedrockParameters);
        logger.info(body);
        try {
            byte[] imageBytes = AwsbedrockImagePayloadHelper.generateImage(modelId, body, configuration, region);
            ByteArrayInputStream bis = new ByteArrayInputStream(imageBytes);
            BufferedImage bufferedImage = ImageIO.read(bis);
            bis.close();
            String filePath = fullPath;
            File outputImageFile = new File(filePath);
            ImageIO.write((RenderedImage)bufferedImage, "png", outputImageFile);
            if (bufferedImage != null) {
                logger.info("Successfully generated image.");
            } else {
                logger.info("Failed to generate image.");
            }
            return filePath;
        }
        catch (Exception e) {
            logger.error("Error: " + e.getMessage());
            e.printStackTrace();
            return null;
        }
    }
}

