feat: add OAuth authentication support for remote MCP servers

This commit is contained in:
2026-07-02 14:43:24 -06:00
parent cc50d39ab4
commit 16f324cefc
6 changed files with 582 additions and 19 deletions
+166 -9
View File
@@ -1,3 +1,4 @@
pub(crate) mod oauth;
mod sse_transport;
use crate::config::AppConfig;
@@ -73,6 +74,8 @@ pub(crate) struct McpServer {
pub url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<IndexMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_client_id: Option<String>,
}
impl McpServer {
@@ -107,10 +110,10 @@ impl McpServer {
"MCP server '{name}' is missing a \"command\" field (required for stdio transport)"
));
}
if self.url.is_some() || self.headers.is_some() {
if self.url.is_some() || self.headers.is_some() || self.oauth_client_id.is_some() {
return Err(anyhow!(
"MCP server '{name}' has type \"stdio\" but also specifies remote fields \
(url/headers). Remove the remote fields or change the type to \"http\" or \"sse\"."
(url/headers/oauth_client_id). Remove the remote fields or change the type to \"http\" or \"sse\"."
));
}
}
@@ -237,7 +240,7 @@ impl McpRegistry {
debug!("Starting selected MCP servers: {:?}", ids_to_start);
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
let results: Vec<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> = stream::iter(
ids_to_start
.into_iter()
.map(|id| async { self.start_server(id).await }),
@@ -246,7 +249,7 @@ impl McpRegistry {
.try_collect()
.await?;
for (id, server, catalog) in results {
for (id, server, catalog) in results.into_iter().flatten() {
self.servers.insert(id.clone(), server);
self.catalogs.insert(id, catalog);
}
@@ -257,14 +260,30 @@ impl McpRegistry {
async fn start_server(
&self,
id: String,
) -> Result<(String, Arc<ConnectedServer>, ServerCatalog)> {
) -> Result<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> {
let spec = self
.config
.as_ref()
.and_then(|c| c.mcp_servers.get(&id))
.with_context(|| format!("MCP server not found in config: {id}"))?;
let service = spawn_mcp_server(spec, self.log_path.as_deref()).await?;
let bearer_token = if spec.is_remote() {
oauth::load_valid_mcp_token(&id)
} else {
None
};
let service = match spawn_mcp_server(spec, self.log_path.as_deref(), bearer_token).await {
Ok(s) => s,
Err(e) if is_auth_required_error(&e) => {
warn!(
"MCP server '{id}' requires OAuth authentication. \
Run `.mcp auth {id}` in the REPL to authenticate, then restart Coyote."
);
return Ok(None);
}
Err(e) => return Err(e),
};
let tools = service.list_tools(None).await?;
debug!("Available tools for MCP server {id}: {tools:?}");
@@ -289,7 +308,7 @@ impl McpRegistry {
info!("Started MCP server: {id}");
Ok((id.to_string(), service, catalog))
Ok(Some((id.to_string(), service, catalog)))
}
fn resolve_server_ids(&self, enabled_mcp_servers: Option<Vec<String>>) -> Vec<String> {
@@ -337,15 +356,18 @@ impl McpRegistry {
pub(crate) async fn spawn_mcp_server(
spec: &McpServer,
log_path: Option<&Path>,
bearer_token: Option<String>,
) -> Result<Arc<ConnectedServer>> {
match spec.transport_type {
McpTransportType::Http => {
let url = spec.url.as_deref().expect("validated: http spec has url");
spawn_http_mcp_server(url, spec.headers.as_ref()).await
let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
spawn_http_mcp_server(url, headers.as_ref()).await
}
McpTransportType::Sse => {
let url = spec.url.as_deref().expect("validated: sse spec has url");
spawn_sse_mcp_server(url, spec.headers.as_ref()).await
let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
spawn_sse_mcp_server(url, headers.as_ref()).await
}
McpTransportType::Stdio => {
let command = spec
@@ -357,6 +379,30 @@ pub(crate) async fn spawn_mcp_server(
}
}
fn merge_bearer_token(
headers: Option<&IndexMap<String, String>>,
bearer_token: Option<String>,
) -> Option<IndexMap<String, String>> {
match (headers, bearer_token) {
(None, None) => None,
(Some(h), None) => Some(h.clone()),
(None, Some(token)) => {
let mut m = IndexMap::new();
m.insert("Authorization".to_string(), format!("Bearer {token}"));
Some(m)
}
(Some(h), Some(token)) => {
let mut m = h.clone();
m.insert("Authorization".to_string(), format!("Bearer {token}"));
Some(m)
}
}
}
fn is_auth_required_error(e: &anyhow::Error) -> bool {
e.to_string().contains("Auth required")
}
async fn spawn_http_mcp_server(
url: &str,
headers: Option<&IndexMap<String, String>>,
@@ -465,6 +511,7 @@ mod tests {
cwd: None,
url: None,
headers: None,
oauth_client_id: None,
}
}
@@ -477,6 +524,7 @@ mod tests {
cwd: None,
url: Some(url.to_string()),
headers: None,
oauth_client_id: None,
}
}
@@ -489,6 +537,7 @@ mod tests {
cwd: None,
url: Some(url.to_string()),
headers: None,
oauth_client_id: None,
}
}
@@ -506,6 +555,7 @@ mod tests {
#[test]
fn validate_stdio_with_command_succeeds() {
let spec = stdio_server("npx");
assert!(spec.validate("test").is_ok());
}
@@ -519,8 +569,11 @@ mod tests {
cwd: None,
url: None,
headers: None,
oauth_client_id: None,
};
let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("missing a \"command\" field"));
}
@@ -534,8 +587,11 @@ mod tests {
cwd: None,
url: Some("http://localhost".into()),
headers: None,
oauth_client_id: None,
};
let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("remote fields"));
}
@@ -551,14 +607,18 @@ mod tests {
cwd: None,
url: None,
headers: Some(headers),
oauth_client_id: None,
};
let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("remote fields"));
}
#[test]
fn validate_http_with_url_succeeds() {
let spec = http_server("http://localhost:8080");
assert!(spec.validate("test").is_ok());
}
@@ -572,8 +632,11 @@ mod tests {
cwd: None,
url: None,
headers: None,
oauth_client_id: None,
};
let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("missing a \"url\" field"));
}
@@ -587,8 +650,11 @@ mod tests {
cwd: None,
url: Some("http://localhost".into()),
headers: None,
oauth_client_id: None,
};
let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("stdio fields"));
}
@@ -602,8 +668,11 @@ mod tests {
cwd: None,
url: Some("http://localhost".into()),
headers: None,
oauth_client_id: None,
};
let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("stdio fields"));
}
@@ -617,14 +686,18 @@ mod tests {
cwd: Some("/tmp".into()),
url: Some("http://localhost".into()),
headers: None,
oauth_client_id: None,
};
let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("stdio fields"));
}
#[test]
fn validate_sse_with_url_succeeds() {
let spec = sse_server("http://sse.example.com");
assert!(spec.validate("test").is_ok());
}
@@ -638,8 +711,11 @@ mod tests {
cwd: None,
url: None,
headers: None,
oauth_client_id: None,
};
let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("missing a \"url\" field"));
}
@@ -665,9 +741,13 @@ mod tests {
}
}
}"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap();
assert!(config.mcp_servers.contains_key("my-server"));
let spec = &config.mcp_servers["my-server"];
assert_eq!(spec.transport_type, McpTransportType::Stdio);
assert_eq!(spec.command.as_deref(), Some("npx"));
assert_eq!(
@@ -688,7 +768,9 @@ mod tests {
}
}"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap();
let spec = &config.mcp_servers["remote"];
assert_eq!(spec.transport_type, McpTransportType::Http);
assert_eq!(spec.url.as_deref(), Some("http://localhost:8080/mcp"));
assert_eq!(
@@ -713,7 +795,9 @@ mod tests {
}
}"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap();
let env = config.mcp_servers["s"].env.as_ref().unwrap();
assert!(matches!(env["STR_VAR"], JsonField::Str(ref s) if s == "hello"));
assert!(matches!(env["BOOL_VAR"], JsonField::Bool(true)));
assert!(matches!(env["INT_VAR"], JsonField::Int(42)));
@@ -727,7 +811,9 @@ mod tests {
"remote-api": { "type": "http", "url": "http://api.example.com" }
}
}"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.mcp_servers.len(), 2);
assert!(config.mcp_servers.contains_key("github"));
assert!(config.mcp_servers.contains_key("remote-api"));
@@ -736,7 +822,9 @@ mod tests {
#[test]
fn deserialize_empty_servers_map() {
let json = r#"{ "mcpServers": {} }"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap();
assert!(config.mcp_servers.is_empty());
}
@@ -751,77 +839,96 @@ mod tests {
}
}
}"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.mcp_servers["s"].cwd.as_deref(), Some("/tmp/work"));
}
#[test]
fn resolve_all_returns_all_configured_servers() {
let registry = make_registry_with_config(&["github", "slack", "jira"]);
let mut ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
ids.sort();
assert_eq!(ids, vec!["github", "jira", "slack"]);
}
#[test]
fn resolve_comma_separated_returns_matching_servers() {
let registry = make_registry_with_config(&["github", "slack", "jira"]);
let mut ids =
registry.resolve_server_ids(Some(vec!["github".to_string(), "jira".to_string()]));
ids.sort();
assert_eq!(ids, vec!["github", "jira"]);
}
#[test]
fn resolve_single_server_name() {
let registry = make_registry_with_config(&["github", "slack"]);
let ids = registry.resolve_server_ids(Some(vec!["slack".to_string()]));
assert_eq!(ids, vec!["slack"]);
}
#[test]
fn resolve_none_returns_empty() {
let registry = make_registry_with_config(&["github"]);
let ids = registry.resolve_server_ids(None);
assert!(ids.is_empty());
}
#[test]
fn resolve_no_config_returns_empty() {
let registry = McpRegistry::default();
let ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
assert!(ids.is_empty());
}
#[test]
fn resolve_nonexistent_server_filtered_out() {
let registry = make_registry_with_config(&["github"]);
let ids = registry
.resolve_server_ids(Some(vec!["github".to_string(), "nonexistent".to_string()]));
assert_eq!(ids, vec!["github"]);
}
#[test]
fn resolve_all_nonexistent_returns_empty() {
let registry = make_registry_with_config(&["github"]);
let ids = registry.resolve_server_ids(Some(vec!["foo".to_string(), "bar".to_string()]));
assert!(ids.is_empty());
}
#[test]
fn resolve_trims_whitespace() {
let registry = make_registry_with_config(&["github", "slack"]);
let mut ids = registry.resolve_server_ids(Some(vec![
" github ".to_string(),
" slack ".to_string(),
]));
ids.sort();
assert_eq!(ids, vec!["github", "slack"]);
}
#[test]
fn registry_default_is_empty() {
let registry = McpRegistry::default();
assert!(registry.is_empty());
assert!(registry.list_started_servers().is_empty());
assert!(registry.mcp_config().is_none());
@@ -831,6 +938,7 @@ mod tests {
#[test]
fn registry_with_config_reports_config() {
let registry = make_registry_with_config(&["github"]);
assert!(registry.mcp_config().is_some());
assert!(
registry
@@ -847,4 +955,53 @@ mod tests {
assert_eq!(MCP_SEARCH_META_FUNCTION_NAME_PREFIX, "mcp_search");
assert_eq!(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, "mcp_describe");
}
#[test]
fn merge_bearer_token_both_none_returns_none() {
assert!(merge_bearer_token(None, None).is_none());
}
#[test]
fn merge_bearer_token_headers_only_passes_through() {
let mut h = IndexMap::new();
h.insert("X-Key".to_string(), "val".to_string());
let result = merge_bearer_token(Some(&h), None).unwrap();
assert_eq!(result["X-Key"], "val");
assert!(!result.contains_key("Authorization"));
}
#[test]
fn merge_bearer_token_token_only_injects_bearer() {
let result = merge_bearer_token(None, Some("tok123".to_string())).unwrap();
assert_eq!(result["Authorization"], "Bearer tok123");
}
#[test]
fn merge_bearer_token_both_merges_and_overrides_authorization() {
let mut h = IndexMap::new();
h.insert("Authorization".to_string(), "old".to_string());
h.insert("X-Custom".to_string(), "keep".to_string());
let result = merge_bearer_token(Some(&h), Some("newtoken".to_string())).unwrap();
assert_eq!(result["Authorization"], "Bearer newtoken");
assert_eq!(result["X-Custom"], "keep");
}
#[test]
fn is_auth_required_error_matches_rmcp_message() {
let e = anyhow!("Auth required, when send initialize request");
assert!(is_auth_required_error(&e));
}
#[test]
fn is_auth_required_error_does_not_match_unrelated() {
let e = anyhow!("Connection refused");
assert!(!is_auth_required_error(&e));
}
}