use crate::config::Config; use crate::config::paths; use crate::utils::{AbortSignal, abortable_run_with_spinner}; use crate::vault::interpolate_secrets; use anyhow::{Context, Result, anyhow}; use futures_util::{StreamExt, TryStreamExt, stream}; use indoc::formatdoc; use rmcp::service::RunningService; use rmcp::transport::TokioChildProcess; use rmcp::{RoleClient, ServiceExt}; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::fs::OpenOptions; use std::path::{Path, PathBuf}; use std::process::Stdio; use std::sync::Arc; use tokio::process::Command; pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke"; pub const MCP_SEARCH_META_FUNCTION_NAME_PREFIX: &str = "mcp_search"; pub const MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX: &str = "mcp_describe"; pub type ConnectedServer = RunningService; #[derive(Clone, Debug, Default, Serialize)] pub struct CatalogItem { pub name: String, pub server: String, pub description: String, } #[derive(Debug)] struct ServerCatalog { items: HashMap, } impl Clone for ServerCatalog { fn clone(&self) -> Self { Self { items: self.items.clone(), } } } #[derive(Debug, Clone, Deserialize)] pub(crate) struct McpServersConfig { #[serde(rename = "mcpServers")] pub mcp_servers: HashMap, } #[derive(Debug, Clone, Deserialize)] pub(crate) struct McpServer { pub command: String, pub args: Option>, pub env: Option>, pub cwd: Option, } #[derive(Debug, Clone, Deserialize)] #[serde(untagged)] pub(crate) enum JsonField { Str(String), Bool(bool), Int(i64), } #[derive(Debug, Clone, Default)] pub struct McpRegistry { log_path: Option, config: Option, servers: HashMap>, catalogs: HashMap, } impl McpRegistry { pub async fn init( log_path: Option, start_mcp_servers: bool, enabled_mcp_servers: Option, abort_signal: AbortSignal, config: &Config, ) -> Result { let mut registry = Self { log_path, ..Default::default() }; if !paths::mcp_config_file().try_exists().with_context(|| { format!( "Failed to check MCP config file at {}", paths::mcp_config_file().display() ) })? { debug!( "MCP config file does not exist at {}, skipping MCP initialization", paths::mcp_config_file().display() ); return Ok(registry); } let err = || { format!( "Failed to load MCP config file at {}", paths::mcp_config_file().display() ) }; let content = tokio::fs::read_to_string(paths::mcp_config_file()) .await .with_context(err)?; if content.trim().is_empty() { debug!("MCP config file is empty, skipping MCP initialization"); return Ok(registry); } let (parsed_content, missing_secrets) = interpolate_secrets(&content, &config.vault); if !missing_secrets.is_empty() { return Err(anyhow!(formatdoc!( " MCP config file references secrets that are missing from the vault: {:?} Please add these secrets to the vault and try again.", missing_secrets ))); } let mcp_servers_config: McpServersConfig = serde_json::from_str(&parsed_content).with_context(err)?; registry.config = Some(mcp_servers_config); if start_mcp_servers && config.mcp_server_support { abortable_run_with_spinner( registry.start_select_mcp_servers(enabled_mcp_servers), "Loading MCP servers", abort_signal, ) .await?; } Ok(registry) } async fn start_select_mcp_servers( &mut self, enabled_mcp_servers: Option, ) -> Result<()> { if self.config.is_none() { debug!( "MCP config is not present; assuming MCP servers are disabled globally. Skipping MCP initialization" ); return Ok(()); } let desired_ids = self.resolve_server_ids(enabled_mcp_servers); let ids_to_start: Vec = desired_ids .into_iter() .filter(|id| !self.servers.contains_key(id)) .collect(); if ids_to_start.is_empty() { return Ok(()); } debug!("Starting selected MCP servers: {:?}", ids_to_start); let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( ids_to_start .into_iter() .map(|id| async { self.start_server(id).await }), ) .buffer_unordered(num_cpus::get()) .try_collect() .await?; for (id, server, catalog) in results { self.servers.insert(id.clone(), server); self.catalogs.insert(id, catalog); } Ok(()) } async fn start_server( &self, id: String, ) -> Result<(String, Arc, ServerCatalog)> { let spec = self .config .as_ref() .and_then(|c| c.mcp_servers.get(&id)) .with_context(|| format!("MCP server not found in config: {id}"))?; let service = spawn_mcp_server(spec, self.log_path.as_deref()).await?; let tools = service.list_tools(None).await?; debug!("Available tools for MCP server {id}: {tools:?}"); let mut items_vec = Vec::new(); for t in tools.tools { let name = t.name.to_string(); let description = t.description.unwrap_or_default().to_string(); items_vec.push(CatalogItem { name, server: id.clone(), description, }); } let mut items_map = HashMap::new(); items_vec.into_iter().for_each(|it| { items_map.insert(it.name.clone(), it); }); let catalog = ServerCatalog { items: items_map }; info!("Started MCP server: {id}"); Ok((id.to_string(), service, catalog)) } fn resolve_server_ids(&self, enabled_mcp_servers: Option) -> Vec { if let Some(config) = &self.config && let Some(servers) = enabled_mcp_servers { if servers == "all" { config.mcp_servers.keys().cloned().collect() } else { let enabled_servers: HashSet = servers.split(',').map(|s| s.trim().to_string()).collect(); config .mcp_servers .keys() .filter(|id| enabled_servers.contains(*id)) .cloned() .collect() } } else { vec![] } } pub fn running_servers(&self) -> &HashMap> { &self.servers } pub fn list_started_servers(&self) -> Vec { self.servers.keys().cloned().collect() } pub fn is_empty(&self) -> bool { self.servers.is_empty() } pub fn mcp_config(&self) -> Option<&McpServersConfig> { self.config.as_ref() } pub fn log_path(&self) -> Option<&PathBuf> { self.log_path.as_ref() } } pub(crate) async fn spawn_mcp_server( spec: &McpServer, log_path: Option<&Path>, ) -> Result> { let mut cmd = Command::new(&spec.command); if let Some(args) = &spec.args { cmd.args(args); } if let Some(env) = &spec.env { let env: HashMap = env .iter() .map(|(k, v)| match v { JsonField::Str(s) => (k.clone(), s.clone()), JsonField::Bool(b) => (k.clone(), b.to_string()), JsonField::Int(i) => (k.clone(), i.to_string()), }) .collect(); cmd.envs(env); } if let Some(cwd) = &spec.cwd { cmd.current_dir(cwd); } let transport = if let Some(log_path) = log_path { cmd.stdin(Stdio::piped()).stdout(Stdio::piped()); let log_file = OpenOptions::new() .create(true) .append(true) .open(log_path)?; let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?; transport } else { TokioChildProcess::new(cmd)? }; let service = Arc::new( ().serve(transport) .await .with_context(|| format!("Failed to start MCP server: {}", &spec.command))?, ); Ok(service) }