feat: Improved MCP server spinup and spindown when switching contexts or settings in the REPL: Modify existing config rather than stopping all servers always and re-initializing if unnecessary

This commit is contained in:
2026-02-20 14:36:34 -07:00
parent 15a293204f
commit e6e99b6926
+71 -51
View File
@@ -158,27 +158,31 @@ impl McpRegistry {
}
pub async fn reinit(
registry: McpRegistry,
mut registry: McpRegistry,
enabled_mcp_servers: Option<String>,
abort_signal: AbortSignal,
) -> Result<Self> {
debug!("Reinitializing MCP registry");
debug!("Stopping all MCP servers");
let mut new_registry = abortable_run_with_spinner(
registry.stop_all_servers(),
"Stopping MCP servers",
let desired_ids = registry.resolve_server_ids(enabled_mcp_servers.clone());
let desired_set: HashSet<String> = desired_ids.iter().cloned().collect();
debug!("Stopping unused MCP servers");
abortable_run_with_spinner(
registry.stop_unused_servers(&desired_set),
"Stopping unused MCP servers",
abort_signal.clone(),
)
.await?;
abortable_run_with_spinner(
new_registry.start_select_mcp_servers(enabled_mcp_servers),
registry.start_select_mcp_servers(enabled_mcp_servers),
"Loading MCP servers",
abort_signal,
)
.await?;
Ok(new_registry)
Ok(registry)
}
async fn start_select_mcp_servers(
@@ -192,43 +196,29 @@ impl McpRegistry {
return Ok(());
}
if let Some(servers) = enabled_mcp_servers {
debug!("Starting selected MCP servers: {:?}", servers);
let config = self
.config
.as_ref()
.with_context(|| "MCP Config not defined. Cannot start servers")?;
let mcp_servers = config.mcp_servers.clone();
let desired_ids = self.resolve_server_ids(enabled_mcp_servers);
let ids_to_start: Vec<String> = desired_ids.into_iter()
.filter(|id| !self.servers.contains_key(id))
.collect();
let enabled_servers: HashSet<String> =
servers.split(',').map(|s| s.trim().to_string()).collect();
let server_ids: Vec<String> = if servers == "all" {
mcp_servers.into_keys().collect()
} else {
mcp_servers
.into_keys()
.filter(|id| enabled_servers.contains(id))
.collect()
};
if ids_to_start.is_empty() {
return Ok(());
}
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
server_ids
.into_iter()
.map(|id| async { self.start_server(id).await }),
)
.buffer_unordered(num_cpus::get())
.try_collect()
.await?;
debug!("Starting selected MCP servers: {:?}", ids_to_start);
self.servers = results
.clone()
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
ids_to_start
.into_iter()
.map(|(id, server, _)| (id, server))
.collect();
self.catalogs = results
.into_iter()
.map(|(id, _, catalog)| (id, catalog))
.collect();
.map(|id| async { self.start_server(id).await }),
)
.buffer_unordered(num_cpus::get())
.try_collect()
.await?;
for (id, server, catalog) in results {
self.servers.insert(id.clone(), server);
self.catalogs.insert(id, catalog);
}
Ok(())
@@ -309,19 +299,49 @@ impl McpRegistry {
Ok((id.to_string(), service, catalog))
}
pub async fn stop_all_servers(mut self) -> Result<Self> {
for (id, server) in self.servers {
Arc::try_unwrap(server)
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))?
.cancel()
.await
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
info!("Stopped MCP server: {id}");
fn resolve_server_ids(&self, enabled_mcp_servers: Option<String>) -> Vec<String> {
if let Some(config) = &self.config
&& let Some(servers) = enabled_mcp_servers {
if servers == "all" {
config.mcp_servers.keys().cloned().collect()
} else {
let enabled_servers: HashSet<String> =
servers.split(',').map(|s| s.trim().to_string()).collect();
config.mcp_servers
.keys()
.filter(|id| enabled_servers.contains(*id))
.cloned()
.collect()
}
} else {
vec![]
}
}
pub async fn stop_unused_servers(&mut self, keep_ids: &HashSet<String>) -> Result<()> {
let mut ids_to_remove = Vec::new();
for (id, _) in self.servers.iter() {
if !keep_ids.contains(id) {
ids_to_remove.push(id.clone());
}
}
self.servers = HashMap::new();
Ok(self)
for id in ids_to_remove {
if let Some(server) = self.servers.remove(&id) {
match Arc::try_unwrap(server) {
Ok(server_inner) => {
server_inner.cancel().await
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
info!("Stopped MCP server: {id}");
}
Err(_) => {
info!("Detaching from MCP server: {id} (still in use)");
}
}
self.catalogs.remove(&id);
}
}
Ok(())
}
pub fn list_started_servers(&self) -> Vec<String> {