%dw 2.0
import fail from dw::Runtime
import * from jsonschema!agents_domain

ns xsi http://www.w3.org/2001/XMLSchema-instance
ns http http://www.mulesoft.org/schema/mule/http
ns a2a http://www.mulesoft.org/schema/mule/a2a
ns ms_einstein_ai http://www.mulesoft.org/schema/mule/ms-einstein-ai
ns mcp http://www.mulesoft.org/schema/mule/mcp
ns agents_conductor http://www.mulesoft.org/schema/mule/agents-conductor
ns ee http://www.mulesoft.org/schema/mule/ee/core
ns oauth http://www.mulesoft.org/schema/mule/oauth

fun createLLMProvidersConfig(llmProvider: { _?: EinsteinLLM }) =
  {
    (llmProvider pluck ((value, key, index) -> {
              ms_einstein_ai#config @(name: key): {
                  ms_einstein_ai#"oauth-client-credentials-connection": {
                      ms_einstein_ai#"oauth-client-credentials" @(clientId: value.clientId, clientSecret: value.clientSecret, tokenUrl: value.baseUrl): {}
                    }
                }
            })
    )
  }

fun createMCPClientConfig(mcpServers: Array<{ _?:MCPServer }> | Null) =
  {
    (mcpServers map ((mcpServer) -> {
              (mcpServer pluck ((serverValue, serverName, index) -> {
                        mcp#"client-config" @(name: serverName, clientName: serverName, clientVersion: "1.0.0"): 
                          if (serverValue.transport.sse?)
                            {
                              mcp#"sse-client-connection" @(serverUrl: "\${egressgw.url}/$(serverName)", sseEndpointPath: serverValue.transport.sse.ssePath): {
                                  (
                                      if(serverValue.authentication is AnypointClientCredentialsAuth) {                    
                                        mcp#"default-request-headers": {
                                            mcp#"default-request-header" @(key: "clientId", value: serverValue.authentication.clientId): {},
                                            mcp#"default-request-header" @(key: "clientSecret", value: serverValue.authentication.clientSecret): {},
                                          }
                                      } else if (serverValue.authentication != null) mcp#authentication: createAuthenticationObject(serverValue)
                                    else{}
                                  ) 
                                } 
                            }
                          else
                            {
                              mcp#"streamable-http-client-connection" @(serverUrl: "\${egressgw.url}/$(serverName)", (mcpEndpointPath: serverValue.transport.streamableHttp.path) if(serverValue.transport.streamableHttp.path?)): {
                                  (
                                      if(serverValue.authentication is AnypointClientCredentialsAuth) {                    
                                        mcp#"default-request-headers": {
                                            mcp#"default-request-header" @(key: "clientId", value: serverValue.authentication.clientId): {},
                                            mcp#"default-request-header" @(key: "clientSecret", value: serverValue.authentication.clientSecret): {},
                                          }
                                      } else if (serverValue.authentication != null) mcp#authentication: createAuthenticationObject(serverValue)
                                    else{}
                                  )
                                }
                            }
                      }))
            }))
  }
fun createAuthenticationObject(value: { authentication?: Authentication }) = 
  value.authentication match {
    case cc is OAuth2ClientCredentialsAuth ->  {
        oauth#"client-credentials-grant-type" @(
          clientId: cc.clientId,
          clientSecret: cc.clientSecret,
          tokenUrl: cc.tokenUrl,
          (tokenUrl: cc.scopes joinBy ",") if(cc.scopes?)): {}
      }                
    case ba is BasicAuth -> {
        http#"basic-authentication" @(
          username: ba.username,
          password: ba.password): {}
      }
  }


fun createAgentClientConfig(agents: { _?: { authentication?: Authentication } } | Null) =
  agents pluck ((value, agentName, index) -> {
          a2a#"client-config" @(name: "$(agentName)-a2a-client-config"): {
              a2a#"client-connection" @(serverUrl: "\${agentregistry.url}/$(agentName)"): {
                  (
                      if(value.authentication is AnypointClientCredentialsAuth) {                    
                        a2a#"default-headers": {                        
                            a2a#"default-header" @(key: "clientId", value: value.authentication.clientId): {},
                            a2a#"default-header" @(key: "clientSecret", value: value.authentication.clientSecret): {}                                
                          }
                      } else if(value.authentication?) {
                        a2a#authentication: createAuthenticationObject(value)
                      } else {}
                  )                    
                }
            }
        })

fun createExternalAgentClientConfig(agents: { _?: { authentication?: Authentication } } | Null) =
  agents pluck ((value, agentName, index) -> {
          a2a#"client-config" @(name: "$(agentName)-a2a-client-config"): {
              a2a#"client-connection" @(serverUrl: "\${egressgw.url}/$(agentName)"): {
                  (
                      if(value.authentication is AnypointClientCredentialsAuth) {                    
                        a2a#"default-headers": {                        
                            a2a#"default-header" @(key: "clientId", value: value.authentication.clientId): {},
                            a2a#"default-header" @(key: "clientSecret", value: value.authentication.clientSecret): {}                                
                          }
                      } else if(value.authentication?) {
                        a2a#authentication: createAuthenticationObject(value)
                      } else {}
                  )                    
                }
            }
        })        

