/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.connectors.inference.internal.helpers.payload;

import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.mulesoft.connectors.inference.api.request.ChatPayloadRecord;
import com.mulesoft.connectors.inference.api.request.Function;
import com.mulesoft.connectors.inference.api.request.FunctionDefinitionRecord;
import com.mulesoft.connectors.inference.api.request.FunctionSchema;
import com.mulesoft.connectors.inference.internal.connection.types.TextGenerationConnection;
import com.mulesoft.connectors.inference.internal.connection.types.VisionModelConnection;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.TextGenerationRequestPayloadDTO;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.gemini.ContentRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.gemini.FunctionDeclarationsWrapper;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.gemini.GeminiGenerationConfigRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.gemini.GeminiPayloadRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.gemini.PartRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.gemini.SystemInstructionRecord;
import com.mulesoft.connectors.inference.internal.dto.vision.VisionRequestPayloadDTO;
import com.mulesoft.connectors.inference.internal.dto.vision.gemini.InlineData;
import com.mulesoft.connectors.inference.internal.dto.vision.gemini.Part;
import com.mulesoft.connectors.inference.internal.dto.vision.gemini.VisionContentRecord;
import com.mulesoft.connectors.inference.internal.helpers.payload.RequestPayloadHelper;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GeminiRequestPayloadHelper
extends RequestPayloadHelper {
    private static final Logger logger = LoggerFactory.getLogger(GeminiRequestPayloadHelper.class);

    public GeminiRequestPayloadHelper(ObjectMapper objectMapper) {
        super(objectMapper);
    }

    @Override
    public TextGenerationRequestPayloadDTO buildChatAnswerPromptPayload(TextGenerationConnection connection, String prompt, Map<String, Object> additionalRequestAttributes) {
        return this.buildGeminiPayload(connection, prompt, Collections.emptyList(), null, Collections.emptyList(), additionalRequestAttributes);
    }

    @Override
    public TextGenerationRequestPayloadDTO buildPromptTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data, Map<String, Object> additionalRequestAttributes) {
        PartRecord partRecord = new PartRecord(template + " - " + instructions, null);
        SystemInstructionRecord systemInstructionRecord = new SystemInstructionRecord(List.of(partRecord));
        return this.buildGeminiPayload(connection, data, Collections.emptyList(), systemInstructionRecord, Collections.emptyList(), additionalRequestAttributes);
    }

    @Override
    public TextGenerationRequestPayloadDTO parseAndBuildChatCompletionPayload(TextGenerationConnection connection, InputStream messages, Map<String, Object> additionalRequestAttributes) throws IOException {
        List openAIFormatMessages = (List)this.objectMapper.readValue(messages, (JavaType)this.objectMapper.getTypeFactory().constructCollectionType(List.class, ChatPayloadRecord.class));
        List<ContentRecord> contentRecords = openAIFormatMessages.stream().map(this::convertToGeminiFormat).toList();
        return new GeminiPayloadRecord<ContentRecord>(contentRecords, null, this.buildGeminiGenerationConfig(connection.getMaxTokens(), connection.getTemperature(), connection.getTopP(), additionalRequestAttributes), null, null);
    }

    @Override
    public TextGenerationRequestPayloadDTO buildToolsTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data, InputStream tools, Map<String, Object> additionalRequestAttributes) throws IOException {
        List openAIFormatTools = (List)this.objectMapper.readValue(tools, (JavaType)this.objectMapper.getTypeFactory().constructCollectionType(List.class, FunctionDefinitionRecord.class));
        return this.buildToolsTemplatePayload(connection, template, instructions, data, openAIFormatTools, additionalRequestAttributes);
    }

    @Override
    public TextGenerationRequestPayloadDTO buildToolsTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data, List<FunctionDefinitionRecord> openAIFormatTools, Map<String, Object> additionalRequestAttributes) {
        List<Function> functionDeclarations = this.getGeminiCompatibleFunctionList(openAIFormatTools);
        logger.debug("functionDeclarations: {}", functionDeclarations);
        PartRecord partRecord = new PartRecord(template + " - " + instructions, null);
        SystemInstructionRecord systemInstructionRecord = new SystemInstructionRecord(List.of(partRecord));
        GeminiPayloadRecord<ContentRecord> geminiPayload = this.buildGeminiPayload(connection, data, Collections.emptyList(), systemInstructionRecord, functionDeclarations, additionalRequestAttributes);
        logger.debug("geminiPayload: {}", geminiPayload);
        return geminiPayload;
    }

    @Override
    public VisionRequestPayloadDTO createRequestImageURL(VisionModelConnection connection, String prompt, String imageUrl, Map<String, Object> additionalRequestAttributes) throws IOException {
        VisionContentRecord content = this.getGoogleVisionContentRecord(prompt, imageUrl);
        return this.buildVisionRequestPayload(connection, List.of(content), additionalRequestAttributes);
    }

    private List<Function> getGeminiCompatibleFunctionList(List<FunctionDefinitionRecord> openAIFormatTools) {
        return Optional.ofNullable(openAIFormatTools).map(tools -> tools.stream().map(FunctionDefinitionRecord::function).filter(function -> function != null && function.parameters() != null).map(function -> new Function(function.name(), function.description(), this.mapGeminiCompatibleFunctionSchema(function.parameters()))).toList()).orElse(Collections.emptyList());
    }

    private VisionContentRecord getGoogleVisionContentRecord(String prompt, String imageUrl) throws IOException {
        ArrayList<Part> parts = new ArrayList<Part>();
        if (!this.isBase64String(imageUrl)) {
            throw new UnsupportedOperationException("Image Read By URI Operation not supported");
        }
        InlineData inlineData = new InlineData(this.getMimeType(imageUrl), imageUrl);
        parts.add(new Part(inlineData, null, null));
        parts.add(new Part(null, null, prompt));
        return new VisionContentRecord("user", parts);
    }

    public VisionRequestPayloadDTO buildVisionRequestPayload(VisionModelConnection connection, List<Object> messagesArray, Map<String, Object> additionalRequestAttributes) {
        return new GeminiPayloadRecord<Object>(messagesArray, null, this.buildGeminiGenerationConfig(connection.getMaxTokens(), connection.getTemperature(), connection.getTopP(), additionalRequestAttributes), null, null);
    }

    private GeminiPayloadRecord<ContentRecord> buildGeminiPayload(TextGenerationConnection connection, String prompt, List<String> safetySettings, SystemInstructionRecord systemInstruction, List<Function> functions, Map<String, Object> additionalRequestAttributes) {
        PartRecord partRecord = new PartRecord(prompt, null);
        ContentRecord contentRecord = new ContentRecord("user", List.of(partRecord));
        List tools = Optional.ofNullable(functions).filter(functionList -> !functionList.isEmpty()).map(functionList -> List.of(new FunctionDeclarationsWrapper((List<Function>)functionList))).orElse(null);
        return new GeminiPayloadRecord<ContentRecord>(List.of(contentRecord), systemInstruction, this.buildGeminiGenerationConfig(connection.getMaxTokens(), connection.getTemperature(), connection.getTopP(), additionalRequestAttributes), safetySettings != null ? safetySettings : Collections.emptyList(), tools);
    }

    private GeminiGenerationConfigRecord buildGeminiGenerationConfig(Number maxTokens, Number temperature, Number topP, Map<String, Object> additionalRequestAttributes) {
        return new GeminiGenerationConfigRecord(List.of("TEXT"), temperature, topP, maxTokens, additionalRequestAttributes);
    }

    private FunctionSchema mapGeminiCompatibleFunctionSchema(FunctionSchema parameters) {
        if (parameters == null) {
            return null;
        }
        return new FunctionSchema(parameters.type(), parameters.description(), parameters.enumValues(), null, null, null, null, null, null, null, this.deepSanitizeProperties(parameters.properties()), parameters.required(), null, null, null, this.mapGeminiCompatibleFunctionSchema(parameters.items()), null, null, null, null, null, null, null, null, null, null, null, null);
    }

    private Map<String, FunctionSchema> deepSanitizeProperties(Map<String, FunctionSchema> properties) {
        if (properties == null || properties.isEmpty()) {
            return properties;
        }
        HashMap<String, FunctionSchema> sanitizedProperties = new HashMap<String, FunctionSchema>();
        for (Map.Entry<String, FunctionSchema> entry : properties.entrySet()) {
            sanitizedProperties.put(entry.getKey(), this.mapGeminiCompatibleFunctionSchema(entry.getValue()));
        }
        return sanitizedProperties;
    }

    private ContentRecord convertToGeminiFormat(ChatPayloadRecord msg) {
        String role = "assistant".equals(msg.role()) ? "model" : msg.role();
        PartRecord part = new PartRecord(msg.content(), null);
        return new ContentRecord(role, List.of(part));
    }
}

