feat: Improved MCP server spinup and spindown when switching contexts or settings in the REPL: Modify existing config rather than stopping all servers always and re-initializing if unnecessary

This commit is contained in:
2026-02-20 14:36:34 -07:00
parent 15a293204f
commit e6e99b6926
+71 -51
View File
@@ -158,27 +158,31 @@ impl McpRegistry {
} }
pub async fn reinit( pub async fn reinit(
registry: McpRegistry, mut registry: McpRegistry,
enabled_mcp_servers: Option<String>, enabled_mcp_servers: Option<String>,
abort_signal: AbortSignal, abort_signal: AbortSignal,
) -> Result<Self> { ) -> Result<Self> {
debug!("Reinitializing MCP registry"); debug!("Reinitializing MCP registry");
debug!("Stopping all MCP servers");
let mut new_registry = abortable_run_with_spinner( let desired_ids = registry.resolve_server_ids(enabled_mcp_servers.clone());
registry.stop_all_servers(), let desired_set: HashSet<String> = desired_ids.iter().cloned().collect();
"Stopping MCP servers",
debug!("Stopping unused MCP servers");
abortable_run_with_spinner(
registry.stop_unused_servers(&desired_set),
"Stopping unused MCP servers",
abort_signal.clone(), abort_signal.clone(),
) )
.await?; .await?;
abortable_run_with_spinner( abortable_run_with_spinner(
new_registry.start_select_mcp_servers(enabled_mcp_servers), registry.start_select_mcp_servers(enabled_mcp_servers),
"Loading MCP servers", "Loading MCP servers",
abort_signal, abort_signal,
) )
.await?; .await?;
Ok(new_registry) Ok(registry)
} }
async fn start_select_mcp_servers( async fn start_select_mcp_servers(
@@ -192,43 +196,29 @@ impl McpRegistry {
return Ok(()); return Ok(());
} }
if let Some(servers) = enabled_mcp_servers { let desired_ids = self.resolve_server_ids(enabled_mcp_servers);
debug!("Starting selected MCP servers: {:?}", servers); let ids_to_start: Vec<String> = desired_ids.into_iter()
let config = self .filter(|id| !self.servers.contains_key(id))
.config .collect();
.as_ref()
.with_context(|| "MCP Config not defined. Cannot start servers")?;
let mcp_servers = config.mcp_servers.clone();
let enabled_servers: HashSet<String> = if ids_to_start.is_empty() {
servers.split(',').map(|s| s.trim().to_string()).collect(); return Ok(());
let server_ids: Vec<String> = if servers == "all" { }
mcp_servers.into_keys().collect()
} else {
mcp_servers
.into_keys()
.filter(|id| enabled_servers.contains(id))
.collect()
};
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( debug!("Starting selected MCP servers: {:?}", ids_to_start);
server_ids
.into_iter()
.map(|id| async { self.start_server(id).await }),
)
.buffer_unordered(num_cpus::get())
.try_collect()
.await?;
self.servers = results let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
.clone() ids_to_start
.into_iter() .into_iter()
.map(|(id, server, _)| (id, server)) .map(|id| async { self.start_server(id).await }),
.collect(); )
self.catalogs = results .buffer_unordered(num_cpus::get())
.into_iter() .try_collect()
.map(|(id, _, catalog)| (id, catalog)) .await?;
.collect();
for (id, server, catalog) in results {
self.servers.insert(id.clone(), server);
self.catalogs.insert(id, catalog);
} }
Ok(()) Ok(())
@@ -309,19 +299,49 @@ impl McpRegistry {
Ok((id.to_string(), service, catalog)) Ok((id.to_string(), service, catalog))
} }
pub async fn stop_all_servers(mut self) -> Result<Self> { fn resolve_server_ids(&self, enabled_mcp_servers: Option<String>) -> Vec<String> {
for (id, server) in self.servers { if let Some(config) = &self.config
Arc::try_unwrap(server) && let Some(servers) = enabled_mcp_servers {
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))? if servers == "all" {
.cancel() config.mcp_servers.keys().cloned().collect()
.await } else {
.with_context(|| format!("Failed to stop MCP server: {id}"))?; let enabled_servers: HashSet<String> =
info!("Stopped MCP server: {id}"); servers.split(',').map(|s| s.trim().to_string()).collect();
config.mcp_servers
.keys()
.filter(|id| enabled_servers.contains(*id))
.cloned()
.collect()
}
} else {
vec![]
}
}
pub async fn stop_unused_servers(&mut self, keep_ids: &HashSet<String>) -> Result<()> {
let mut ids_to_remove = Vec::new();
for (id, _) in self.servers.iter() {
if !keep_ids.contains(id) {
ids_to_remove.push(id.clone());
}
} }
self.servers = HashMap::new(); for id in ids_to_remove {
if let Some(server) = self.servers.remove(&id) {
Ok(self) match Arc::try_unwrap(server) {
Ok(server_inner) => {
server_inner.cancel().await
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
info!("Stopped MCP server: {id}");
}
Err(_) => {
info!("Detaching from MCP server: {id} (still in use)");
}
}
self.catalogs.remove(&id);
}
}
Ok(())
} }
pub fn list_started_servers(&self) -> Vec<String> { pub fn list_started_servers(&self) -> Vec<String> {