fun createSecurityScheme(securitySchemeName: String, securityScheme: SecurityScheme) =
  //TODO add support for other security schemes
  securityScheme  match {    
    case oauth is OAuth2SecurityScheme -> 
      {
        a2a#"oauth2-security-scheme" @(securitySchemeName: securitySchemeName): {
            a2a#flows: {
                (a2a#"client-credentials" @(tokenUrl: securityScheme.flows.clientCredentials.tokenUrl): {
                    a2a#"oauth2-scopes": {
                        (oauth.flows.clientCredentials.scopes pluck ((scope, key, index) -> {
                                  a2a#"oauth2-scope" @(name: scope): {
                                      a2a#description: scope as CData
                                    }
                                }
                              )
                          )
                      }
                  }) if (securityScheme.flows.clientCredentials?),
              }
          }
      }      
    case apikey is APIKeySecurityScheme -> 
      {
        a2a#"api-key-security-scheme" @(name: apikey.name, in: upper(apikey.in), securitySchemeName: securitySchemeName): {
            a2a#description: (apikey.description default "") as CData,
          }
      }
    case securityScheme is HTTPAuthSecurityScheme -> 
      {
        a2a#"http-security-scheme" @(securitySchemeName: securitySchemeName, scheme: securityScheme.scheme, (bearerFormat: securityScheme.bearerFormat) if(securityScheme.bearerFormat?)): {
            a2a#description: (securityScheme.description default "") as CData,
          }
      }
    case securityScheme is OpenIdConnectSecurityScheme -> 
      {
        a2a#"open-id-connect-security-scheme" @(securitySchemeName: securitySchemeName, openIdConnectUrl: securityScheme.openIdConnectUrl): {
            a2a#description: (securityScheme.description default "") as CData,
          }
      }
    else -> fail("Unsupported security scheme type: " ++ (typeOf(securityScheme) as String))
  }
  
fun getConductorConfigName(agent: Agent) = agent.card.name replace " " with ("_")

fun createConductorConfig(agents: { _?: Agent }) =
  agents pluck ((agent, agentName, index) -> {
          agents_conductor#"config" @(
            name: getConductorConfigName(agent), 
            (maxNumberOfLoops: agent.spec.maxNumberOfLoops) if (agent.spec.maxNumberOfLoops != null), 
            (maxConsecutiveErrors: agent.spec.maxConsecutiveErrors) if (agent.spec.maxConsecutiveErrors != null)): null
        })
  
fun createAgentListenerConfig(agents: { _?: Agent }) =
  agents pluck ((agent, agentName, index) -> {
          a2a#"server-config" @(name: "$(agentName)-a2a-listener-config"): {
              a2a#connection @(listenerConfig: "http-listener-config", agentPath: "/$(agentName)"): null,
              a2a#card @(name: agent.card.name, url: "\${agentregistry.url}/$(agentName)", version: "1.0.0"): {
                  a2a#description: agent.card.description as CData,
                  a2a#skills: {
                      (agent.card.skills map ((skill) -> {
                                a2a#"agent-skill" @(id: skill.id, name: skill.name): {
                                    a2a#description: skill.description as CData,
                                    (a2a#tags: {
                                        (skill.tags map ((tag) -> {
                                                  a2a#tag @(value: tag): {}
                                                }))
                                      }) if (!isEmpty(skill.tags)),
                                    (a2a#examples: {
                                        (skill.examples map ((example) -> {
                                                  a2a#example @(value: example): {}
                                                }))
                                      }) if (!isEmpty(skill.examples)),
                                    (a2a#"input-modes": {
                                        (skill.inputModes map ((inputMode) -> {
                                                  a2a#"input-mode" @(value: inputMode): {}
                                                }))
                                      }) if (!isEmpty(skill.inputModes)),
                                    (a2a#"output-modes": {
                                        (skill.outputModes map ((outputMode) -> {
                                                  a2a#"output-mode" @(value: outputMode): {}
                                                }))
                                      }) if (!isEmpty(skill.outputModes))
                                  }
                              }))
                    },
                  a2a#provider @(organization: agent.card.provider.organization, url: agent.provider.url): null,
                  a2a#"default-input-modes": {
                      (agent.card.defaultInputModes map ((inputMode) -> {
                                a2a#"default-input-mode" @(value: inputMode): {}
                              }))
                    },
                  a2a#"default-output-modes": {
                      (agent.card.defaultOutputModes map ((outputMode) -> {
                                a2a#"default-output-mode" @(value: outputMode): {}
                              }))
                    },
                  (a2a#security: {
                      (agent.card.security map ((schemeDefinition, index) -> do {
                                var schemeName = keysOf(schemeDefinition)[0]
                                var scheme = schemeDefinition[schemeName]
                                ---
                                a2a#"security-requirement" @(securitySchemeName: keysOf(schemeDefinition)[0]): {
                                      a2a#scopes: {
                                          (scheme map ((security) -> {
                                                    a2a#scope @(value: security): {}
                                                  }))
                                        }
                                    }
                              }))
                    }) if (agent.card.security?),
                  (a2a#"security-schemes": {
                      (agent.card.securitySchemes pluck ((securityScheme, securitySchemeName, index) ->
                              createSecurityScheme(securitySchemeName, securityScheme)
                            ))
                    }) if (agent.card.securitySchemes?)
                }
            }
        })

