feat: Improved MCP server spinup and spindown when switching contexts or settings in the REPL: Modify existing config rather than stopping all servers always and re-initializing if unnecessary
This commit is contained in:
+71
-51
@@ -158,27 +158,31 @@ impl McpRegistry {
|
||||
}
|
||||
|
||||
pub async fn reinit(
|
||||
registry: McpRegistry,
|
||||
mut registry: McpRegistry,
|
||||
enabled_mcp_servers: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
debug!("Reinitializing MCP registry");
|
||||
debug!("Stopping all MCP servers");
|
||||
let mut new_registry = abortable_run_with_spinner(
|
||||
registry.stop_all_servers(),
|
||||
"Stopping MCP servers",
|
||||
|
||||
let desired_ids = registry.resolve_server_ids(enabled_mcp_servers.clone());
|
||||
let desired_set: HashSet<String> = desired_ids.iter().cloned().collect();
|
||||
|
||||
debug!("Stopping unused MCP servers");
|
||||
abortable_run_with_spinner(
|
||||
registry.stop_unused_servers(&desired_set),
|
||||
"Stopping unused MCP servers",
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
abortable_run_with_spinner(
|
||||
new_registry.start_select_mcp_servers(enabled_mcp_servers),
|
||||
registry.start_select_mcp_servers(enabled_mcp_servers),
|
||||
"Loading MCP servers",
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(new_registry)
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
async fn start_select_mcp_servers(
|
||||
@@ -192,43 +196,29 @@ impl McpRegistry {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(servers) = enabled_mcp_servers {
|
||||
debug!("Starting selected MCP servers: {:?}", servers);
|
||||
let config = self
|
||||
.config
|
||||
.as_ref()
|
||||
.with_context(|| "MCP Config not defined. Cannot start servers")?;
|
||||
let mcp_servers = config.mcp_servers.clone();
|
||||
let desired_ids = self.resolve_server_ids(enabled_mcp_servers);
|
||||
let ids_to_start: Vec<String> = desired_ids.into_iter()
|
||||
.filter(|id| !self.servers.contains_key(id))
|
||||
.collect();
|
||||
|
||||
let enabled_servers: HashSet<String> =
|
||||
servers.split(',').map(|s| s.trim().to_string()).collect();
|
||||
let server_ids: Vec<String> = if servers == "all" {
|
||||
mcp_servers.into_keys().collect()
|
||||
} else {
|
||||
mcp_servers
|
||||
.into_keys()
|
||||
.filter(|id| enabled_servers.contains(id))
|
||||
.collect()
|
||||
};
|
||||
if ids_to_start.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
|
||||
server_ids
|
||||
.into_iter()
|
||||
.map(|id| async { self.start_server(id).await }),
|
||||
)
|
||||
.buffer_unordered(num_cpus::get())
|
||||
.try_collect()
|
||||
.await?;
|
||||
debug!("Starting selected MCP servers: {:?}", ids_to_start);
|
||||
|
||||
self.servers = results
|
||||
.clone()
|
||||
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
|
||||
ids_to_start
|
||||
.into_iter()
|
||||
.map(|(id, server, _)| (id, server))
|
||||
.collect();
|
||||
self.catalogs = results
|
||||
.into_iter()
|
||||
.map(|(id, _, catalog)| (id, catalog))
|
||||
.collect();
|
||||
.map(|id| async { self.start_server(id).await }),
|
||||
)
|
||||
.buffer_unordered(num_cpus::get())
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
for (id, server, catalog) in results {
|
||||
self.servers.insert(id.clone(), server);
|
||||
self.catalogs.insert(id, catalog);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -309,19 +299,49 @@ impl McpRegistry {
|
||||
Ok((id.to_string(), service, catalog))
|
||||
}
|
||||
|
||||
pub async fn stop_all_servers(mut self) -> Result<Self> {
|
||||
for (id, server) in self.servers {
|
||||
Arc::try_unwrap(server)
|
||||
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))?
|
||||
.cancel()
|
||||
.await
|
||||
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
|
||||
info!("Stopped MCP server: {id}");
|
||||
fn resolve_server_ids(&self, enabled_mcp_servers: Option<String>) -> Vec<String> {
|
||||
if let Some(config) = &self.config
|
||||
&& let Some(servers) = enabled_mcp_servers {
|
||||
if servers == "all" {
|
||||
config.mcp_servers.keys().cloned().collect()
|
||||
} else {
|
||||
let enabled_servers: HashSet<String> =
|
||||
servers.split(',').map(|s| s.trim().to_string()).collect();
|
||||
config.mcp_servers
|
||||
.keys()
|
||||
.filter(|id| enabled_servers.contains(*id))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stop_unused_servers(&mut self, keep_ids: &HashSet<String>) -> Result<()> {
|
||||
let mut ids_to_remove = Vec::new();
|
||||
for (id, _) in self.servers.iter() {
|
||||
if !keep_ids.contains(id) {
|
||||
ids_to_remove.push(id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
self.servers = HashMap::new();
|
||||
|
||||
Ok(self)
|
||||
for id in ids_to_remove {
|
||||
if let Some(server) = self.servers.remove(&id) {
|
||||
match Arc::try_unwrap(server) {
|
||||
Ok(server_inner) => {
|
||||
server_inner.cancel().await
|
||||
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
|
||||
info!("Stopped MCP server: {id}");
|
||||
}
|
||||
Err(_) => {
|
||||
info!("Detaching from MCP server: {id} (still in use)");
|
||||
}
|
||||
}
|
||||
self.catalogs.remove(&id);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn list_started_servers(&self) -> Vec<String> {
|
||||
|
||||
Reference in New Issue
Block a user