/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.connectors.inference.internal.helpers.payload;

import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.mulesoft.connectors.inference.api.request.ChatPayloadRecord;
import com.mulesoft.connectors.inference.api.request.FunctionDefinitionRecord;
import com.mulesoft.connectors.inference.internal.connection.types.TextGenerationConnection;
import com.mulesoft.connectors.inference.internal.connection.types.VisionModelConnection;
import com.mulesoft.connectors.inference.internal.dto.imagegeneration.DefaultImageRequestPayloadRecord;
import com.mulesoft.connectors.inference.internal.dto.imagegeneration.ImageGenerationRequestPayloadDTO;
import com.mulesoft.connectors.inference.internal.dto.moderation.ModerationRequestPayloadRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.DefaultRequestPayloadRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.TextGenerationRequestPayloadDTO;
import com.mulesoft.connectors.inference.internal.dto.vision.Content;
import com.mulesoft.connectors.inference.internal.dto.vision.DefaultVisionRequestPayloadRecord;
import com.mulesoft.connectors.inference.internal.dto.vision.ImageUrl;
import com.mulesoft.connectors.inference.internal.dto.vision.ImageUrlContent;
import com.mulesoft.connectors.inference.internal.dto.vision.Message;
import com.mulesoft.connectors.inference.internal.dto.vision.TextContent;
import com.mulesoft.connectors.inference.internal.dto.vision.VisionRequestPayloadDTO;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URLConnection;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RequestPayloadHelper {
    private static final Logger logger = LoggerFactory.getLogger(RequestPayloadHelper.class);
    protected final ObjectMapper objectMapper;

    public RequestPayloadHelper(ObjectMapper objectMapper) {
        this.objectMapper = objectMapper;
    }

    public TextGenerationRequestPayloadDTO buildChatAnswerPromptPayload(TextGenerationConnection connection, String prompt) {
        return this.buildPayload(connection, List.of(new ChatPayloadRecord("user", prompt)), null);
    }

    public TextGenerationRequestPayloadDTO parseAndBuildChatCompletionPayload(TextGenerationConnection connection, InputStream messages) throws IOException {
        List messagesList = (List)this.objectMapper.readValue(messages, (JavaType)this.objectMapper.getTypeFactory().constructCollectionType(List.class, ChatPayloadRecord.class));
        return this.buildPayload(connection, messagesList, null);
    }

    public TextGenerationRequestPayloadDTO buildPayload(TextGenerationConnection connection, List<ChatPayloadRecord> messages, List<FunctionDefinitionRecord> tools) {
        return new DefaultRequestPayloadRecord(connection.getModelName(), messages, connection.getMaxTokens(), connection.getTemperature(), connection.getTopP(), tools);
    }

    public TextGenerationRequestPayloadDTO buildPromptTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data) {
        List<ChatPayloadRecord> messages = this.createMessagesArrayWithSystemPrompt(template + " - " + instructions, data);
        return this.buildPayload(connection, messages, null);
    }

    public TextGenerationRequestPayloadDTO buildToolsTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data, InputStream tools) throws IOException {
        List<FunctionDefinitionRecord> toolsRecord = this.parseInputStreamToTools(tools);
        logger.debug("toolsArray: {}", toolsRecord);
        return this.buildToolsTemplatePayload(connection, template, instructions, data, toolsRecord);
    }

    public TextGenerationRequestPayloadDTO buildToolsTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data, List<FunctionDefinitionRecord> tools) {
        List<ChatPayloadRecord> messages = this.createMessagesArrayWithSystemPrompt(template + " - " + instructions, data);
        return this.buildPayload(connection, messages, tools);
    }

    public ImageGenerationRequestPayloadDTO createRequestImageGeneration(String model, String prompt) {
        return new DefaultImageRequestPayloadRecord(model, prompt, "b64_json");
    }

    public VisionRequestPayloadDTO createRequestImageURL(VisionModelConnection connection, String prompt, String imageUrl) throws IOException {
        ArrayList<Content> contents = new ArrayList<Content>();
        contents.add(new TextContent("text", prompt));
        contents.add(new ImageUrlContent("image_url", new ImageUrl(this.getImageUrl(imageUrl))));
        Message message = new Message("user", contents);
        return new DefaultVisionRequestPayloadRecord(connection.getModelName(), List.of(message), connection.getMaxTokens(), connection.getTemperature(), connection.getTopP());
    }

    public ModerationRequestPayloadRecord getModerationRequestPayload(String modelName, InputStream text) throws IOException {
        Object input = this.objectMapper.readValue(text, Object.class);
        return new ModerationRequestPayloadRecord(input, modelName);
    }

    protected List<ChatPayloadRecord> createMessagesArrayWithSystemPrompt(String systemContent, String userContent) {
        ChatPayloadRecord systemMessage = new ChatPayloadRecord("system", systemContent);
        ChatPayloadRecord userMessage = new ChatPayloadRecord("user", userContent);
        return List.of(systemMessage, userMessage);
    }

    protected List<FunctionDefinitionRecord> parseInputStreamToTools(InputStream inputStream) throws IOException {
        return (List)this.objectMapper.readValue(inputStream, (JavaType)this.objectMapper.getTypeFactory().constructCollectionType(List.class, FunctionDefinitionRecord.class));
    }

    protected String getMimeType(String base64String) throws IOException {
        byte[] decodedBytes = Base64.getDecoder().decode(base64String);
        ByteArrayInputStream inputStream = new ByteArrayInputStream(decodedBytes);
        String mimeType = URLConnection.guessContentTypeFromStream(inputStream);
        return mimeType != null ? mimeType : "image/jpeg";
    }

    protected boolean isBase64String(String str) {
        if (str == null || str.length() % 4 != 0 || !str.matches("^[A-Za-z0-9+/]*={0,2}$")) {
            return false;
        }
        try {
            Base64.getDecoder().decode(str);
            return true;
        }
        catch (IllegalArgumentException e) {
            return false;
        }
    }

    private String getImageUrl(String imageUrl) throws IOException {
        return this.isBase64String(imageUrl) ? "data:" + this.getMimeType(imageUrl) + ";base64," + imageUrl : imageUrl;
    }
}

