diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 492eeba..5213bcd 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -158,27 +158,31 @@ impl McpRegistry { } pub async fn reinit( - registry: McpRegistry, + mut registry: McpRegistry, enabled_mcp_servers: Option, abort_signal: AbortSignal, ) -> Result { debug!("Reinitializing MCP registry"); - debug!("Stopping all MCP servers"); - let mut new_registry = abortable_run_with_spinner( - registry.stop_all_servers(), - "Stopping MCP servers", + + let desired_ids = registry.resolve_server_ids(enabled_mcp_servers.clone()); + let desired_set: HashSet = desired_ids.iter().cloned().collect(); + + debug!("Stopping unused MCP servers"); + abortable_run_with_spinner( + registry.stop_unused_servers(&desired_set), + "Stopping unused MCP servers", abort_signal.clone(), ) .await?; abortable_run_with_spinner( - new_registry.start_select_mcp_servers(enabled_mcp_servers), + registry.start_select_mcp_servers(enabled_mcp_servers), "Loading MCP servers", abort_signal, ) .await?; - Ok(new_registry) + Ok(registry) } async fn start_select_mcp_servers( @@ -192,43 +196,29 @@ impl McpRegistry { return Ok(()); } - if let Some(servers) = enabled_mcp_servers { - debug!("Starting selected MCP servers: {:?}", servers); - let config = self - .config - .as_ref() - .with_context(|| "MCP Config not defined. Cannot start servers")?; - let mcp_servers = config.mcp_servers.clone(); + let desired_ids = self.resolve_server_ids(enabled_mcp_servers); + let ids_to_start: Vec = desired_ids.into_iter() + .filter(|id| !self.servers.contains_key(id)) + .collect(); - let enabled_servers: HashSet = - servers.split(',').map(|s| s.trim().to_string()).collect(); - let server_ids: Vec = if servers == "all" { - mcp_servers.into_keys().collect() - } else { - mcp_servers - .into_keys() - .filter(|id| enabled_servers.contains(id)) - .collect() - }; + if ids_to_start.is_empty() { + return Ok(()); + } - let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( - server_ids - .into_iter() - .map(|id| async { self.start_server(id).await }), - ) - .buffer_unordered(num_cpus::get()) - .try_collect() - .await?; + debug!("Starting selected MCP servers: {:?}", ids_to_start); - self.servers = results - .clone() + let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( + ids_to_start .into_iter() - .map(|(id, server, _)| (id, server)) - .collect(); - self.catalogs = results - .into_iter() - .map(|(id, _, catalog)| (id, catalog)) - .collect(); + .map(|id| async { self.start_server(id).await }), + ) + .buffer_unordered(num_cpus::get()) + .try_collect() + .await?; + + for (id, server, catalog) in results { + self.servers.insert(id.clone(), server); + self.catalogs.insert(id, catalog); } Ok(()) @@ -309,19 +299,49 @@ impl McpRegistry { Ok((id.to_string(), service, catalog)) } - pub async fn stop_all_servers(mut self) -> Result { - for (id, server) in self.servers { - Arc::try_unwrap(server) - .map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))? - .cancel() - .await - .with_context(|| format!("Failed to stop MCP server: {id}"))?; - info!("Stopped MCP server: {id}"); + fn resolve_server_ids(&self, enabled_mcp_servers: Option) -> Vec { + if let Some(config) = &self.config + && let Some(servers) = enabled_mcp_servers { + if servers == "all" { + config.mcp_servers.keys().cloned().collect() + } else { + let enabled_servers: HashSet = + servers.split(',').map(|s| s.trim().to_string()).collect(); + config.mcp_servers + .keys() + .filter(|id| enabled_servers.contains(*id)) + .cloned() + .collect() + } + } else { + vec![] + } + } + + pub async fn stop_unused_servers(&mut self, keep_ids: &HashSet) -> Result<()> { + let mut ids_to_remove = Vec::new(); + for (id, _) in self.servers.iter() { + if !keep_ids.contains(id) { + ids_to_remove.push(id.clone()); + } } - self.servers = HashMap::new(); - - Ok(self) + for id in ids_to_remove { + if let Some(server) = self.servers.remove(&id) { + match Arc::try_unwrap(server) { + Ok(server_inner) => { + server_inner.cancel().await + .with_context(|| format!("Failed to stop MCP server: {id}"))?; + info!("Stopped MCP server: {id}"); + } + Err(_) => { + info!("Detaching from MCP server: {id} (still in use)"); + } + } + self.catalogs.remove(&id); + } + } + Ok(()) } pub fn list_started_servers(&self) -> Vec {