/*
 * Decompiled with CFR 0.152.
 */
package com.mulesoft.connectors.inference.internal.helpers.payload;

import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.mulesoft.connectors.inference.api.request.ChatPayloadRecord;
import com.mulesoft.connectors.inference.api.request.FunctionDefinitionRecord;
import com.mulesoft.connectors.inference.internal.connection.types.TextGenerationConnection;
import com.mulesoft.connectors.inference.internal.connection.types.VisionModelConnection;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.DefaultRequestPayloadRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.TextGenerationRequestPayloadDTO;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.vertexai.google.ContentRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.vertexai.google.PartRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.vertexai.google.SystemInstructionRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.vertexai.google.VertexAIGoogleGenerationConfigRecord;
import com.mulesoft.connectors.inference.internal.dto.textgeneration.vertexai.google.VertexAIGooglePayloadRecord;
import com.mulesoft.connectors.inference.internal.dto.vision.DefaultVisionRequestPayloadRecord;
import com.mulesoft.connectors.inference.internal.dto.vision.VisionRequestPayloadDTO;
import com.mulesoft.connectors.inference.internal.dto.vision.vertexai.FileData;
import com.mulesoft.connectors.inference.internal.dto.vision.vertexai.InlineData;
import com.mulesoft.connectors.inference.internal.dto.vision.vertexai.Part;
import com.mulesoft.connectors.inference.internal.dto.vision.vertexai.VisionContentRecord;
import com.mulesoft.connectors.inference.internal.helpers.payload.RequestPayloadHelper;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VertexAIRequestPayloadHelper
extends RequestPayloadHelper {
    private static final Logger logger = LoggerFactory.getLogger(VertexAIRequestPayloadHelper.class);
    public static final String GOOGLE_PROVIDER_TYPE = "Google";
    public static final String ANTHROPIC_PROVIDER_TYPE = "Anthropic";
    public static final String META_PROVIDER_TYPE = "Meta";
    public static final String VERTEX_AI_ANTHROPIC_VERSION_VALUE = "vertex-2023-10-16";
    private static final String DEFAULT_MIME_TYPE = "image/jpeg";

    public VertexAIRequestPayloadHelper(ObjectMapper objectMapper) {
        super(objectMapper);
    }

    @Override
    public TextGenerationRequestPayloadDTO buildChatAnswerPromptPayload(TextGenerationConnection connection, String prompt) {
        String provider;
        return switch (provider = VertexAIRequestPayloadHelper.getProviderByModel(connection.getModelName())) {
            case GOOGLE_PROVIDER_TYPE -> this.buildVertexAIGooglePayload(connection, prompt, Collections.emptyList(), null, Collections.emptyList());
            default -> this.getDefaultRequestPayloadDTO(connection, List.of(new ChatPayloadRecord("user", prompt)));
        };
    }

    @Override
    public TextGenerationRequestPayloadDTO buildPromptTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data) {
        String provider;
        return switch (provider = VertexAIRequestPayloadHelper.getProviderByModel(connection.getModelName())) {
            case GOOGLE_PROVIDER_TYPE -> {
                PartRecord partRecord = new PartRecord(template + " - " + instructions);
                SystemInstructionRecord systemInstructionRecord = new SystemInstructionRecord(List.of(partRecord));
                yield this.buildVertexAIGooglePayload(connection, data, Collections.emptyList(), systemInstructionRecord, Collections.emptyList());
            }
            default -> {
                List<ChatPayloadRecord> messagesArray = this.createMessagesArrayWithSystemPrompt(template + " - " + instructions, data);
                yield this.buildPayload(connection, messagesArray, null);
            }
        };
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public TextGenerationRequestPayloadDTO parseAndBuildChatCompletionPayload(TextGenerationConnection connection, InputStream messages) throws IOException {
        String provider;
        List messagesList = (List)this.objectMapper.readValue(messages, (JavaType)this.objectMapper.getTypeFactory().constructCollectionType(List.class, ContentRecord.class));
        switch (provider = VertexAIRequestPayloadHelper.getProviderByModel(connection.getModelName())) {
            case "Google": {
                return new VertexAIGooglePayloadRecord(messagesList, null, this.buildVertexAIGoogleGenerationConfig(connection.getMaxTokens(), connection.getTemperature(), connection.getTopP()), null, null);
            }
            default: {
                throw new UnsupportedOperationException("Model not supported: " + connection.getModelName());
            }
        }
    }

    @Override
    public TextGenerationRequestPayloadDTO buildToolsTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data, List<FunctionDefinitionRecord> tools) {
        throw new UnsupportedOperationException("Currently not supported");
    }

    @Override
    public TextGenerationRequestPayloadDTO buildToolsTemplatePayload(TextGenerationConnection connection, String template, String instructions, String data, InputStream tools) throws IOException {
        String provider = VertexAIRequestPayloadHelper.getProviderByModel(connection.getModelName());
        throw new IllegalArgumentException(provider + ":" + connection.getModelName() + " on Vertex AI do not currently support function calling at this time.");
    }

    @Override
    public VisionRequestPayloadDTO createRequestImageURL(VisionModelConnection connection, String prompt, String imageUrl) throws IOException {
        String provider;
        switch (provider = VertexAIRequestPayloadHelper.getProviderByModel(connection.getModelName())) {
            case "Google": {
                break;
            }
            default: {
                throw new IllegalArgumentException("Unknown provider");
            }
        }
        VisionContentRecord content = this.getGoogleVisionContentRecord(prompt, imageUrl);
        return this.buildVisionRequestPayload(connection, List.of(content));
    }

    public static String getProviderByModel(String modelName) {
        logger.debug("model name {}", (Object)modelName);
        if (modelName == null || modelName.isEmpty()) {
            return "Unknown";
        }
        String upperName = modelName.toUpperCase();
        if (upperName.startsWith("GEMINI")) {
            return GOOGLE_PROVIDER_TYPE;
        }
        return "Unknown";
    }

    private VisionContentRecord getGoogleVisionContentRecord(String prompt, String imageUrl) throws IOException {
        ArrayList<Part> parts = new ArrayList<Part>();
        if (this.isBase64String(imageUrl)) {
            InlineData inlineData = new InlineData(this.getMimeType(imageUrl), imageUrl);
            parts.add(new Part(inlineData, null, null));
        } else {
            FileData fileData = new FileData(this.getMimeTypeFromUrl(imageUrl), imageUrl);
            parts.add(new Part(null, fileData, null));
        }
        parts.add(new Part(null, null, prompt));
        return new VisionContentRecord("user", parts);
    }

    private VisionRequestPayloadDTO buildVisionRequestPayload(VisionModelConnection connection, List<Object> messagesArray) {
        String provider;
        return switch (provider = VertexAIRequestPayloadHelper.getProviderByModel(connection.getModelName())) {
            case GOOGLE_PROVIDER_TYPE -> new VertexAIGooglePayloadRecord<Object>(messagesArray, null, this.buildVertexAIGoogleGenerationConfig(connection.getMaxTokens(), connection.getTemperature(), connection.getTopP()), null, null);
            default -> this.getDefaultVisionRequestPayloadDTO(connection, messagesArray);
        };
    }

    private DefaultRequestPayloadRecord getDefaultRequestPayloadDTO(TextGenerationConnection connection, List<ChatPayloadRecord> chatPayloadRecordList) {
        return new DefaultRequestPayloadRecord(connection.getModelName(), chatPayloadRecordList, connection.getMaxTokens(), connection.getTemperature(), connection.getTopP(), null);
    }

    private DefaultVisionRequestPayloadRecord getDefaultVisionRequestPayloadDTO(VisionModelConnection connection, List<Object> chatPayloadRecordList) {
        return new DefaultVisionRequestPayloadRecord(connection.getModelName(), chatPayloadRecordList, connection.getMaxTokens(), connection.getTemperature(), connection.getTopP());
    }

    private VertexAIGooglePayloadRecord<ContentRecord> buildVertexAIGooglePayload(TextGenerationConnection connection, String prompt, List<String> safetySettings, SystemInstructionRecord systemInstruction, List<FunctionDefinitionRecord> tools) {
        PartRecord partRecord = new PartRecord(prompt);
        ContentRecord contentRecord = new ContentRecord("user", List.of(partRecord));
        return new VertexAIGooglePayloadRecord<ContentRecord>(List.of(contentRecord), systemInstruction, this.buildVertexAIGoogleGenerationConfig(connection.getMaxTokens(), connection.getTemperature(), connection.getTopP()), safetySettings, tools != null && !tools.isEmpty() ? tools : null);
    }

    private VertexAIGoogleGenerationConfigRecord buildVertexAIGoogleGenerationConfig(Number maxTokens, Number temperature, Number topP) {
        return new VertexAIGoogleGenerationConfigRecord(List.of("TEXT"), temperature, topP, maxTokens);
    }

    private String getMimeTypeFromUrl(String imageUrl) {
        String extension;
        if (imageUrl == null || imageUrl.isBlank()) {
            return DEFAULT_MIME_TYPE;
        }
        String trimmedUrl = imageUrl.trim();
        int lastDotIndex = trimmedUrl.lastIndexOf(46);
        if (lastDotIndex == -1) {
            return DEFAULT_MIME_TYPE;
        }
        return switch (extension = trimmedUrl.substring(lastDotIndex).toLowerCase()) {
            case ".png" -> "image/png";
            case ".pdf" -> "application/pdf";
            default -> DEFAULT_MIME_TYPE;
        };
    }
}

