feat: Improved MCP handling toggle handling
This commit is contained in:
+246
-234
@@ -1,8 +1,10 @@
|
||||
use crate::config::Config;
|
||||
use crate::utils::{abortable_run_with_spinner, AbortSignal};
|
||||
use crate::vault::SECRET_RE;
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use futures_util::future::BoxFuture;
|
||||
use futures_util::{stream, StreamExt, TryStreamExt};
|
||||
use indoc::formatdoc;
|
||||
use rmcp::model::{CallToolRequestParam, CallToolResult};
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::transport::TokioChildProcess;
|
||||
@@ -14,9 +16,8 @@ use std::collections::{HashMap, HashSet};
|
||||
use std::fs::OpenOptions;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::{Arc};
|
||||
use std::sync::Arc;
|
||||
use tokio::process::Command;
|
||||
use crate::vault::SECRET_RE;
|
||||
|
||||
pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
|
||||
pub const MCP_LIST_META_FUNCTION_NAME_PREFIX: &str = "mcp_list";
|
||||
@@ -25,292 +26,303 @@ type ConnectedServer = RunningService<RoleClient, ()>;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct McpServersConfig {
|
||||
#[serde(rename = "mcpServers")]
|
||||
mcp_servers: HashMap<String, McpServer>,
|
||||
#[serde(rename = "mcpServers")]
|
||||
mcp_servers: HashMap<String, McpServer>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct McpServer {
|
||||
command: String,
|
||||
args: Option<Vec<String>>,
|
||||
env: Option<HashMap<String, JsonField>>,
|
||||
cwd: Option<String>,
|
||||
command: String,
|
||||
args: Option<Vec<String>>,
|
||||
env: Option<HashMap<String, JsonField>>,
|
||||
cwd: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum JsonField {
|
||||
Str(String),
|
||||
Bool(bool),
|
||||
Int(i64),
|
||||
Str(String),
|
||||
Bool(bool),
|
||||
Int(i64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct McpRegistry {
|
||||
log_path: Option<PathBuf>,
|
||||
config: Option<McpServersConfig>,
|
||||
servers: HashMap<String, Arc<RunningService<RoleClient, ()>>>,
|
||||
log_path: Option<PathBuf>,
|
||||
config: Option<McpServersConfig>,
|
||||
servers: HashMap<String, Arc<RunningService<RoleClient, ()>>>,
|
||||
}
|
||||
|
||||
impl McpRegistry {
|
||||
pub async fn init(
|
||||
log_path: Option<PathBuf>,
|
||||
start_mcp_servers: bool,
|
||||
use_mcp_servers: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
config: &Config,
|
||||
) -> Result<Self> {
|
||||
let mut registry = Self {
|
||||
log_path,
|
||||
..Default::default()
|
||||
};
|
||||
if !Config::mcp_config_file().try_exists().with_context(|| {
|
||||
format!(
|
||||
"Failed to check MCP config file at {}",
|
||||
Config::mcp_config_file().display()
|
||||
)
|
||||
})? {
|
||||
debug!(
|
||||
pub async fn init(
|
||||
log_path: Option<PathBuf>,
|
||||
start_mcp_servers: bool,
|
||||
use_mcp_servers: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
config: &Config,
|
||||
) -> Result<Self> {
|
||||
let mut registry = Self {
|
||||
log_path,
|
||||
..Default::default()
|
||||
};
|
||||
if !Config::mcp_config_file().try_exists().with_context(|| {
|
||||
format!(
|
||||
"Failed to check MCP config file at {}",
|
||||
Config::mcp_config_file().display()
|
||||
)
|
||||
})? {
|
||||
debug!(
|
||||
"MCP config file does not exist at {}, skipping MCP initialization",
|
||||
Config::mcp_config_file().display()
|
||||
);
|
||||
return Ok(registry);
|
||||
}
|
||||
let err = || {
|
||||
format!(
|
||||
"Failed to load MCP config file at {}",
|
||||
Config::mcp_config_file().display()
|
||||
)
|
||||
};
|
||||
let content = tokio::fs::read_to_string(Config::mcp_config_file())
|
||||
.await
|
||||
.with_context(err)?;
|
||||
let mut missing_secrets = vec![];
|
||||
let parsed_content = SECRET_RE.replace_all(&content, |caps: &fancy_regex::Captures<'_>| {
|
||||
let secret = config.vault
|
||||
.get_secret(&caps[1], false);
|
||||
match secret {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
missing_secrets.push(caps[1].to_string());
|
||||
"".to_string()
|
||||
}
|
||||
}
|
||||
});
|
||||
return Ok(registry);
|
||||
}
|
||||
let err = || {
|
||||
format!(
|
||||
"Failed to load MCP config file at {}",
|
||||
Config::mcp_config_file().display()
|
||||
)
|
||||
};
|
||||
let content = tokio::fs::read_to_string(Config::mcp_config_file())
|
||||
.await
|
||||
.with_context(err)?;
|
||||
|
||||
if !missing_secrets.is_empty() {
|
||||
return Err(anyhow!("MCP config file references secrets that are missing from the vault: {:?}", missing_secrets));
|
||||
}
|
||||
if content.trim().is_empty() {
|
||||
debug!("MCP config file is empty, skipping MCP initialization");
|
||||
return Ok(registry);
|
||||
}
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(&parsed_content).with_context(err)?;
|
||||
registry.config = Some(config);
|
||||
let mut missing_secrets = vec![];
|
||||
let parsed_content = SECRET_RE.replace_all(&content, |caps: &fancy_regex::Captures<'_>| {
|
||||
let secret = config.vault.get_secret(&caps[1], false);
|
||||
match secret {
|
||||
Ok(s) => s,
|
||||
Err(_) => {
|
||||
missing_secrets.push(caps[1].to_string());
|
||||
"".to_string()
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if start_mcp_servers {
|
||||
abortable_run_with_spinner(
|
||||
registry.start_select_mcp_servers(use_mcp_servers),
|
||||
"Loading MCP servers",
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
if !missing_secrets.is_empty() {
|
||||
return Err(anyhow!(formatdoc!(
|
||||
"
|
||||
MCP config file references secrets that are missing from the vault: {:?}
|
||||
Please add these secrets to the vault and try again.",
|
||||
missing_secrets
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(registry)
|
||||
}
|
||||
let mcp_servers_config: McpServersConfig =
|
||||
serde_json::from_str(&parsed_content).with_context(err)?;
|
||||
registry.config = Some(mcp_servers_config);
|
||||
|
||||
pub async fn reinit(
|
||||
registry: McpRegistry,
|
||||
use_mcp_servers: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
debug!("Reinitializing MCP registry");
|
||||
debug!("Stopping all MCP servers");
|
||||
let mut new_registry = abortable_run_with_spinner(
|
||||
registry.stop_all_servers(),
|
||||
"Stopping MCP servers",
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
if start_mcp_servers && config.mcp_servers {
|
||||
abortable_run_with_spinner(
|
||||
registry.start_select_mcp_servers(use_mcp_servers),
|
||||
"Loading MCP servers",
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
abortable_run_with_spinner(
|
||||
new_registry.start_select_mcp_servers(use_mcp_servers),
|
||||
"Loading MCP servers",
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
Ok(new_registry)
|
||||
}
|
||||
pub async fn reinit(
|
||||
registry: McpRegistry,
|
||||
use_mcp_servers: Option<String>,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
debug!("Reinitializing MCP registry");
|
||||
debug!("Stopping all MCP servers");
|
||||
let mut new_registry = abortable_run_with_spinner(
|
||||
registry.stop_all_servers(),
|
||||
"Stopping MCP servers",
|
||||
abort_signal.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option<String>) -> Result<()> {
|
||||
if self.config.is_none() {
|
||||
debug!("MCP config is not present; assuming MCP servers are disabled globally. Skipping MCP initialization");
|
||||
return Ok(());
|
||||
}
|
||||
abortable_run_with_spinner(
|
||||
new_registry.start_select_mcp_servers(use_mcp_servers),
|
||||
"Loading MCP servers",
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
|
||||
if let Some(servers) = use_mcp_servers {
|
||||
debug!("Starting selected MCP servers: {:?}", servers);
|
||||
let config = self
|
||||
.config
|
||||
.as_ref()
|
||||
.with_context(|| "MCP Config not defined. Cannot start servers")?;
|
||||
let mcp_servers = config.mcp_servers.clone();
|
||||
Ok(new_registry)
|
||||
}
|
||||
|
||||
let enabled_servers: HashSet<String> =
|
||||
servers.split(',').map(|s| s.trim().to_string()).collect();
|
||||
let server_ids: Vec<String> = if servers == "all" {
|
||||
mcp_servers.into_keys().collect()
|
||||
} else {
|
||||
mcp_servers
|
||||
.into_keys()
|
||||
.filter(|id| enabled_servers.contains(id))
|
||||
.collect()
|
||||
};
|
||||
async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option<String>) -> Result<()> {
|
||||
if self.config.is_none() {
|
||||
debug!("MCP config is not present; assuming MCP servers are disabled globally. Skipping MCP initialization");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let results: Vec<(String, Arc<_>)> = stream::iter(
|
||||
server_ids
|
||||
.into_iter()
|
||||
.map(|id| async { self.start_server(id).await }),
|
||||
)
|
||||
.buffer_unordered(num_cpus::get())
|
||||
.try_collect()
|
||||
.await?;
|
||||
if let Some(servers) = use_mcp_servers {
|
||||
debug!("Starting selected MCP servers: {:?}", servers);
|
||||
let config = self
|
||||
.config
|
||||
.as_ref()
|
||||
.with_context(|| "MCP Config not defined. Cannot start servers")?;
|
||||
let mcp_servers = config.mcp_servers.clone();
|
||||
|
||||
self.servers = results.into_iter().collect();
|
||||
}
|
||||
let enabled_servers: HashSet<String> =
|
||||
servers.split(',').map(|s| s.trim().to_string()).collect();
|
||||
let server_ids: Vec<String> = if servers == "all" {
|
||||
mcp_servers.into_keys().collect()
|
||||
} else {
|
||||
mcp_servers
|
||||
.into_keys()
|
||||
.filter(|id| enabled_servers.contains(id))
|
||||
.collect()
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
let results: Vec<(String, Arc<_>)> = stream::iter(
|
||||
server_ids
|
||||
.into_iter()
|
||||
.map(|id| async { self.start_server(id).await }),
|
||||
)
|
||||
.buffer_unordered(num_cpus::get())
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
async fn start_server(&self, id: String) -> Result<(String, Arc<ConnectedServer>)> {
|
||||
let server = self
|
||||
.config
|
||||
.as_ref()
|
||||
.and_then(|c| c.mcp_servers.get(&id))
|
||||
.with_context(|| format!("MCP server not found in config: {id}"))?;
|
||||
let mut cmd = Command::new(&server.command);
|
||||
if let Some(args) = &server.args {
|
||||
cmd.args(args);
|
||||
}
|
||||
if let Some(env) = &server.env {
|
||||
let env: HashMap<String, String> = env
|
||||
.iter()
|
||||
.map(|(k, v)| match v {
|
||||
JsonField::Str(s) => (k.clone(), s.clone()),
|
||||
JsonField::Bool(b) => (k.clone(), b.to_string()),
|
||||
JsonField::Int(i) => (k.clone(), i.to_string()),
|
||||
})
|
||||
.collect();
|
||||
cmd.envs(env);
|
||||
}
|
||||
if let Some(cwd) = &server.cwd {
|
||||
cmd.current_dir(cwd);
|
||||
}
|
||||
self.servers = results.into_iter().collect();
|
||||
}
|
||||
|
||||
let transport = if let Some(log_path) = self.log_path.as_ref() {
|
||||
cmd.stdin(Stdio::piped()).stdout(Stdio::piped());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let log_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(log_path)?;
|
||||
let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?;
|
||||
transport
|
||||
} else {
|
||||
TokioChildProcess::new(cmd)?
|
||||
};
|
||||
async fn start_server(&self, id: String) -> Result<(String, Arc<ConnectedServer>)> {
|
||||
let server = self
|
||||
.config
|
||||
.as_ref()
|
||||
.and_then(|c| c.mcp_servers.get(&id))
|
||||
.with_context(|| format!("MCP server not found in config: {id}"))?;
|
||||
let mut cmd = Command::new(&server.command);
|
||||
if let Some(args) = &server.args {
|
||||
cmd.args(args);
|
||||
}
|
||||
if let Some(env) = &server.env {
|
||||
let env: HashMap<String, String> = env
|
||||
.iter()
|
||||
.map(|(k, v)| match v {
|
||||
JsonField::Str(s) => (k.clone(), s.clone()),
|
||||
JsonField::Bool(b) => (k.clone(), b.to_string()),
|
||||
JsonField::Int(i) => (k.clone(), i.to_string()),
|
||||
})
|
||||
.collect();
|
||||
cmd.envs(env);
|
||||
}
|
||||
if let Some(cwd) = &server.cwd {
|
||||
cmd.current_dir(cwd);
|
||||
}
|
||||
|
||||
let service = Arc::new(
|
||||
().serve(transport)
|
||||
.await
|
||||
.with_context(|| format!("Failed to start MCP server: {}", &server.command))?,
|
||||
);
|
||||
debug!(
|
||||
let transport = if let Some(log_path) = self.log_path.as_ref() {
|
||||
cmd.stdin(Stdio::piped()).stdout(Stdio::piped());
|
||||
|
||||
let log_file = OpenOptions::new()
|
||||
.create(true)
|
||||
.append(true)
|
||||
.open(log_path)?;
|
||||
let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?;
|
||||
transport
|
||||
} else {
|
||||
TokioChildProcess::new(cmd)?
|
||||
};
|
||||
|
||||
let service = Arc::new(
|
||||
().serve(transport)
|
||||
.await
|
||||
.with_context(|| format!("Failed to start MCP server: {}", &server.command))?,
|
||||
);
|
||||
debug!(
|
||||
"Available tools for MCP server {id}: {:?}",
|
||||
service.list_tools(None).await?
|
||||
);
|
||||
|
||||
info!("Started MCP server: {id}");
|
||||
info!("Started MCP server: {id}");
|
||||
|
||||
Ok((id.to_string(), service))
|
||||
}
|
||||
Ok((id.to_string(), service))
|
||||
}
|
||||
|
||||
pub async fn stop_all_servers(mut self) -> Result<Self> {
|
||||
for (id, server) in self.servers {
|
||||
Arc::try_unwrap(server)
|
||||
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))?
|
||||
.cancel()
|
||||
.await
|
||||
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
|
||||
info!("Stopped MCP server: {id}");
|
||||
}
|
||||
pub async fn stop_all_servers(mut self) -> Result<Self> {
|
||||
for (id, server) in self.servers {
|
||||
Arc::try_unwrap(server)
|
||||
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))?
|
||||
.cancel()
|
||||
.await
|
||||
.with_context(|| format!("Failed to stop MCP server: {id}"))?;
|
||||
info!("Stopped MCP server: {id}");
|
||||
}
|
||||
|
||||
self.servers = HashMap::new();
|
||||
self.servers = HashMap::new();
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
pub fn list_started_servers(&self) -> Vec<String> {
|
||||
self.servers.keys().cloned().collect()
|
||||
}
|
||||
pub fn list_started_servers(&self) -> Vec<String> {
|
||||
self.servers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub fn list_configured_servers(&self) -> Vec<String> {
|
||||
if let Some(config) = &self.config {
|
||||
config.mcp_servers.keys().cloned().collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
pub fn list_configured_servers(&self) -> Vec<String> {
|
||||
if let Some(config) = &self.config {
|
||||
config.mcp_servers.keys().cloned().collect()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
pub fn catalog(&self) -> BoxFuture<'static, Result<Value>> {
|
||||
let servers: Vec<(String, Arc<ConnectedServer>)> = self
|
||||
.servers
|
||||
.iter()
|
||||
.map(|(id, s)| (id.clone(), s.clone()))
|
||||
.collect();
|
||||
pub fn catalog(&self) -> BoxFuture<'static, Result<Value>> {
|
||||
let servers: Vec<(String, Arc<ConnectedServer>)> = self
|
||||
.servers
|
||||
.iter()
|
||||
.map(|(id, s)| (id.clone(), s.clone()))
|
||||
.collect();
|
||||
|
||||
Box::pin(async move {
|
||||
let mut out = Vec::with_capacity(servers.len());
|
||||
for (id, server) in servers {
|
||||
let tools = server.list_tools(None).await?;
|
||||
let resources = server.list_resources(None).await.unwrap_or_default();
|
||||
// TODO implement prompt sampling for MCP servers
|
||||
// let prompts = server.service.list_prompts(None).await.unwrap_or_default();
|
||||
out.push(json!({
|
||||
Box::pin(async move {
|
||||
let mut out = Vec::with_capacity(servers.len());
|
||||
for (id, server) in servers {
|
||||
let tools = server.list_tools(None).await?;
|
||||
let resources = server.list_resources(None).await.unwrap_or_default();
|
||||
// TODO implement prompt sampling for MCP servers
|
||||
// let prompts = server.service.list_prompts(None).await.unwrap_or_default();
|
||||
out.push(json!({
|
||||
"server": id,
|
||||
"tools": tools,
|
||||
"resources": resources,
|
||||
}));
|
||||
}
|
||||
Ok(Value::Array(out))
|
||||
})
|
||||
}
|
||||
}
|
||||
Ok(Value::Array(out))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn invoke(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Value,
|
||||
) -> BoxFuture<'static, Result<CallToolResult>> {
|
||||
let server = self
|
||||
.servers
|
||||
.get(server)
|
||||
.cloned()
|
||||
.with_context(|| format!("Invoked MCP server does not exist: {server}"));
|
||||
pub fn invoke(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Value,
|
||||
) -> BoxFuture<'static, Result<CallToolResult>> {
|
||||
let server = self
|
||||
.servers
|
||||
.get(server)
|
||||
.cloned()
|
||||
.with_context(|| format!("Invoked MCP server does not exist: {server}"));
|
||||
|
||||
let tool = tool.to_owned();
|
||||
Box::pin(async move {
|
||||
let server = server?;
|
||||
let call_tool_request = CallToolRequestParam {
|
||||
name: Cow::Owned(tool.to_owned()),
|
||||
arguments: arguments.as_object().cloned(),
|
||||
};
|
||||
let tool = tool.to_owned();
|
||||
Box::pin(async move {
|
||||
let server = server?;
|
||||
let call_tool_request = CallToolRequestParam {
|
||||
name: Cow::Owned(tool.to_owned()),
|
||||
arguments: arguments.as_object().cloned(),
|
||||
};
|
||||
|
||||
let result = server.call_tool(call_tool_request).await?;
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
let result = server.call_tool(call_tool_request).await?;
|
||||
Ok(result)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.servers.is_empty()
|
||||
}
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.servers.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user