use crate::config::Config; use crate::utils::{AbortSignal, abortable_run_with_spinner}; use crate::vault::interpolate_secrets; use anyhow::{Context, Result, anyhow}; use bm25::{Document, Language, SearchEngine, SearchEngineBuilder}; use futures_util::future::BoxFuture; use futures_util::{StreamExt, TryStreamExt, stream}; use indoc::formatdoc; use rmcp::model::{CallToolRequestParams, CallToolResult}; use rmcp::service::RunningService; use rmcp::transport::TokioChildProcess; use rmcp::{RoleClient, ServiceExt}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::fs::OpenOptions; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; use tokio::process::Command; pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke"; pub const MCP_SEARCH_META_FUNCTION_NAME_PREFIX: &str = "mcp_search"; pub const MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX: &str = "mcp_describe"; type ConnectedServer = RunningService; #[derive(Clone, Debug, Default, Serialize)] pub struct CatalogItem { pub name: String, pub server: String, pub description: String, } #[derive(Debug)] struct ServerCatalog { engine: SearchEngine, items: HashMap, } impl ServerCatalog { pub fn build_bm25(items: &HashMap) -> SearchEngine { let docs = items.values().map(|it| { let contents = format!("{}\n{}\nserver:{}", it.name, it.description, it.server); Document { id: it.name.clone(), contents, } }); SearchEngineBuilder::::with_documents(Language::English, docs).build() } } impl Clone for ServerCatalog { fn clone(&self) -> Self { Self { engine: Self::build_bm25(&self.items), items: self.items.clone(), } } } #[derive(Debug, Clone, Deserialize)] struct McpServersConfig { #[serde(rename = "mcpServers")] mcp_servers: HashMap, } #[derive(Debug, Clone, Deserialize)] struct McpServer { command: String, args: Option>, env: Option>, cwd: Option, } #[derive(Debug, Clone, Deserialize)] #[serde(untagged)] enum JsonField { Str(String), Bool(bool), Int(i64), } #[derive(Debug, Clone, Default)] pub struct McpRegistry { log_path: Option, config: Option, servers: HashMap>, catalogs: HashMap, } impl McpRegistry { pub async fn init( log_path: Option, start_mcp_servers: bool, enabled_mcp_servers: Option, abort_signal: AbortSignal, config: &Config, ) -> Result { let mut registry = Self { log_path, ..Default::default() }; if !Config::mcp_config_file().try_exists().with_context(|| { format!( "Failed to check MCP config file at {}", Config::mcp_config_file().display() ) })? { debug!( "MCP config file does not exist at {}, skipping MCP initialization", Config::mcp_config_file().display() ); return Ok(registry); } let err = || { format!( "Failed to load MCP config file at {}", Config::mcp_config_file().display() ) }; let content = tokio::fs::read_to_string(Config::mcp_config_file()) .await .with_context(err)?; if content.trim().is_empty() { debug!("MCP config file is empty, skipping MCP initialization"); return Ok(registry); } let (parsed_content, missing_secrets) = interpolate_secrets(&content, &config.vault); if !missing_secrets.is_empty() { return Err(anyhow!(formatdoc!( " MCP config file references secrets that are missing from the vault: {:?} Please add these secrets to the vault and try again.", missing_secrets ))); } let mcp_servers_config: McpServersConfig = serde_json::from_str(&parsed_content).with_context(err)?; registry.config = Some(mcp_servers_config); if start_mcp_servers && config.mcp_server_support { abortable_run_with_spinner( registry.start_select_mcp_servers(enabled_mcp_servers), "Loading MCP servers", abort_signal, ) .await?; } Ok(registry) } pub async fn reinit( mut registry: McpRegistry, enabled_mcp_servers: Option, abort_signal: AbortSignal, ) -> Result { debug!("Reinitializing MCP registry"); let desired_ids = registry.resolve_server_ids(enabled_mcp_servers.clone()); let desired_set: HashSet = desired_ids.iter().cloned().collect(); debug!("Stopping unused MCP servers"); abortable_run_with_spinner( registry.stop_unused_servers(&desired_set), "Stopping unused MCP servers", abort_signal.clone(), ) .await?; abortable_run_with_spinner( registry.start_select_mcp_servers(enabled_mcp_servers), "Loading MCP servers", abort_signal, ) .await?; Ok(registry) } async fn start_select_mcp_servers( &mut self, enabled_mcp_servers: Option, ) -> Result<()> { if self.config.is_none() { debug!( "MCP config is not present; assuming MCP servers are disabled globally. Skipping MCP initialization" ); return Ok(()); } let desired_ids = self.resolve_server_ids(enabled_mcp_servers); let ids_to_start: Vec = desired_ids .into_iter() .filter(|id| !self.servers.contains_key(id)) .collect(); if ids_to_start.is_empty() { return Ok(()); } debug!("Starting selected MCP servers: {:?}", ids_to_start); let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( ids_to_start .into_iter() .map(|id| async { self.start_server(id).await }), ) .buffer_unordered(num_cpus::get()) .try_collect() .await?; for (id, server, catalog) in results { self.servers.insert(id.clone(), server); self.catalogs.insert(id, catalog); } Ok(()) } async fn start_server( &self, id: String, ) -> Result<(String, Arc, ServerCatalog)> { let server = self .config .as_ref() .and_then(|c| c.mcp_servers.get(&id)) .with_context(|| format!("MCP server not found in config: {id}"))?; let mut cmd = Command::new(&server.command); if let Some(args) = &server.args { cmd.args(args); } if let Some(env) = &server.env { let env: HashMap = env .iter() .map(|(k, v)| match v { JsonField::Str(s) => (k.clone(), s.clone()), JsonField::Bool(b) => (k.clone(), b.to_string()), JsonField::Int(i) => (k.clone(), i.to_string()), }) .collect(); cmd.envs(env); } if let Some(cwd) = &server.cwd { cmd.current_dir(cwd); } let transport = if let Some(log_path) = self.log_path.as_ref() { cmd.stdin(Stdio::piped()).stdout(Stdio::piped()); let log_file = OpenOptions::new() .create(true) .append(true) .open(log_path)?; let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?; transport } else { TokioChildProcess::new(cmd)? }; let service = Arc::new( ().serve(transport) .await .with_context(|| format!("Failed to start MCP server: {}", &server.command))?, ); let tools = service.list_tools(None).await?; debug!("Available tools for MCP server {id}: {tools:?}"); let mut items_vec = Vec::new(); for t in tools.tools { let name = t.name.to_string(); let description = t.description.unwrap_or_default().to_string(); items_vec.push(CatalogItem { name, server: id.clone(), description, }); } let mut items_map = HashMap::new(); items_vec.into_iter().for_each(|it| { items_map.insert(it.name.clone(), it); }); let catalog = ServerCatalog { engine: ServerCatalog::build_bm25(&items_map), items: items_map, }; info!("Started MCP server: {id}"); Ok((id.to_string(), service, catalog)) } fn resolve_server_ids(&self, enabled_mcp_servers: Option) -> Vec { if let Some(config) = &self.config && let Some(servers) = enabled_mcp_servers { if servers == "all" { config.mcp_servers.keys().cloned().collect() } else { let enabled_servers: HashSet = servers.split(',').map(|s| s.trim().to_string()).collect(); config .mcp_servers .keys() .filter(|id| enabled_servers.contains(*id)) .cloned() .collect() } } else { vec![] } } pub async fn stop_unused_servers(&mut self, keep_ids: &HashSet) -> Result<()> { let mut ids_to_remove = Vec::new(); for (id, _) in self.servers.iter() { if !keep_ids.contains(id) { ids_to_remove.push(id.clone()); } } for id in ids_to_remove { if let Some(server) = self.servers.remove(&id) { match Arc::try_unwrap(server) { Ok(server_inner) => { server_inner .cancel() .await .with_context(|| format!("Failed to stop MCP server: {id}"))?; info!("Stopped MCP server: {id}"); } Err(_) => { info!("Detaching from MCP server: {id} (still in use)"); } } self.catalogs.remove(&id); } } Ok(()) } pub fn list_started_servers(&self) -> Vec { self.servers.keys().cloned().collect() } pub fn list_configured_servers(&self) -> Vec { if let Some(config) = &self.config { config.mcp_servers.keys().cloned().collect() } else { vec![] } } pub fn search_tools_server(&self, server: &str, query: &str, top_k: usize) -> Vec { let Some(catalog) = self.catalogs.get(server) else { return vec![]; }; let engine = &catalog.engine; let raw = engine.search(query, top_k.min(20)); raw.into_iter() .filter_map(|r| catalog.items.get(&r.document.id)) .take(top_k) .cloned() .collect() } pub async fn describe(&self, server_id: &str, tool: &str) -> Result { let server = self .servers .iter() .filter(|(id, _)| &server_id == id) .map(|(_, s)| s.clone()) .next() .ok_or(anyhow!("{server_id} MCP server not found in config"))?; let tool_schema = server .list_tools(None) .await? .tools .into_iter() .find(|it| it.name == tool) .ok_or(anyhow!( "{tool} not found in {server_id} MCP server catalog" ))? .input_schema; Ok(json!({ "type": "object", "properties": { "tool": { "type": "string", }, "arguments": tool_schema } })) } pub fn invoke( &self, server: &str, tool: &str, arguments: Value, ) -> BoxFuture<'static, Result> { let server = self .servers .get(server) .cloned() .with_context(|| format!("Invoked MCP server does not exist: {server}")); let tool = tool.to_owned(); Box::pin(async move { let server = server?; let call_tool_request = CallToolRequestParams { name: Cow::Owned(tool.to_owned()), arguments: arguments.as_object().cloned(), meta: None, task: None, }; let result = server.call_tool(call_tool_request).await?; Ok(result) }) } pub fn is_empty(&self) -> bool { self.servers.is_empty() } }