/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.vertexai;

import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.FunctionCallingConfig;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.ToolConfig;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import com.google.cloud.vertexai.generativeai.ResponseHandler;
import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.vertexai.ContentsMapper;
import dev.langchain4j.model.vertexai.FinishReasonMapper;
import dev.langchain4j.model.vertexai.FunctionCallHelper;
import dev.langchain4j.model.vertexai.HarmCategory;
import dev.langchain4j.model.vertexai.ResponseGrounding;
import dev.langchain4j.model.vertexai.SafetySettingsMapper;
import dev.langchain4j.model.vertexai.SafetyThreshold;
import dev.langchain4j.model.vertexai.TokenUsageMapper;
import dev.langchain4j.model.vertexai.ToolCallingMode;
import dev.langchain4j.model.vertexai.spi.VertexAiGeminiChatModelBuilderFactory;
import dev.langchain4j.spi.ServiceHelper;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class VertexAiGeminiChatModel
implements ChatLanguageModel,
Closeable {
    private final GenerativeModel generativeModel;
    private final GenerationConfig generationConfig;
    private final Integer maxRetries;
    private final VertexAI vertexAI;
    private final Map<HarmCategory, SafetyThreshold> safetySettings;
    private final Tool googleSearch;
    private final Tool vertexSearch;
    private final ToolConfig toolConfig;
    private final List<String> allowedFunctionNames;
    private final Boolean logRequests;
    private final Boolean logResponses;
    private static final Logger logger = LoggerFactory.getLogger(VertexAiGeminiChatModel.class);

    public VertexAiGeminiChatModel(String project, String location, String modelName, Float temperature, Integer maxOutputTokens, Integer topK, Float topP, Integer maxRetries, String responseMimeType, Schema responseSchema, Map<HarmCategory, SafetyThreshold> safetySettings, Boolean useGoogleSearch, String vertexSearchDatastore, ToolCallingMode toolCallingMode, List<String> allowedFunctionNames, Boolean logRequests, Boolean logResponses) {
        GenerationConfig.Builder generationConfigBuilder = GenerationConfig.newBuilder();
        if (temperature != null) {
            generationConfigBuilder.setTemperature(temperature.floatValue());
        }
        if (maxOutputTokens != null) {
            generationConfigBuilder.setMaxOutputTokens(maxOutputTokens.intValue());
        }
        if (topK != null) {
            generationConfigBuilder.setTopK((float)topK.intValue());
        }
        if (topP != null) {
            generationConfigBuilder.setTopP(topP.floatValue());
        }
        if (responseMimeType != null) {
            generationConfigBuilder.setResponseMimeType(responseMimeType);
        }
        if (responseSchema != null) {
            generationConfigBuilder.setResponseMimeType("application/json");
            generationConfigBuilder.setResponseSchema(responseSchema);
        }
        this.generationConfig = generationConfigBuilder.build();
        this.safetySettings = safetySettings != null ? new HashMap<HarmCategory, SafetyThreshold>(safetySettings) : Collections.emptyMap();
        this.googleSearch = useGoogleSearch != null && useGoogleSearch != false ? ResponseGrounding.googleSearchTool() : null;
        this.vertexSearch = vertexSearchDatastore != null ? ResponseGrounding.vertexAiSearch(vertexSearchDatastore) : null;
        this.allowedFunctionNames = allowedFunctionNames != null ? Collections.unmodifiableList(allowedFunctionNames) : Collections.emptyList();
        this.toolConfig = toolCallingMode != null ? (toolCallingMode == ToolCallingMode.ANY && allowedFunctionNames != null && !allowedFunctionNames.isEmpty() ? ToolConfig.newBuilder().setFunctionCallingConfig(FunctionCallingConfig.newBuilder().setMode(FunctionCallingConfig.Mode.ANY).addAllAllowedFunctionNames(this.allowedFunctionNames).build()).build() : (toolCallingMode == ToolCallingMode.NONE ? ToolConfig.newBuilder().setFunctionCallingConfig(FunctionCallingConfig.newBuilder().setMode(FunctionCallingConfig.Mode.NONE).build()).build() : ToolConfig.newBuilder().setFunctionCallingConfig(FunctionCallingConfig.newBuilder().setMode(FunctionCallingConfig.Mode.AUTO).build()).build())) : ToolConfig.newBuilder().setFunctionCallingConfig(FunctionCallingConfig.newBuilder().setMode(FunctionCallingConfig.Mode.AUTO).build()).build();
        this.vertexAI = new VertexAI.Builder().setProjectId(ValidationUtils.ensureNotBlank((String)project, (String)"project")).setLocation(ValidationUtils.ensureNotBlank((String)location, (String)"location")).setCustomHeaders(Collections.singletonMap("user-agent", "LangChain4j")).build();
        this.generativeModel = new GenerativeModel(ValidationUtils.ensureNotBlank((String)modelName, (String)"modelName"), this.vertexAI).withGenerationConfig(this.generationConfig);
        this.maxRetries = (Integer)Utils.getOrDefault((Object)maxRetries, (Object)3);
        this.logRequests = logRequests != null ? logRequests : Boolean.valueOf(false);
        this.logResponses = logResponses != null ? logResponses : Boolean.valueOf(false);
    }

    public VertexAiGeminiChatModel(GenerativeModel generativeModel, GenerationConfig generationConfig) {
        this(generativeModel, generationConfig, 3);
    }

    public VertexAiGeminiChatModel(GenerativeModel generativeModel, GenerationConfig generationConfig, Integer maxRetries) {
        this.generationConfig = (GenerationConfig)ValidationUtils.ensureNotNull((Object)generationConfig, (String)"generationConfig");
        this.generativeModel = ((GenerativeModel)ValidationUtils.ensureNotNull((Object)generativeModel, (String)"generativeModel")).withGenerationConfig(generationConfig);
        this.maxRetries = (Integer)Utils.getOrDefault((Object)maxRetries, (Object)3);
        this.vertexAI = null;
        this.safetySettings = Collections.emptyMap();
        this.googleSearch = null;
        this.vertexSearch = null;
        this.toolConfig = ToolConfig.newBuilder().setFunctionCallingConfig(FunctionCallingConfig.newBuilder().setMode(FunctionCallingConfig.Mode.AUTO).build()).build();
        this.allowedFunctionNames = Collections.emptyList();
        this.logRequests = false;
        this.logResponses = false;
    }

    public Response<AiMessage> generate(List<ChatMessage> messages) {
        return this.generate(messages, new ArrayList<ToolSpecification>());
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
        Content content;
        List<FunctionCall> functionCalls;
        String modelName = this.generativeModel.getModelName();
        ArrayList<Tool> tools = new ArrayList<Tool>();
        if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
            Tool tool = FunctionCallHelper.convertToolSpecifications(toolSpecifications);
            tools.add(tool);
        }
        if (this.googleSearch != null) {
            tools.add(this.googleSearch);
        }
        if (this.vertexSearch != null) {
            tools.add(this.vertexSearch);
        }
        GenerativeModel model = this.generativeModel.withTools(tools).withToolConfig(this.toolConfig);
        ContentsMapper.InstructionAndContent instructionAndContent = ContentsMapper.splitInstructionAndContent(messages);
        if (instructionAndContent.systemInstruction != null) {
            model = model.withSystemInstruction(instructionAndContent.systemInstruction);
        }
        if (!this.safetySettings.isEmpty()) {
            model = model.withSafetySettings(SafetySettingsMapper.mapSafetySettings(this.safetySettings));
        }
        if (this.logRequests.booleanValue() && logger.isDebugEnabled()) {
            logger.debug("GEMINI ({}) request: {} tools: {}", new Object[]{modelName, instructionAndContent, tools});
        }
        GenerativeModel finalModel = model;
        GenerateContentResponse response = (GenerateContentResponse)RetryUtils.withRetry(() -> finalModel.generateContent(instructionAndContent.contents), (int)this.maxRetries);
        if (this.logResponses.booleanValue() && logger.isDebugEnabled()) {
            logger.debug("GEMINI ({}) response: {}", (Object)modelName, (Object)response);
        }
        if (!(functionCalls = (content = ResponseHandler.getContent((GenerateContentResponse)response)).getPartsList().stream().filter(Part::hasFunctionCall).map(Part::getFunctionCall).collect(Collectors.toList())).isEmpty()) {
            List<ToolExecutionRequest> toolExecutionRequests = FunctionCallHelper.fromFunctionCalls(functionCalls);
            return Response.from((Object)AiMessage.from(toolExecutionRequests), (TokenUsage)TokenUsageMapper.map(response.getUsageMetadata()), (FinishReason)FinishReasonMapper.map(ResponseHandler.getFinishReason((GenerateContentResponse)response)));
        }
        return Response.from((Object)AiMessage.from((String)ResponseHandler.getText((GenerateContentResponse)response)), (TokenUsage)TokenUsageMapper.map(response.getUsageMetadata()), (FinishReason)FinishReasonMapper.map(ResponseHandler.getFinishReason((GenerateContentResponse)response)));
    }

    public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
        if (toolSpecification == null) {
            return this.generate(messages);
        }
        return this.generate(messages, Collections.singletonList(toolSpecification));
    }

    @Override
    public void close() throws IOException {
        if (this.vertexAI != null) {
            this.vertexAI.close();
        }
    }

    public static VertexAiGeminiChatModelBuilder builder() {
        Iterator iterator = ServiceHelper.loadFactories(VertexAiGeminiChatModelBuilderFactory.class).iterator();
        if (iterator.hasNext()) {
            VertexAiGeminiChatModelBuilderFactory factory = (VertexAiGeminiChatModelBuilderFactory)iterator.next();
            return (VertexAiGeminiChatModelBuilder)factory.get();
        }
        return new VertexAiGeminiChatModelBuilder();
    }

    public static class VertexAiGeminiChatModelBuilder {
        private String project;
        private String location;
        private String modelName;
        private Float temperature;
        private Integer maxOutputTokens;
        private Integer topK;
        private Float topP;
        private Integer maxRetries;
        private String responseMimeType;
        private Schema responseSchema;
        private Map<HarmCategory, SafetyThreshold> safetySettings;
        private Boolean useGoogleSearch;
        private String vertexSearchDatastore;
        private ToolCallingMode toolCallingMode;
        private List<String> allowedFunctionNames;
        private Boolean logRequests;
        private Boolean logResponses;

        public VertexAiGeminiChatModelBuilder project(String project) {
            this.project = project;
            return this;
        }

        public VertexAiGeminiChatModelBuilder location(String location) {
            this.location = location;
            return this;
        }

        public VertexAiGeminiChatModelBuilder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public VertexAiGeminiChatModelBuilder temperature(Float temperature) {
            this.temperature = temperature;
            return this;
        }

        public VertexAiGeminiChatModelBuilder maxOutputTokens(Integer maxOutputTokens) {
            this.maxOutputTokens = maxOutputTokens;
            return this;
        }

        public VertexAiGeminiChatModelBuilder topK(Integer topK) {
            this.topK = topK;
            return this;
        }

        public VertexAiGeminiChatModelBuilder topP(Float topP) {
            this.topP = topP;
            return this;
        }

        public VertexAiGeminiChatModelBuilder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public VertexAiGeminiChatModelBuilder responseMimeType(String responseMimeType) {
            this.responseMimeType = responseMimeType;
            return this;
        }

        public VertexAiGeminiChatModelBuilder responseSchema(Schema responseSchema) {
            this.responseSchema = responseSchema;
            return this;
        }

        public VertexAiGeminiChatModelBuilder safetySettings(Map<HarmCategory, SafetyThreshold> safetySettings) {
            this.safetySettings = safetySettings;
            return this;
        }

        public VertexAiGeminiChatModelBuilder useGoogleSearch(Boolean useGoogleSearch) {
            this.useGoogleSearch = useGoogleSearch;
            return this;
        }

        public VertexAiGeminiChatModelBuilder vertexSearchDatastore(String vertexSearchDatastore) {
            this.vertexSearchDatastore = vertexSearchDatastore;
            return this;
        }

        public VertexAiGeminiChatModelBuilder toolCallingMode(ToolCallingMode toolCallingMode) {
            this.toolCallingMode = toolCallingMode;
            return this;
        }

        public VertexAiGeminiChatModelBuilder allowedFunctionNames(List<String> allowedFunctionNames) {
            this.allowedFunctionNames = allowedFunctionNames;
            return this;
        }

        public VertexAiGeminiChatModelBuilder logRequests(Boolean logRequests) {
            this.logRequests = logRequests;
            return this;
        }

        public VertexAiGeminiChatModelBuilder logResponses(Boolean logResponses) {
            this.logResponses = logResponses;
            return this;
        }

        public VertexAiGeminiChatModel build() {
            return new VertexAiGeminiChatModel(this.project, this.location, this.modelName, this.temperature, this.maxOutputTokens, this.topK, this.topP, this.maxRetries, this.responseMimeType, this.responseSchema, this.safetySettings, this.useGoogleSearch, this.vertexSearchDatastore, this.toolCallingMode, this.allowedFunctionNames, this.logRequests, this.logResponses);
        }

        public String toString() {
            return "VertexAiGeminiChatModel.VertexAiGeminiChatModelBuilder(project=" + this.project + ", location=" + this.location + ", modelName=" + this.modelName + ", temperature=" + this.temperature + ", maxOutputTokens=" + this.maxOutputTokens + ", topK=" + this.topK + ", topP=" + this.topP + ", maxRetries=" + this.maxRetries + ", responseMimeType=" + this.responseMimeType + ", responseSchema=" + this.responseSchema + ", safetySettings=" + this.safetySettings + ", useGoogleSearch=" + this.useGoogleSearch + ", vertexSearchDatastore=" + this.vertexSearchDatastore + ", toolCallingMode=" + (Object)((Object)this.toolCallingMode) + ", allowedFunctionNames=" + this.allowedFunctionNames + ", logRequests=" + this.logRequests + ", logResponses=" + this.logResponses + ")";
        }
    }
}

