/*
 * (c) 2025 MuleSoft, Inc. The software in this package is published under the terms of the Commercial Free Software license V.1 a copy of which has been included with this distribution in the LICENSE.md file.
 */
package com.mulesoft.modules.agent.conductor.internal.llm.client.einstein;

import static org.mule.runtime.api.meta.model.parameter.ParameterGroupModel.DEFAULT_GROUP_NAME;

import org.mule.runtime.extension.api.client.ExtensionsClient;

import com.mulesoft.modules.agent.conductor.api.model.llm.einstein.EinsteinModel;
import com.mulesoft.modules.agent.conductor.internal.llm.client.LLMClient;
import com.mulesoft.modules.agent.conductor.internal.llm.LLMOrchestrationRequest;
import com.mulesoft.modules.agent.conductor.internal.state.model.LLMOutput;
import com.mulesoft.modules.agent.conductor.internal.serializer.LLMSerializer;

import java.io.InputStream;
import java.util.concurrent.CompletableFuture;

public class EinsteinClient implements LLMClient {

  private static final String EXTENSION_NAME = "Einstein AI";
  private static final String SETTINGS_PARAM_GROUP = "Additional properties";

  private final String configRef;
  private final EinsteinModel model;
  private final ExtensionsClient extensionsClient;
  private final LLMSerializer serializer;

  public EinsteinClient(String configRef, EinsteinModel model, ExtensionsClient extensionsClient, LLMSerializer serializer) {
    this.extensionsClient = extensionsClient;
    this.configRef = configRef;
    this.model = model;
    this.serializer = serializer;
  }

  @Override
  public CompletableFuture<LLMOutput> reasonNextStep(LLMOrchestrationRequest request) {
    return extensionsClient.<InputStream, Object>execute(EXTENSION_NAME, "CHAT-answer-prompt", params -> {
      params.withConfigRef(configRef)
          .withParameter(DEFAULT_GROUP_NAME, "prompt", request.getPrompt())
          .withParameter(SETTINGS_PARAM_GROUP, "modelApiName", model.getModelApiName())
          .withParameter(SETTINGS_PARAM_GROUP, "probability", model.getProbability())
          .withParameter(SETTINGS_PARAM_GROUP, "locale", model.getLocale());
    }).thenApply(result -> serializer.parseNextStep(result.getOutput()));
  }
}
