feat: Secret injection into the MCP configuration

This commit is contained in:
2025-10-15 16:06:59 -06:00
parent df8b326d89
commit 39fc863e22
5 changed files with 4410 additions and 4380 deletions
+504 -505
View File
File diff suppressed because it is too large Load Diff
+2640 -2642
View File
File diff suppressed because it is too large Load Diff
+236 -210
View File
@@ -14,8 +14,9 @@ use std::collections::{HashMap, HashSet};
use std::fs::OpenOptions; use std::fs::OpenOptions;
use std::path::PathBuf; use std::path::PathBuf;
use std::process::Stdio; use std::process::Stdio;
use std::sync::Arc; use std::sync::{Arc};
use tokio::process::Command; use tokio::process::Command;
use crate::vault::SECRET_RE;
pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke"; pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
pub const MCP_LIST_META_FUNCTION_NAME_PREFIX: &str = "mcp_list"; pub const MCP_LIST_META_FUNCTION_NAME_PREFIX: &str = "mcp_list";
@@ -24,267 +25,292 @@ type ConnectedServer = RunningService<RoleClient, ()>;
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
struct McpServersConfig { struct McpServersConfig {
#[serde(rename = "mcpServers")] #[serde(rename = "mcpServers")]
mcp_servers: HashMap<String, McpServer>, mcp_servers: HashMap<String, McpServer>,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
struct McpServer { struct McpServer {
command: String, command: String,
args: Option<Vec<String>>, args: Option<Vec<String>>,
env: Option<HashMap<String, JsonField>>, env: Option<HashMap<String, JsonField>>,
cwd: Option<String>, cwd: Option<String>,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
#[serde(untagged)] #[serde(untagged)]
enum JsonField { enum JsonField {
Str(String), Str(String),
Bool(bool), Bool(bool),
Int(i64), Int(i64),
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct McpRegistry { pub struct McpRegistry {
log_path: Option<PathBuf>, log_path: Option<PathBuf>,
config: Option<McpServersConfig>, config: Option<McpServersConfig>,
servers: HashMap<String, Arc<RunningService<RoleClient, ()>>>, servers: HashMap<String, Arc<RunningService<RoleClient, ()>>>,
} }
impl McpRegistry { impl McpRegistry {
pub async fn init( pub async fn init(
log_path: Option<PathBuf>, log_path: Option<PathBuf>,
start_mcp_servers: bool, start_mcp_servers: bool,
use_mcp_servers: Option<String>, use_mcp_servers: Option<String>,
abort_signal: AbortSignal, abort_signal: AbortSignal,
) -> Result<Self> { config: &Config,
let mut registry = Self { ) -> Result<Self> {
log_path, let mut registry = Self {
..Default::default() log_path,
}; ..Default::default()
if !Config::mcp_config_file().try_exists().with_context(|| { };
format!( if !Config::mcp_config_file().try_exists().with_context(|| {
"Failed to check MCP config file at {}", format!(
Config::mcp_config_file().display() "Failed to check MCP config file at {}",
) Config::mcp_config_file().display()
})? { )
debug!( })? {
debug!(
"MCP config file does not exist at {}, skipping MCP initialization", "MCP config file does not exist at {}, skipping MCP initialization",
Config::mcp_config_file().display() Config::mcp_config_file().display()
); );
return Ok(registry); return Ok(registry);
} }
let err = || { let err = || {
format!( format!(
"Failed to load MCP config file at {}", "Failed to load MCP config file at {}",
Config::mcp_config_file().display() Config::mcp_config_file().display()
) )
}; };
let content = tokio::fs::read_to_string(Config::mcp_config_file()) let content = tokio::fs::read_to_string(Config::mcp_config_file())
.await .await
.with_context(err)?; .with_context(err)?;
let config: McpServersConfig = serde_json::from_str(&content).with_context(err)?; let mut missing_secrets = vec![];
registry.config = Some(config); let parsed_content = SECRET_RE.replace_all(&content, |caps: &fancy_regex::Captures<'_>| {
let secret = config.vault
.get_secret(&caps[1], false);
match secret {
Ok(s) => s,
Err(_) => {
missing_secrets.push(caps[1].to_string());
"".to_string()
}
}
});
if start_mcp_servers { if !missing_secrets.is_empty() {
abortable_run_with_spinner( return Err(anyhow!("MCP config file references secrets that are missing from the vault: {:?}", missing_secrets));
registry.start_select_mcp_servers(use_mcp_servers), }
"Loading MCP servers",
abort_signal,
)
.await?;
}
Ok(registry) let config: McpServersConfig = serde_json::from_str(&parsed_content).with_context(err)?;
} registry.config = Some(config);
pub async fn reinit( if start_mcp_servers {
registry: McpRegistry, abortable_run_with_spinner(
use_mcp_servers: Option<String>, registry.start_select_mcp_servers(use_mcp_servers),
abort_signal: AbortSignal, "Loading MCP servers",
) -> Result<Self> { abort_signal,
debug!("Reinitializing MCP registry"); )
debug!("Stopping all MCP servers"); .await?;
let mut new_registry = abortable_run_with_spinner( }
registry.stop_all_servers(),
"Stopping MCP servers",
abort_signal.clone(),
)
.await?;
abortable_run_with_spinner( Ok(registry)
new_registry.start_select_mcp_servers(use_mcp_servers), }
"Loading MCP servers",
abort_signal,
)
.await?;
Ok(new_registry) pub async fn reinit(
} registry: McpRegistry,
use_mcp_servers: Option<String>,
abort_signal: AbortSignal,
) -> Result<Self> {
debug!("Reinitializing MCP registry");
debug!("Stopping all MCP servers");
let mut new_registry = abortable_run_with_spinner(
registry.stop_all_servers(),
"Stopping MCP servers",
abort_signal.clone(),
)
.await?;
async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option<String>) -> Result<()> { abortable_run_with_spinner(
if self.config.is_none() { new_registry.start_select_mcp_servers(use_mcp_servers),
debug!("MCP config is not present; assuming MCP servers are disabled globally. skipping MCP initialization"); "Loading MCP servers",
return Ok(()); abort_signal,
} )
.await?;
debug!("Starting selected MCP servers: {:?}", use_mcp_servers); Ok(new_registry)
}
if let Some(servers) = use_mcp_servers { async fn start_select_mcp_servers(&mut self, use_mcp_servers: Option<String>) -> Result<()> {
let config = self if self.config.is_none() {
.config debug!("MCP config is not present; assuming MCP servers are disabled globally. Skipping MCP initialization");
.as_ref() return Ok(());
.with_context(|| "MCP Config not defined. Cannot start servers")?; }
let mcp_servers = config.mcp_servers.clone();
let enabled_servers: HashSet<String> = if let Some(servers) = use_mcp_servers {
servers.split(',').map(|s| s.trim().to_string()).collect(); debug!("Starting selected MCP servers: {:?}", servers);
let server_ids: Vec<String> = if servers == "all" { let config = self
mcp_servers.into_keys().collect() .config
} else { .as_ref()
mcp_servers .with_context(|| "MCP Config not defined. Cannot start servers")?;
.into_keys() let mcp_servers = config.mcp_servers.clone();
.filter(|id| enabled_servers.contains(id))
.collect()
};
let results: Vec<(String, Arc<_>)> = stream::iter( let enabled_servers: HashSet<String> =
server_ids servers.split(',').map(|s| s.trim().to_string()).collect();
.into_iter() let server_ids: Vec<String> = if servers == "all" {
.map(|id| async { self.start_server(id).await }), mcp_servers.into_keys().collect()
) } else {
.buffer_unordered(num_cpus::get()) mcp_servers
.try_collect() .into_keys()
.await?; .filter(|id| enabled_servers.contains(id))
.collect()
};
self.servers = results.into_iter().collect(); let results: Vec<(String, Arc<_>)> = stream::iter(
} server_ids
.into_iter()
.map(|id| async { self.start_server(id).await }),
)
.buffer_unordered(num_cpus::get())
.try_collect()
.await?;
Ok(()) self.servers = results.into_iter().collect();
} }
async fn start_server(&self, id: String) -> Result<(String, Arc<ConnectedServer>)> { Ok(())
let server = self }
.config
.as_ref()
.and_then(|c| c.mcp_servers.get(&id))
.with_context(|| format!("MCP server not found in config: {id}"))?;
let mut cmd = Command::new(&server.command);
if let Some(args) = &server.args {
cmd.args(args);
}
if let Some(env) = &server.env {
let env: HashMap<String, String> = env
.iter()
.map(|(k, v)| match v {
JsonField::Str(s) => (k.clone(), s.clone()),
JsonField::Bool(b) => (k.clone(), b.to_string()),
JsonField::Int(i) => (k.clone(), i.to_string()),
})
.collect();
cmd.envs(env);
}
if let Some(cwd) = &server.cwd {
cmd.current_dir(cwd);
}
let transport = if let Some(log_path) = self.log_path.as_ref() { async fn start_server(&self, id: String) -> Result<(String, Arc<ConnectedServer>)> {
cmd.stdin(Stdio::piped()).stdout(Stdio::piped()); let server = self
.config
.as_ref()
.and_then(|c| c.mcp_servers.get(&id))
.with_context(|| format!("MCP server not found in config: {id}"))?;
let mut cmd = Command::new(&server.command);
if let Some(args) = &server.args {
cmd.args(args);
}
if let Some(env) = &server.env {
let env: HashMap<String, String> = env
.iter()
.map(|(k, v)| match v {
JsonField::Str(s) => (k.clone(), s.clone()),
JsonField::Bool(b) => (k.clone(), b.to_string()),
JsonField::Int(i) => (k.clone(), i.to_string()),
})
.collect();
cmd.envs(env);
}
if let Some(cwd) = &server.cwd {
cmd.current_dir(cwd);
}
let log_file = OpenOptions::new() let transport = if let Some(log_path) = self.log_path.as_ref() {
.create(true) cmd.stdin(Stdio::piped()).stdout(Stdio::piped());
.append(true)
.open(log_path)?;
let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?;
transport
} else {
TokioChildProcess::new(cmd)?
};
let service = Arc::new( let log_file = OpenOptions::new()
().serve(transport) .create(true)
.await .append(true)
.with_context(|| format!("Failed to start MCP server: {}", &server.command))?, .open(log_path)?;
); let (transport, _) = TokioChildProcess::builder(cmd).stderr(log_file).spawn()?;
debug!( transport
} else {
TokioChildProcess::new(cmd)?
};
let service = Arc::new(
().serve(transport)
.await
.with_context(|| format!("Failed to start MCP server: {}", &server.command))?,
);
debug!(
"Available tools for MCP server {id}: {:?}", "Available tools for MCP server {id}: {:?}",
service.list_tools(None).await? service.list_tools(None).await?
); );
info!("Started MCP server: {id}"); info!("Started MCP server: {id}");
Ok((id.to_string(), service)) Ok((id.to_string(), service))
} }
pub async fn stop_all_servers(mut self) -> Result<Self> { pub async fn stop_all_servers(mut self) -> Result<Self> {
for (id, server) in self.servers { for (id, server) in self.servers {
Arc::try_unwrap(server) Arc::try_unwrap(server)
.map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))? .map_err(|_| anyhow!("Failed to unwrap Arc for MCP server: {id}"))?
.cancel() .cancel()
.await .await
.with_context(|| format!("Failed to stop MCP server: {id}"))?; .with_context(|| format!("Failed to stop MCP server: {id}"))?;
info!("Stopped MCP server: {id}"); info!("Stopped MCP server: {id}");
} }
self.servers = HashMap::new(); self.servers = HashMap::new();
Ok(self) Ok(self)
} }
pub fn list_servers(&self) -> Vec<String> { pub fn list_started_servers(&self) -> Vec<String> {
self.servers.keys().cloned().collect() self.servers.keys().cloned().collect()
} }
pub fn catalog(&self) -> BoxFuture<'static, Result<Value>> { pub fn list_configured_servers(&self) -> Vec<String> {
let servers: Vec<(String, Arc<ConnectedServer>)> = self if let Some(config) = &self.config {
.servers config.mcp_servers.keys().cloned().collect()
.iter() } else {
.map(|(id, s)| (id.clone(), s.clone())) vec![]
.collect(); }
}
Box::pin(async move { pub fn catalog(&self) -> BoxFuture<'static, Result<Value>> {
let mut out = Vec::with_capacity(servers.len()); let servers: Vec<(String, Arc<ConnectedServer>)> = self
for (id, server) in servers { .servers
let tools = server.list_tools(None).await?; .iter()
let resources = server.list_resources(None).await.unwrap_or_default(); .map(|(id, s)| (id.clone(), s.clone()))
// TODO implement prompt sampling for MCP servers .collect();
// let prompts = server.service.list_prompts(None).await.unwrap_or_default();
out.push(json!({ Box::pin(async move {
let mut out = Vec::with_capacity(servers.len());
for (id, server) in servers {
let tools = server.list_tools(None).await?;
let resources = server.list_resources(None).await.unwrap_or_default();
// TODO implement prompt sampling for MCP servers
// let prompts = server.service.list_prompts(None).await.unwrap_or_default();
out.push(json!({
"server": id, "server": id,
"tools": tools, "tools": tools,
"resources": resources, "resources": resources,
})); }));
} }
Ok(Value::Array(out)) Ok(Value::Array(out))
}) })
} }
pub fn invoke( pub fn invoke(
&self, &self,
server: &str, server: &str,
tool: &str, tool: &str,
arguments: Value, arguments: Value,
) -> BoxFuture<'static, Result<CallToolResult>> { ) -> BoxFuture<'static, Result<CallToolResult>> {
let server = self let server = self
.servers .servers
.get(server) .get(server)
.cloned() .cloned()
.with_context(|| format!("Invoked MCP server does not exist: {server}")); .with_context(|| format!("Invoked MCP server does not exist: {server}"));
let tool = tool.to_owned(); let tool = tool.to_owned();
Box::pin(async move { Box::pin(async move {
let server = server?; let server = server?;
let call_tool_request = CallToolRequestParam { let call_tool_request = CallToolRequestParam {
name: Cow::Owned(tool.to_owned()), name: Cow::Owned(tool.to_owned()),
arguments: arguments.as_object().cloned(), arguments: arguments.as_object().cloned(),
}; };
let result = server.call_tool(call_tool_request).await?; let result = server.call_tool(call_tool_request).await?;
Ok(result) Ok(result)
}) })
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.servers.is_empty() self.servers.is_empty()
} }
} }
+935 -935
View File
File diff suppressed because it is too large Load Diff
+95 -88
View File
@@ -1,125 +1,132 @@
mod utils; mod utils;
use std::sync::LazyLock;
use crate::cli::Cli; use crate::cli::Cli;
use crate::config::Config; use crate::config::Config;
use crate::vault::utils::ensure_password_file_initialized; use crate::vault::utils::ensure_password_file_initialized;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use fancy_regex::Regex;
use gman::providers::local::LocalProvider; use gman::providers::local::LocalProvider;
use gman::providers::SecretProvider; use gman::providers::SecretProvider;
use inquire::{required, Password, PasswordDisplayMode}; use inquire::{required, Password, PasswordDisplayMode};
use tokio::runtime::Handle; use tokio::runtime::Handle;
pub static SECRET_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\{\{(.+)}}").unwrap());
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
pub struct Vault { pub struct Vault {
local_provider: LocalProvider, local_provider: LocalProvider,
} }
impl Vault { impl Vault {
pub fn init(config: &Config) -> Self { pub fn init(config: &Config) -> Self {
let vault_password_file = config.vault_password_file(); let vault_password_file = config.vault_password_file();
let mut local_provider = LocalProvider { let mut local_provider = LocalProvider {
password_file: Some(vault_password_file), password_file: Some(vault_password_file),
git_branch: None, git_branch: None,
..LocalProvider::default() ..LocalProvider::default()
}; };
ensure_password_file_initialized(&mut local_provider) ensure_password_file_initialized(&mut local_provider)
.expect("Failed to initialize password file"); .expect("Failed to initialize password file");
Self { local_provider } Self { local_provider }
} }
pub fn add_secret(&self, secret_name: &str) -> Result<()> { pub fn add_secret(&self, secret_name: &str) -> Result<()> {
let secret_value = Password::new("Enter the secret value:") let secret_value = Password::new("Enter the secret value:")
.with_validator(required!()) .with_validator(required!())
.with_display_mode(PasswordDisplayMode::Masked) .with_display_mode(PasswordDisplayMode::Masked)
.prompt() .prompt()
.with_context(|| "unable to read secret from input")?; .with_context(|| "unable to read secret from input")?;
let h = Handle::current(); let h = Handle::current();
tokio::task::block_in_place(|| { tokio::task::block_in_place(|| {
h.block_on(self.local_provider.set_secret(secret_name, &secret_value)) h.block_on(self.local_provider.set_secret(secret_name, &secret_value))
})?; })?;
println!("✓ Secret '{secret_name}' added to the vault."); println!("✓ Secret '{secret_name}' added to the vault.");
Ok(()) Ok(())
} }
pub fn get_secret(&self, secret_name: &str) -> Result<()> { pub fn get_secret(&self, secret_name: &str, display_output: bool) -> Result<String> {
let h = Handle::current(); let h = Handle::current();
let secret = tokio::task::block_in_place(|| { let secret = tokio::task::block_in_place(|| {
h.block_on(self.local_provider.get_secret(secret_name)) h.block_on(self.local_provider.get_secret(secret_name))
})?; })?;
println!("{}", secret);
Ok(()) if display_output {
} println!("{}", secret);
}
pub fn update_secret(&self, secret_name: &str) -> Result<()> { Ok(secret)
let secret_value = Password::new("Enter the secret value:") }
.with_validator(required!())
.with_display_mode(PasswordDisplayMode::Masked)
.prompt()
.with_context(|| "unable to read secret from input")?;
let h = Handle::current();
tokio::task::block_in_place(|| {
h.block_on(
self.local_provider
.update_secret(secret_name, &secret_value),
)
})?;
println!("✓ Secret '{secret_name}' updated in the vault.");
Ok(()) pub fn update_secret(&self, secret_name: &str) -> Result<()> {
} let secret_value = Password::new("Enter the secret value:")
.with_validator(required!())
.with_display_mode(PasswordDisplayMode::Masked)
.prompt()
.with_context(|| "unable to read secret from input")?;
let h = Handle::current();
tokio::task::block_in_place(|| {
h.block_on(
self.local_provider
.update_secret(secret_name, &secret_value),
)
})?;
println!("✓ Secret '{secret_name}' updated in the vault.");
pub fn delete_secret(&self, secret_name: &str) -> Result<()> { Ok(())
let h = Handle::current(); }
tokio::task::block_in_place(|| h.block_on(self.local_provider.delete_secret(secret_name)))?;
println!("✓ Secret '{secret_name}' deleted from the vault.");
Ok(()) pub fn delete_secret(&self, secret_name: &str) -> Result<()> {
} let h = Handle::current();
tokio::task::block_in_place(|| h.block_on(self.local_provider.delete_secret(secret_name)))?;
println!("✓ Secret '{secret_name}' deleted from the vault.");
pub fn list_secrets(&self, display_output: bool) -> Result<Vec<String>> { Ok(())
let h = Handle::current(); }
let secrets =
tokio::task::block_in_place(|| h.block_on(self.local_provider.list_secrets()))?;
if display_output { pub fn list_secrets(&self, display_output: bool) -> Result<Vec<String>> {
if secrets.is_empty() { let h = Handle::current();
println!("The vault is empty."); let secrets =
} else { tokio::task::block_in_place(|| h.block_on(self.local_provider.list_secrets()))?;
for key in &secrets {
println!("{}", key);
}
}
}
Ok(secrets) if display_output {
} if secrets.is_empty() {
println!("The vault is empty.");
} else {
for key in &secrets {
println!("{}", key);
}
}
}
pub fn handle_vault_flags(cli: Cli, config: Config) -> Result<()> { Ok(secrets)
if let Some(secret_name) = cli.add_secret { }
config.vault.add_secret(&secret_name)?;
}
if let Some(secret_name) = cli.get_secret { pub fn handle_vault_flags(cli: Cli, config: Config) -> Result<()> {
config.vault.get_secret(&secret_name)?; if let Some(secret_name) = cli.add_secret {
} config.vault.add_secret(&secret_name)?;
}
if let Some(secret_name) = cli.update_secret { if let Some(secret_name) = cli.get_secret {
config.vault.update_secret(&secret_name)?; config.vault.get_secret(&secret_name, true)?;
} }
if let Some(secret_name) = cli.delete_secret { if let Some(secret_name) = cli.update_secret {
config.vault.delete_secret(&secret_name)?; config.vault.update_secret(&secret_name)?;
} }
if cli.list_secrets { if let Some(secret_name) = cli.delete_secret {
config.vault.list_secrets(true)?; config.vault.delete_secret(&secret_name)?;
} }
Ok(()) if cli.list_secrets {
} config.vault.list_secrets(true)?;
}
Ok(())
}
} }