fun createRemoteAgentConfig(agentName: String) =
  {
    a2a#"client-config" @(name: agentName): {
        a2a#"client-connection" @(serverUrl: "\${agentregistry.url}/$(agentName)"): {}
      }
  }


fun createAgents(agents: { _?: Agent }, llmProviders: { _?: EinsteinLLM }) =
  agents pluck ((value, key, index) -> createAgentFlow(key as String, value, agents, llmProviders))

fun createAgentFlow(agentName: String, agent: Agent, allAgents: { _?: Agent }, llmProviders: { _?: EinsteinLLM }) =
  {
    flow @(name: "$(agentName)-flow"): {
        a2a#"task-listener" @("config-ref": "$(agentName)-a2a-listener-config"): null,
        try: {
            agents_conductor#"agent-loop" @(
              taskId: "#[if (payload.message.taskId != null) payload.message.taskId else if (payload.message.referenceTaskIds != null and sizeOf(payload.message.referenceTaskIds) > 0) payload.message.referenceTaskIds[0] else null]", 
              contextId: "#[payload.message.contextId]", 
              "config-ref": getConductorConfigName(agent)): {
                agents_conductor#prompt: "#[payload.message.parts.text joinBy \"\\n\" onNull \"\"]" as CData,
                agents_conductor#llm: do {
                    var llm = agent.spec.llm
                    var llmProvider = llmProviders[llm]! 
                    ---                
                    {
                      agents_conductor#"einstein-settings" @(
                        einsteinAiConfigRef: llm, 
                        (probability: llmProvider.probability) if(llmProvider.probability != null), 
                        (modelApiName: llmProvider.modelName) if(llmProvider.modelName != null),
                        (locale: llmProvider.locale) if(llmProvider.locale != null)
                        ): {}
                    }
                  },
                (agents_conductor#instructions: (agent.spec.instructions joinBy "\n") as CData) if (!isEmpty(agent.spec.instructions)),
                (agents_conductor#"mcp-servers": { (
                        agent.spec.tools flatMap ((tool) ->
                              if (tool.mcp?)
                                [ {
                                    agents_conductor#"mcp-server" @(mcpClientConfigRef: tool.mcp.server): {
                                        (agents_conductor#"allowed-tools": {
                                            (tool.mcp.allowed map ((tool) -> {
                                                      agents_conductor#"allowed-tool" @(value: tool): null
                                                    }))
                                          }) if (!isEmpty(tool.mcp.allowed))
                                      }
                                  }]
                              else []
                            )
          )}),
                (agents_conductor#"a2a-clients": {
                    (agent.spec.links map ((linkDef) -> {
                              (linkDef pluck ((value, key, index) -> {
                                        agents_conductor#"a2a-client" @(a2aClientConfigRef: "$(key)-a2a-client-config"): {}
                                      }))
                            }))
                  }) if (!isEmpty(agent.spec.links))
              },
            ee#transform: {
                ee#message: {
                    ee#"set-payload": "%dw 2.0\noutput application/json\n--- \n{\n    \"id\": payload.taskId,\n    \"contextId\": payload.contextId,\n    \"status\": {\n      \"state\": \n        if(payload.goalComplete) \"completed\" \n        else if (payload.inputRequired) \"input-required\" \n        else \"failed\",\n    },\n    \"artifacts\": [\n      {\n        \"artifactId\": uuid(),\n        \"parts\": [\n          {\n            \"kind\": \"text\",\n            \"text\": payload.response\n            }\n        ]\n      }\n    ],\n    \"kind\": \"task\"\n  }" as CData
                  }
              },
            "error-handler": {
                "on-error-continue" @(enableNotifications: "true", logException: "true"): {
                    ee#transform: {
                        ee#message: {
                            ee#"set-payload": "%dw 2.0\noutput application/json\n--- \n{\n    (\"id\": payload.id) if (payload.id?),\n    (\"contextId\": payload.contextId) if (payload.contextId?),\n    \"status\": {\n      \"state\": \"failed\",\n \"message\": {\n \"role\": \"agent\",\n \"messageId\": uuid(),\n \"parts\": [\n          {\n            \"kind\": \"text\",\n            \"text\": error.description\n            }\n        ]\n }  \n} ,\n    \"kind\": if (payload.id?) \"task\" else \"message\"\n  }" as CData
                          }
                      }
                  }
              }
          },
      }
  }


fun createHttpListenerConfig() =
  {
    http#"listener-config" @(name: "http-listener-config"): {
        http#"listener-connection" @(host: "0.0.0.0", port: "\${http.port}"): {}
      }
  }

fun createConfigurationElement() =
  {
    "configuration-properties" @(file: "config.yaml"): {}
  }
