/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vertexai.embedding.multimodal;

import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Value;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.metadata.DefaultUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.content.Media;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.DocumentEmbeddingModel;
import org.springframework.ai.embedding.DocumentEmbeddingRequest;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.EmbeddingResultMetadata;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingConnectionDetails;
import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils;
import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingModelName;
import org.springframework.ai.vertexai.embedding.multimodal.VertexAiMultimodalEmbeddingOptions;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import org.springframework.util.StringUtils;

public class VertexAiMultimodalEmbeddingModel
implements DocumentEmbeddingModel {
    private static final Logger logger = LoggerFactory.getLogger(VertexAiMultimodalEmbeddingModel.class);
    private static final MimeType TEXT_MIME_TYPE = MimeTypeUtils.parseMimeType((String)"text/*");
    private static final MimeType IMAGE_MIME_TYPE = MimeTypeUtils.parseMimeType((String)"image/*");
    private static final MimeType VIDEO_MIME_TYPE = MimeTypeUtils.parseMimeType((String)"video/*");
    private static final List<MimeType> SUPPORTED_IMAGE_MIME_SUB_TYPES = List.of(MimeTypeUtils.IMAGE_JPEG, MimeTypeUtils.IMAGE_GIF, MimeTypeUtils.IMAGE_PNG, MimeTypeUtils.parseMimeType((String)"image/bmp"));
    private static final Map<String, Integer> KNOWN_EMBEDDING_DIMENSIONS = Stream.of(VertexAiMultimodalEmbeddingModelName.values()).collect(Collectors.toMap(VertexAiMultimodalEmbeddingModelName::getName, VertexAiMultimodalEmbeddingModelName::getDimensions));
    public final VertexAiMultimodalEmbeddingOptions defaultOptions;
    private final VertexAiEmbeddingConnectionDetails connectionDetails;

    public VertexAiMultimodalEmbeddingModel(VertexAiEmbeddingConnectionDetails connectionDetails, VertexAiMultimodalEmbeddingOptions defaultEmbeddingOptions) {
        Assert.notNull((Object)defaultEmbeddingOptions, (String)"VertexAiMultimodalEmbeddingOptions must not be null");
        this.defaultOptions = defaultEmbeddingOptions;
        this.connectionDetails = connectionDetails;
    }

    public EmbeddingResponse call(DocumentEmbeddingRequest request) {
        EmbeddingResponse finalResponse = new EmbeddingResponse(List.of());
        VertexAiMultimodalEmbeddingOptions mergedOptions = this.defaultOptions;
        if (request.getOptions() != null) {
            VertexAiMultimodalEmbeddingOptions defaultOptionsCopy = VertexAiMultimodalEmbeddingOptions.builder().from(this.defaultOptions).build();
            mergedOptions = (VertexAiMultimodalEmbeddingOptions)ModelOptionsUtils.merge((Object)request.getOptions(), (Object)defaultOptionsCopy, VertexAiMultimodalEmbeddingOptions.class);
        }
        try (PredictionServiceClient client = PredictionServiceClient.create((PredictionServiceSettings)this.connectionDetails.getPredictionServiceSettings());){
            EndpointName endpointName = this.connectionDetails.getEndpointName(mergedOptions.getModel());
            for (Document document : request.getInstructions()) {
                EmbeddingResponse singleDocResponse = this.doSingleDocumentPrediction(client, endpointName, document, mergedOptions);
                ArrayList mergedEmbeddings = new ArrayList(finalResponse.getResults());
                mergedEmbeddings.addAll(singleDocResponse.getResults());
                finalResponse = new EmbeddingResponse(mergedEmbeddings, singleDocResponse.getMetadata());
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return finalResponse;
    }

    /*
     * Enabled aggressive block sorting
     */
    private EmbeddingResponse doSingleDocumentPrediction(PredictionServiceClient client, EndpointName endpointName, Document document, VertexAiMultimodalEmbeddingOptions mergedOptions) throws InvalidProtocolBufferException {
        Media media;
        VertexAiEmbeddingUtils.MultimodalInstanceBuilder instanceBuilder = VertexAiEmbeddingUtils.MultimodalInstanceBuilder.of();
        EnumMap<EmbeddingResultMetadata.ModalityType, DocumentMetadata> documentMetadata = new EnumMap<EmbeddingResultMetadata.ModalityType, DocumentMetadata>(EmbeddingResultMetadata.ModalityType.class);
        if (mergedOptions.getDimensions() != null) {
            instanceBuilder.dimension(mergedOptions.getDimensions());
        }
        if (StringUtils.hasText((String)document.getText())) {
            instanceBuilder.text(document.getText());
            documentMetadata.put(EmbeddingResultMetadata.ModalityType.TEXT, new DocumentMetadata(document.getId(), MimeTypeUtils.TEXT_PLAIN, document.getText()));
        }
        if ((media = document.getMedia()) != null) {
            if (media.getMimeType().isCompatibleWith(TEXT_MIME_TYPE)) {
                instanceBuilder.text(media.getData().toString());
                documentMetadata.put(EmbeddingResultMetadata.ModalityType.TEXT, new DocumentMetadata(document.getId(), MimeTypeUtils.TEXT_PLAIN, media.getData()));
                if (StringUtils.hasText((String)document.getText())) {
                    logger.warn("Media type String overrides the Document text content!");
                }
            } else if (media.getMimeType().isCompatibleWith(IMAGE_MIME_TYPE)) {
                if (!SUPPORTED_IMAGE_MIME_SUB_TYPES.contains(media.getMimeType())) {
                    logger.warn("Unsupported image mime type: {}", (Object)media.getMimeType());
                    throw new IllegalArgumentException("Unsupported image mime type: " + String.valueOf(media.getMimeType()));
                }
                instanceBuilder.image(VertexAiEmbeddingUtils.ImageBuilder.of(media.getMimeType()).imageData(media.getData()).build());
                documentMetadata.put(EmbeddingResultMetadata.ModalityType.IMAGE, new DocumentMetadata(document.getId(), media.getMimeType(), media.getData()));
            } else {
                if (!media.getMimeType().isCompatibleWith(VIDEO_MIME_TYPE)) {
                    logger.warn("Unsupported media type: {}", (Object)media.getMimeType());
                    throw new IllegalArgumentException("Unsupported media type: " + String.valueOf(media.getMimeType()));
                }
                instanceBuilder.video(VertexAiEmbeddingUtils.VideoBuilder.of(media.getMimeType()).videoData(media.getData()).startOffsetSec(mergedOptions.getVideoStartOffsetSec()).endOffsetSec(mergedOptions.getVideoEndOffsetSec()).intervalSec(mergedOptions.getVideoIntervalSec()).build());
                documentMetadata.put(EmbeddingResultMetadata.ModalityType.VIDEO, new DocumentMetadata(document.getId(), media.getMimeType(), media.getData()));
            }
        }
        List<Value> instances = List.of(VertexAiEmbeddingUtils.valueOf(instanceBuilder.build()));
        PredictRequest.Builder predictRequestBuilder = PredictRequest.newBuilder().setEndpoint(endpointName.toString()).setParameters(VertexAiEmbeddingUtils.jsonToValue(ModelOptionsUtils.toJsonString(Map.of()))).addAllInstances(instances);
        PredictResponse embeddingResponse = client.predict(predictRequestBuilder.build());
        int index = 0;
        ArrayList<Embedding> embeddingList = new ArrayList<Embedding>();
        for (Value prediction : embeddingResponse.getPredictionsList()) {
            Value videoEmbeddings;
            DocumentMetadata docMetadata;
            if (prediction.getStructValue().containsFields("textEmbedding")) {
                Value textEmbedding = prediction.getStructValue().getFieldsOrThrow("textEmbedding");
                float[] textVector = VertexAiEmbeddingUtils.toVector(textEmbedding);
                docMetadata = (DocumentMetadata)documentMetadata.get(EmbeddingResultMetadata.ModalityType.TEXT);
                embeddingList.add(new Embedding(textVector, Integer.valueOf(index++), new EmbeddingResultMetadata(docMetadata.documentId, EmbeddingResultMetadata.ModalityType.TEXT, docMetadata.mimeType, docMetadata.data)));
            }
            if (prediction.getStructValue().containsFields("imageEmbedding")) {
                Value imageEmbedding = prediction.getStructValue().getFieldsOrThrow("imageEmbedding");
                float[] imageVector = VertexAiEmbeddingUtils.toVector(imageEmbedding);
                docMetadata = (DocumentMetadata)documentMetadata.get(EmbeddingResultMetadata.ModalityType.IMAGE);
                embeddingList.add(new Embedding(imageVector, Integer.valueOf(index++), new EmbeddingResultMetadata(docMetadata.documentId, EmbeddingResultMetadata.ModalityType.IMAGE, docMetadata.mimeType, docMetadata.data)));
            }
            if (!prediction.getStructValue().containsFields("videoEmbeddings") || !(videoEmbeddings = prediction.getStructValue().getFieldsOrThrow("videoEmbeddings")).getListValue().getValues(0).getStructValue().containsFields("embedding")) continue;
            Value embeddings = videoEmbeddings.getListValue().getValues(0).getStructValue().getFieldsOrThrow("embedding");
            float[] videoVector = VertexAiEmbeddingUtils.toVector(embeddings);
            DocumentMetadata docMetadata2 = (DocumentMetadata)documentMetadata.get(EmbeddingResultMetadata.ModalityType.VIDEO);
            embeddingList.add(new Embedding(videoVector, Integer.valueOf(index++), new EmbeddingResultMetadata(docMetadata2.documentId, EmbeddingResultMetadata.ModalityType.VIDEO, docMetadata2.mimeType, docMetadata2.data)));
        }
        String deploymentModelId = embeddingResponse.getDeployedModelId();
        Map<String, Object> metadataToUse = Map.of("deployment-model-id", StringUtils.hasText((String)deploymentModelId) ? deploymentModelId : "unknown");
        EmbeddingResponseMetadata responseMetadata = this.generateResponseMetadata(mergedOptions.getModel(), 0, metadataToUse);
        return new EmbeddingResponse(embeddingList, responseMetadata);
    }

    private EmbeddingResponseMetadata generateResponseMetadata(String model, Integer totalTokens, Map<String, Object> metadataToUse) {
        DefaultUsage usage = this.getDefaultUsage(totalTokens);
        return new EmbeddingResponseMetadata(model, (Usage)usage, metadataToUse);
    }

    private DefaultUsage getDefaultUsage(Integer totalTokens) {
        return new DefaultUsage(Integer.valueOf(0), Integer.valueOf(0), totalTokens);
    }

    public int dimensions() {
        return KNOWN_EMBEDDING_DIMENSIONS.getOrDefault(this.defaultOptions.getModel(), 768);
    }

    record DocumentMetadata(String documentId, MimeType mimeType, Object data) {
    }
}

