From 7ca94f7d1b3bb50b199b0318a74a7d65ef72f1d6 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Fri, 22 May 2026 16:30:45 -0600 Subject: [PATCH] feat: Added MCP config merging support for remote asset installations --- src/config/install_remote.rs | 369 +++++++++++++++++++++++++++++++++- src/config/mcp_factory.rs | 13 +- src/config/request_context.rs | 4 +- src/mcp/mod.rs | 29 ++- src/mcp/sse_transport.rs | 7 +- 5 files changed, 391 insertions(+), 31 deletions(-) diff --git a/src/config/install_remote.rs b/src/config/install_remote.rs index f29c71a..4f85c26 100644 --- a/src/config/install_remote.rs +++ b/src/config/install_remote.rs @@ -1,4 +1,5 @@ use anyhow::{Context, Result, bail}; +use indexmap::IndexMap; use inquire::Select; use std::ffi::{OsStr, OsString}; use std::fs; @@ -6,8 +7,10 @@ use std::path::{Path, PathBuf}; use crate::config::{InstallFilter, paths}; use crate::function::Language; +use crate::mcp::{McpServer, McpServersConfig}; use crate::utils; use crate::utils::IS_STDOUT_TERMINAL; +use crate::vault::{Vault, interpolate_secrets}; pub fn install_remote(git_url: &str, filter: Option, force: bool) -> Result<()> { let (url, reference) = parse_url_with_ref(git_url)?; @@ -32,11 +35,10 @@ pub fn install_remote(git_url: &str, filter: Option, force: bool) apply_plan(&plan, force)?; } - if plan.skipped_mcp_json.is_some() { - println!( - "\nNote: functions/mcp.json detected but MCP merge is not yet wired up \ - (Step 3 of the install-remote rollout)." - ); + if let Some((remote_mcp, local_mcp)) = &plan.mcp_json { + let local = local_mcp.exists().then_some(local_mcp.as_path()); + let report = merge_mcp_json(local, remote_mcp, local_mcp, force)?; + print_mcp_merge_report(&report); } Ok(()) @@ -335,7 +337,7 @@ struct PlannedFile { struct InstallPlan { files: Vec, - skipped_mcp_json: Option<(PathBuf, PathBuf)>, + mcp_json: Option<(PathBuf, PathBuf)>, } fn plan_changes(layout: &RemoteLayout) -> Result { @@ -369,15 +371,12 @@ fn plan_changes(layout: &RemoteLayout) -> Result { )?; } - let skipped_mcp_json = layout + let mcp_json = layout .mcp_json .as_ref() .map(|src| (src.clone(), paths::mcp_config_file())); - Ok(InstallPlan { - files, - skipped_mcp_json, - }) + Ok(InstallPlan { files, mcp_json }) } fn plan_dir_into( @@ -611,6 +610,217 @@ fn set_executable_bit_if_script(_path: &Path) -> Result<()> { Ok(()) } +#[derive(Debug)] +struct McpMergeReport { + added: Vec, + kept_local: Vec, + replaced: Vec, + renamed: Vec<(String, String)>, + final_path: PathBuf, + missing_secrets: Vec, +} + +enum McpConflictAction { + KeepLocal, + TakeRemote, + RenameRemote, +} + +fn merge_mcp_json( + local: Option<&Path>, + remote: &Path, + target: &Path, + force: bool, +) -> Result { + let remote_content = fs::read_to_string(remote) + .with_context(|| format!("failed to read remote mcp.json at {}", remote.display()))?; + let remote_config: McpServersConfig = serde_json::from_str(&remote_content) + .with_context(|| format!("failed to parse remote mcp.json at {}", remote.display()))?; + + let mut merged = if let Some(local_path) = local { + let content = fs::read_to_string(local_path).with_context(|| { + format!("failed to read local mcp.json at {}", local_path.display()) + })?; + serde_json::from_str::(&content).with_context(|| { + format!("failed to parse local mcp.json at {}", local_path.display()) + })? + } else { + McpServersConfig { + mcp_servers: IndexMap::new(), + } + }; + + let final_path = target.to_path_buf(); + let mut report = McpMergeReport { + added: Vec::new(), + kept_local: Vec::new(), + replaced: Vec::new(), + renamed: Vec::new(), + final_path: final_path.clone(), + missing_secrets: Vec::new(), + }; + let mut to_validate: Vec = Vec::new(); + + for (name, remote_server) in remote_config.mcp_servers { + if let Some(local_server) = merged.mcp_servers.get(&name) { + if local_server == &remote_server { + continue; + } + match resolve_mcp_conflict(&name, force)? { + McpConflictAction::KeepLocal => report.kept_local.push(name), + McpConflictAction::TakeRemote => { + merged.mcp_servers.insert(name.clone(), remote_server); + report.replaced.push(name.clone()); + to_validate.push(name); + } + McpConflictAction::RenameRemote => { + let new_name = unique_renamed_key(&name, &merged.mcp_servers); + merged.mcp_servers.insert(new_name.clone(), remote_server); + report.renamed.push((name, new_name.clone())); + to_validate.push(new_name); + } + } + } else { + merged.mcp_servers.insert(name.clone(), remote_server); + report.added.push(name.clone()); + to_validate.push(name); + } + } + + for key in &to_validate { + let spec = merged + .mcp_servers + .get(key) + .expect("entry was just inserted"); + spec.validate(key).with_context(|| { + format!("MCP server '{key}' failed validation; refusing to write merged mcp.json") + })?; + } + + let serialized = + serde_json::to_string_pretty(&merged).context("failed to serialize merged mcp.json")?; + write_atomically(&final_path, &serialized)?; + + let vault = Vault::init_bare(); + let (_parsed, missing) = interpolate_secrets(&serialized, &vault); + let mut deduped: Vec = Vec::new(); + for s in missing { + if !deduped.contains(&s) { + deduped.push(s); + } + } + report.missing_secrets = deduped; + + Ok(report) +} + +fn resolve_mcp_conflict(name: &str, force: bool) -> Result { + if force { + return Ok(McpConflictAction::TakeRemote); + } + if !*IS_STDOUT_TERMINAL { + bail!( + "MCP server '{name}' already exists locally. Refusing to merge non-interactively. \ + Re-run with --install-force or in a terminal." + ); + } + let rename_label = format!("rename remote as \"{name}-remote\""); + let prompt = format!("Conflict on MCP server '{name}'"); + let choice = Select::new( + &prompt, + vec![ + "keep local".to_string(), + "take remote".to_string(), + rename_label.clone(), + "abort merge".to_string(), + ], + ) + .prompt() + .with_context(|| "failed to read MCP conflict choice")?; + + if choice == "keep local" { + Ok(McpConflictAction::KeepLocal) + } else if choice == "take remote" { + Ok(McpConflictAction::TakeRemote) + } else if choice == rename_label { + Ok(McpConflictAction::RenameRemote) + } else if choice == "abort merge" { + bail!("Aborted MCP merge by user.") + } else { + unreachable!("inquire::Select returned an unexpected option") + } +} + +fn unique_renamed_key(name: &str, existing: &IndexMap) -> String { + let base = format!("{name}-remote"); + if !existing.contains_key(&base) { + return base; + } + for i in 2..=u32::MAX { + let candidate = format!("{name}-remote-{i}"); + if !existing.contains_key(&candidate) { + return candidate; + } + } + unreachable!("ran out of suffix variants") +} + +fn write_atomically(path: &Path, content: &str) -> Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("failed to create directory {}", parent.display()))?; + } + let tmp_path = path.with_extension("json.tmp"); + fs::write(&tmp_path, content) + .with_context(|| format!("failed to write {}", tmp_path.display()))?; + fs::rename(&tmp_path, path).with_context(|| { + format!( + "failed to rename {} to {}", + tmp_path.display(), + path.display() + ) + })?; + Ok(()) +} + +fn print_mcp_merge_report(report: &McpMergeReport) { + println!("\nMCP merge ({}):", report.final_path.display()); + println!( + " added: {}, replaced: {}, kept local: {}, renamed: {}", + report.added.len(), + report.replaced.len(), + report.kept_local.len(), + report.renamed.len() + ); + if !report.added.is_empty() { + println!(" + new servers: {}", report.added.join(", ")); + } + if !report.replaced.is_empty() { + println!(" ~ replaced: {}", report.replaced.join(", ")); + } + if !report.kept_local.is_empty() { + println!(" = kept local: {}", report.kept_local.join(", ")); + } + if !report.renamed.is_empty() { + let pairs: Vec = report + .renamed + .iter() + .map(|(orig, new_)| format!("{orig} -> {new_}")) + .collect(); + println!(" > renamed: {}", pairs.join(", ")); + } + if !report.missing_secrets.is_empty() { + println!("\nMissing vault secrets referenced by the merged mcp.json:"); + for name in &report.missing_secrets { + println!(" {{{{ {name} }}}}"); + } + println!( + "\nAdd each missing secret to the vault before starting these MCP servers. \ + For example: `loki --add-secret ` or `.vault add ` in the REPL." + ); + } +} + #[cfg(test)] mod tests { use super::*; @@ -861,4 +1071,141 @@ mod tests { assert_eq!(classify_file(&src, &dst).unwrap(), PlannedKind::Conflict); let _ = fs::remove_dir_all(&dir); } + + fn write_mcp(path: &Path, json: &str) { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).unwrap(); + } + fs::write(path, json).unwrap(); + } + + const FIXTURE_REMOTE: &str = r#"{ + "mcpServers": { + "alpha": {"type": "stdio", "command": "echo", "args": ["a"]}, + "beta": {"type": "stdio", "command": "echo", "args": ["b"]} + } + }"#; + + #[test] + fn unique_renamed_key_appends_remote_suffix() { + let map: IndexMap = IndexMap::new(); + assert_eq!(unique_renamed_key("foo", &map), "foo-remote"); + } + + #[test] + fn unique_renamed_key_appends_numeric_when_remote_taken() { + let mut map: IndexMap = IndexMap::new(); + map.insert( + "foo-remote".to_string(), + serde_json::from_str(r#"{"type":"stdio","command":"x"}"#).unwrap(), + ); + assert_eq!(unique_renamed_key("foo", &map), "foo-remote-2"); + } + + #[test] + fn merge_into_empty_local_adds_all_remote_servers() { + let dir = fresh_temp_dir("merge-empty-"); + let remote = dir.join("remote.json"); + let target = dir.join("target.json"); + write_mcp(&remote, FIXTURE_REMOTE); + + let report = merge_mcp_json(None, &remote, &target, false).unwrap(); + + assert_eq!(report.added, vec!["alpha", "beta"]); + assert!(report.kept_local.is_empty()); + assert!(report.replaced.is_empty()); + assert!(report.renamed.is_empty()); + assert!(target.exists()); + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn merge_force_replaces_local_on_conflict() { + let dir = fresh_temp_dir("merge-force-"); + let remote = dir.join("remote.json"); + let target = dir.join("target.json"); + write_mcp( + &target, + r#"{"mcpServers": {"alpha": {"type": "stdio", "command": "OLD"}}}"#, + ); + write_mcp(&remote, FIXTURE_REMOTE); + + let report = merge_mcp_json(Some(&target), &remote, &target, true).unwrap(); + + assert_eq!(report.added, vec!["beta"]); + assert_eq!(report.replaced, vec!["alpha"]); + + let written = fs::read_to_string(&target).unwrap(); + assert!(written.contains("\"command\": \"echo\""), "got: {written}"); + assert!(!written.contains("OLD")); + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn merge_non_tty_conflict_aborts_without_force() { + let dir = fresh_temp_dir("merge-non-tty-"); + let remote = dir.join("remote.json"); + let target = dir.join("target.json"); + write_mcp( + &target, + r#"{"mcpServers": {"alpha": {"type": "stdio", "command": "LOCAL"}}}"#, + ); + write_mcp(&remote, FIXTURE_REMOTE); + + let err = merge_mcp_json(Some(&target), &remote, &target, false).unwrap_err(); + assert!( + err.to_string() + .contains("Refusing to merge non-interactively"), + "got: {err}" + ); + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn merge_rejects_invalid_remote_server() { + let dir = fresh_temp_dir("merge-invalid-"); + let remote = dir.join("remote.json"); + let target = dir.join("target.json"); + write_mcp(&remote, r#"{"mcpServers": {"broken": {"type": "stdio"}}}"#); + + let err = merge_mcp_json(None, &remote, &target, false).unwrap_err(); + assert!( + format!("{err:#}").contains("missing a \"command\" field"), + "got: {err:#}" + ); + let _ = fs::remove_dir_all(&dir); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn merge_detects_missing_secrets_in_output() { + let dir = fresh_temp_dir("merge-secret-"); + let remote = dir.join("remote.json"); + let target = dir.join("target.json"); + write_mcp( + &remote, + r#"{"mcpServers": {"x": {"type":"stdio","command":"echo","env":{"K":"{{LOKI_TEST_MERGE_SECRET}}"}}}}"#, + ); + + let report = merge_mcp_json(None, &remote, &target, false).unwrap(); + assert_eq!(report.missing_secrets, vec!["LOKI_TEST_MERGE_SECRET"]); + let _ = fs::remove_dir_all(&dir); + } + + #[test] + fn merge_is_idempotent_on_re_run() { + let dir = fresh_temp_dir("merge-idempotent-"); + let remote = dir.join("remote.json"); + let target = dir.join("target.json"); + write_mcp(&remote, FIXTURE_REMOTE); + + merge_mcp_json(None, &remote, &target, false).unwrap(); + let after_first = fs::read(&target).unwrap(); + + let report = merge_mcp_json(Some(&target), &remote, &target, false).unwrap(); + assert!(report.added.is_empty(), "got: {:?}", report.added); + let after_second = fs::read(&target).unwrap(); + + assert_eq!(after_first, after_second); + let _ = fs::remove_dir_all(&dir); + } } diff --git a/src/config/mcp_factory.rs b/src/config/mcp_factory.rs index 9484ba0..1f05401 100644 --- a/src/config/mcp_factory.rs +++ b/src/config/mcp_factory.rs @@ -109,12 +109,13 @@ impl McpFactory { mod tests { use super::*; use crate::mcp::{JsonField, McpServer, McpTransportType}; + use indexmap::IndexMap; use std::collections::HashMap; fn stdio_spec( command: &str, args: Option>, - env: Option>, + env: Option>, ) -> McpServer { McpServer { transport_type: McpTransportType::Stdio, @@ -130,7 +131,7 @@ mod tests { fn remote_spec( transport: McpTransportType, url: &str, - headers: Option>, + headers: Option>, ) -> McpServer { McpServer { transport_type: transport, @@ -145,7 +146,7 @@ mod tests { #[test] fn key_from_stdio_spec_captures_command_args_env() { - let mut env = HashMap::new(); + let mut env = IndexMap::new(); env.insert("TOKEN".into(), JsonField::Str("abc".into())); let spec = stdio_spec("npx", Some(vec!["-y".into(), "server".into()]), Some(env)); let key = McpServerKey::from_spec("my-server", &spec); @@ -163,7 +164,7 @@ mod tests { #[test] fn key_from_stdio_spec_sorts_args_and_env() { - let mut env = HashMap::new(); + let mut env = IndexMap::new(); env.insert("Z_VAR".into(), JsonField::Str("z".into())); env.insert("A_VAR".into(), JsonField::Int(42)); let spec = stdio_spec( @@ -222,7 +223,7 @@ mod tests { #[test] fn key_from_remote_sse_spec_with_sorted_headers() { - let mut hdrs = HashMap::new(); + let mut hdrs = IndexMap::new(); hdrs.insert("Z-Key".into(), "z-val".into()); hdrs.insert("A-Key".into(), "a-val".into()); let spec = remote_spec(McpTransportType::Sse, "http://sse.example.com", Some(hdrs)); @@ -264,7 +265,7 @@ mod tests { #[test] fn key_env_bool_and_int_coerce_to_string() { - let mut env = HashMap::new(); + let mut env = IndexMap::new(); env.insert("FLAG".into(), JsonField::Bool(true)); env.insert("PORT".into(), JsonField::Int(3000)); let spec = stdio_spec("cmd", None, Some(env)); diff --git a/src/config/request_context.rs b/src/config/request_context.rs index 5f30f57..e5be268 100644 --- a/src/config/request_context.rs +++ b/src/config/request_context.rs @@ -29,6 +29,8 @@ use crate::utils::{ use crate::graph; use anyhow::{Context, Error, Result, bail}; +#[cfg(test)] +use indexmap::IndexMap; use indoc::formatdoc; use inquire::{Confirm, MultiSelect, Text, list_option::ListOption, validator::Validation}; use parking_lot::RwLock; @@ -2899,7 +2901,7 @@ mod tests { let mcp_config = if server_names.is_empty() { None } else { - let mut servers = HashMap::new(); + let mut servers = IndexMap::new(); for name in server_names { servers.insert( name.to_string(), diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 5df8af1..61e956a 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -8,6 +8,7 @@ use crate::vault::interpolate_secrets; use anyhow::{Context, Result, anyhow}; use futures_util::{StreamExt, TryStreamExt, stream}; use http::{HeaderName, HeaderValue}; +use indexmap::IndexMap; use indoc::formatdoc; use rmcp::service::RunningService; use rmcp::transport::StreamableHttpClientTransport; @@ -49,23 +50,29 @@ impl Clone for ServerCatalog { } } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize)] pub(crate) struct McpServersConfig { #[serde(rename = "mcpServers")] - pub mcp_servers: HashMap, + pub mcp_servers: IndexMap, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] #[serde(deny_unknown_fields)] pub(crate) struct McpServer { #[serde(rename = "type")] pub transport_type: McpTransportType, + #[serde(skip_serializing_if = "Option::is_none")] pub command: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub args: Option>, - pub env: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, + #[serde(skip_serializing_if = "Option::is_none")] pub cwd: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub url: Option, - pub headers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub headers: Option>, } impl McpServer { @@ -111,7 +118,7 @@ impl McpServer { } } -#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)] #[serde(rename_all = "lowercase")] pub(crate) enum McpTransportType { Stdio, @@ -119,7 +126,7 @@ pub(crate) enum McpTransportType { Sse, } -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] #[serde(untagged)] pub(crate) enum JsonField { Str(String), @@ -352,7 +359,7 @@ pub(crate) async fn spawn_mcp_server( async fn spawn_http_mcp_server( url: &str, - headers: Option<&HashMap>, + headers: Option<&IndexMap>, ) -> Result> { let transport = if let Some(hdrs) = headers && !hdrs.is_empty() @@ -382,7 +389,7 @@ async fn spawn_http_mcp_server( async fn spawn_sse_mcp_server( url: &str, - headers: Option<&HashMap>, + headers: Option<&IndexMap>, ) -> Result> { let sse = LegacySseTransport::connect(url, headers) .await @@ -482,7 +489,7 @@ mod tests { } fn make_registry_with_config(server_names: &[&str]) -> McpRegistry { - let mut mcp_servers = HashMap::new(); + let mut mcp_servers = IndexMap::new(); for name in server_names { mcp_servers.insert(name.to_string(), stdio_server("echo")); } @@ -530,7 +537,7 @@ mod tests { #[test] fn validate_stdio_with_headers_fails() { - let mut headers = HashMap::new(); + let mut headers = IndexMap::new(); headers.insert("Auth".into(), "Bearer tok".into()); let spec = McpServer { transport_type: McpTransportType::Stdio, diff --git a/src/mcp/sse_transport.rs b/src/mcp/sse_transport.rs index 7dad1c2..dbe0bf4 100644 --- a/src/mcp/sse_transport.rs +++ b/src/mcp/sse_transport.rs @@ -3,12 +3,12 @@ use eventsource_stream::{EventStream, Eventsource}; use fmt::{Display, Formatter}; use futures_util::StreamExt; use futures_util::stream::BoxStream; +use indexmap::IndexMap; use mpsc::error::SendError; use mpsc::{OwnedPermit, Receiver, Sender, channel}; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::{Client, header}; use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; -use std::collections::HashMap; use std::error::Error; use std::fmt; use std::future::Future; @@ -28,7 +28,10 @@ pub struct LegacySseTransport { } impl LegacySseTransport { - pub async fn connect(sse_url: &str, headers: Option<&HashMap>) -> Result { + pub async fn connect( + sse_url: &str, + headers: Option<&IndexMap>, + ) -> Result { let base_url = Url::parse(sse_url).with_context(|| format!("Invalid SSE URL: {sse_url}"))?;