feat: Improved MCP server spinup and spindown when switching contexts or settings in the REPL: Modify existing config rather than stopping all servers always and re-initializing if unnecessary

This commit is contained in:
2026-02-20 14:36:34 -07:00
parent 15a293204f
commit e6e99b6926
+64 -44
View File
@@ -158,27 +158,31 @@ impl McpRegistry {
} }
pub async fn reinit( pub async fn reinit(
registry: McpRegistry, mut registry: McpRegistry,
enabled_mcp_servers: Option<String>, enabled_mcp_servers: Option<String>,
abort_signal: AbortSignal, abort_signal: AbortSignal,
) -> Result<Self> { ) -> Result<Self> {
debug!("Reinitializing MCP registry"); debug!("Reinitializing MCP registry");
debug!("Stopping all MCP servers");
let mut new_registry = abortable_run_with_spinner( let desired_ids = registry.resolve_server_ids(enabled_mcp_servers.clone());
registry.stop_all_servers(), let desired_set: HashSet<String> = desired_ids.iter().cloned().collect();
"Stopping MCP servers",
debug!("Stopping unused MCP servers");
abortable_run_with_spinner(
registry.stop_unused_servers(&desired_set),
"Stopping unused MCP servers",
abort_signal.clone(), abort_signal.clone(),
) )
.await?; .await?;
abortable_run_with_spinner( abortable_run_with_spinner(
new_registry.start_select_mcp_servers(enabled_mcp_servers), registry.start_select_mcp_servers(enabled_mcp_servers),
"Loading MCP servers", "Loading MCP servers",
abort_signal, abort_signal,
) )
.await?; .await?;
Ok(new_registry) Ok(registry)
} }
async fn start_select_mcp_servers( async fn start_select_mcp_servers(
@@ -192,27 +196,19 @@ impl McpRegistry {
return Ok(()); return Ok(());
} }
if let Some(servers) = enabled_mcp_servers { let desired_ids = self.resolve_server_ids(enabled_mcp_servers);
debug!("Starting selected MCP servers: {:?}", servers); let ids_to_start: Vec<String> = desired_ids.into_iter()
let config = self .filter(|id| !self.servers.contains_key(id))
.config .collect();
.as_ref()
.with_context(|| "MCP Config not defined. Cannot start servers")?;
let mcp_servers = config.mcp_servers.clone();
let enabled_servers: HashSet<String> = if ids_to_start.is_empty() {
servers.split(',').map(|s| s.trim().to_string()).collect(); return Ok(());
let server_ids: Vec<String> = if servers == "all" { }
mcp_servers.into_keys().collect()
} else { debug!("Starting selected MCP servers: {:?}", ids_to_start);
mcp_servers
.into_keys()
.filter(|id| enabled_servers.contains(id))
.collect()
};
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
server_ids ids_to_start
.into_iter() .into_iter()
.map(|id| async { self.start_server(id).await }), .map(|id| async { self.start_server(id).await }),
) )
@@ -220,15 +216,9 @@ impl McpRegistry {
.try_collect() .try_collect()
.await?; .await?;
self.servers = results for (id, server, catalog) in results {
.clone() self.servers.insert(id.clone(), server);
.into_iter() self.catalogs.insert(id, catalog);
.map(|(id, server, _)| (id, server))
.collect();
self.catalogs = results
.into_iter()
.map(|(id, _, catalog)| (id, catalog))
.collect();
} }
Ok(()) Ok(())
@@ -309,19 +299,49 @@ impl McpRegistry {
Ok((id.to_string(), service, catalog)) Ok((id.to_string(), service, catalog))
} }
pub async fn stop_all_servers(mut self) -> Result<Self> { fn resolve_server_ids(&self, enabled_mcp_servers: Option<String>) -> Vec<String> {
for (id, server) in self.servers { if let Some(config) = &self.config
Arc::try_unwrap(server) && let Some(servers) = enabled_mcp_servers {
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))? if servers == "all" {
.cancel() config.mcp_servers.keys().cloned().collect()
.await } else {
let enabled_servers: HashSet<String> =
servers.split(',').map(|s| s.trim().to_string()).collect();
config.mcp_servers
.keys()
.filter(|id| enabled_servers.contains(*id))
.cloned()
.collect()
}
} else {
vec![]
}
}
pub async fn stop_unused_servers(&mut self, keep_ids: &HashSet<String>) -> Result<()> {
let mut ids_to_remove = Vec::new();
for (id, _) in self.servers.iter() {
if !keep_ids.contains(id) {
ids_to_remove.push(id.clone());
}
}
for id in ids_to_remove {
if let Some(server) = self.servers.remove(&id) {
match Arc::try_unwrap(server) {
Ok(server_inner) => {
server_inner.cancel().await
.with_context(|| format!("Failed to stop MCP server: {id}"))?; .with_context(|| format!("Failed to stop MCP server: {id}"))?;
info!("Stopped MCP server: {id}"); info!("Stopped MCP server: {id}");
} }
Err(_) => {
self.servers = HashMap::new(); info!("Detaching from MCP server: {id} (still in use)");
}
Ok(self) }
self.catalogs.remove(&id);
}
}
Ok(())
} }
pub fn list_started_servers(&self) -> Vec<String> { pub fn list_started_servers(&self) -> Vec<String> {