feat: Improved MCP server spinup and spindown when switching contexts or settings in the REPL: Modify existing config rather than stopping all servers always and re-initializing if unnecessary
This commit is contained in:
+64
-44
@@ -158,27 +158,31 @@ impl McpRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn reinit(
|
pub async fn reinit(
|
||||||
registry: McpRegistry,
|
mut registry: McpRegistry,
|
||||||
enabled_mcp_servers: Option<String>,
|
enabled_mcp_servers: Option<String>,
|
||||||
abort_signal: AbortSignal,
|
abort_signal: AbortSignal,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
debug!("Reinitializing MCP registry");
|
debug!("Reinitializing MCP registry");
|
||||||
debug!("Stopping all MCP servers");
|
|
||||||
let mut new_registry = abortable_run_with_spinner(
|
let desired_ids = registry.resolve_server_ids(enabled_mcp_servers.clone());
|
||||||
registry.stop_all_servers(),
|
let desired_set: HashSet<String> = desired_ids.iter().cloned().collect();
|
||||||
"Stopping MCP servers",
|
|
||||||
|
debug!("Stopping unused MCP servers");
|
||||||
|
abortable_run_with_spinner(
|
||||||
|
registry.stop_unused_servers(&desired_set),
|
||||||
|
"Stopping unused MCP servers",
|
||||||
abort_signal.clone(),
|
abort_signal.clone(),
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
abortable_run_with_spinner(
|
abortable_run_with_spinner(
|
||||||
new_registry.start_select_mcp_servers(enabled_mcp_servers),
|
registry.start_select_mcp_servers(enabled_mcp_servers),
|
||||||
"Loading MCP servers",
|
"Loading MCP servers",
|
||||||
abort_signal,
|
abort_signal,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(new_registry)
|
Ok(registry)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn start_select_mcp_servers(
|
async fn start_select_mcp_servers(
|
||||||
@@ -192,27 +196,19 @@ impl McpRegistry {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(servers) = enabled_mcp_servers {
|
let desired_ids = self.resolve_server_ids(enabled_mcp_servers);
|
||||||
debug!("Starting selected MCP servers: {:?}", servers);
|
let ids_to_start: Vec<String> = desired_ids.into_iter()
|
||||||
let config = self
|
.filter(|id| !self.servers.contains_key(id))
|
||||||
.config
|
.collect();
|
||||||
.as_ref()
|
|
||||||
.with_context(|| "MCP Config not defined. Cannot start servers")?;
|
|
||||||
let mcp_servers = config.mcp_servers.clone();
|
|
||||||
|
|
||||||
let enabled_servers: HashSet<String> =
|
if ids_to_start.is_empty() {
|
||||||
servers.split(',').map(|s| s.trim().to_string()).collect();
|
return Ok(());
|
||||||
let server_ids: Vec<String> = if servers == "all" {
|
}
|
||||||
mcp_servers.into_keys().collect()
|
|
||||||
} else {
|
debug!("Starting selected MCP servers: {:?}", ids_to_start);
|
||||||
mcp_servers
|
|
||||||
.into_keys()
|
|
||||||
.filter(|id| enabled_servers.contains(id))
|
|
||||||
.collect()
|
|
||||||
};
|
|
||||||
|
|
||||||
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
|
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
|
||||||
server_ids
|
ids_to_start
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|id| async { self.start_server(id).await }),
|
.map(|id| async { self.start_server(id).await }),
|
||||||
)
|
)
|
||||||
@@ -220,15 +216,9 @@ impl McpRegistry {
|
|||||||
.try_collect()
|
.try_collect()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
self.servers = results
|
for (id, server, catalog) in results {
|
||||||
.clone()
|
self.servers.insert(id.clone(), server);
|
||||||
.into_iter()
|
self.catalogs.insert(id, catalog);
|
||||||
.map(|(id, server, _)| (id, server))
|
|
||||||
.collect();
|
|
||||||
self.catalogs = results
|
|
||||||
.into_iter()
|
|
||||||
.map(|(id, _, catalog)| (id, catalog))
|
|
||||||
.collect();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@@ -309,19 +299,49 @@ impl McpRegistry {
|
|||||||
Ok((id.to_string(), service, catalog))
|
Ok((id.to_string(), service, catalog))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn stop_all_servers(mut self) -> Result<Self> {
|
fn resolve_server_ids(&self, enabled_mcp_servers: Option<String>) -> Vec<String> {
|
||||||
for (id, server) in self.servers {
|
if let Some(config) = &self.config
|
||||||
Arc::try_unwrap(server)
|
&& let Some(servers) = enabled_mcp_servers {
|
||||||
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))?
|
if servers == "all" {
|
||||||
.cancel()
|
config.mcp_servers.keys().cloned().collect()
|
||||||
.await
|
} else {
|
||||||
|
let enabled_servers: HashSet<String> =
|
||||||
|
servers.split(',').map(|s| s.trim().to_string()).collect();
|
||||||
|
config.mcp_servers
|
||||||
|
.keys()
|
||||||
|
.filter(|id| enabled_servers.contains(*id))
|
||||||
|
.cloned()
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop_unused_servers(&mut self, keep_ids: &HashSet<String>) -> Result<()> {
|
||||||
|
let mut ids_to_remove = Vec::new();
|
||||||
|
for (id, _) in self.servers.iter() {
|
||||||
|
if !keep_ids.contains(id) {
|
||||||
|
ids_to_remove.push(id.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id in ids_to_remove {
|
||||||
|
if let Some(server) = self.servers.remove(&id) {
|
||||||
|
match Arc::try_unwrap(server) {
|
||||||
|
Ok(server_inner) => {
|
||||||
|
server_inner.cancel().await
|
||||||
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
|
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
|
||||||
info!("Stopped MCP server: {id}");
|
info!("Stopped MCP server: {id}");
|
||||||
}
|
}
|
||||||
|
Err(_) => {
|
||||||
self.servers = HashMap::new();
|
info!("Detaching from MCP server: {id} (still in use)");
|
||||||
|
}
|
||||||
Ok(self)
|
}
|
||||||
|
self.catalogs.remove(&id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn list_started_servers(&self) -> Vec<String> {
|
pub fn list_started_servers(&self) -> Vec<String> {
|
||||||
|
|||||||
Reference in New Issue
Block a user