feat: add OAuth authentication support for remote MCP servers
This commit is contained in:
+166
-9
@@ -1,3 +1,4 @@
|
||||
pub(crate) mod oauth;
|
||||
mod sse_transport;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
@@ -73,6 +74,8 @@ pub(crate) struct McpServer {
|
||||
pub url: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<IndexMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_client_id: Option<String>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
@@ -107,10 +110,10 @@ impl McpServer {
|
||||
"MCP server '{name}' is missing a \"command\" field (required for stdio transport)"
|
||||
));
|
||||
}
|
||||
if self.url.is_some() || self.headers.is_some() {
|
||||
if self.url.is_some() || self.headers.is_some() || self.oauth_client_id.is_some() {
|
||||
return Err(anyhow!(
|
||||
"MCP server '{name}' has type \"stdio\" but also specifies remote fields \
|
||||
(url/headers). Remove the remote fields or change the type to \"http\" or \"sse\"."
|
||||
(url/headers/oauth_client_id). Remove the remote fields or change the type to \"http\" or \"sse\"."
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -237,7 +240,7 @@ impl McpRegistry {
|
||||
|
||||
debug!("Starting selected MCP servers: {:?}", ids_to_start);
|
||||
|
||||
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
|
||||
let results: Vec<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> = stream::iter(
|
||||
ids_to_start
|
||||
.into_iter()
|
||||
.map(|id| async { self.start_server(id).await }),
|
||||
@@ -246,7 +249,7 @@ impl McpRegistry {
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
for (id, server, catalog) in results {
|
||||
for (id, server, catalog) in results.into_iter().flatten() {
|
||||
self.servers.insert(id.clone(), server);
|
||||
self.catalogs.insert(id, catalog);
|
||||
}
|
||||
@@ -257,14 +260,30 @@ impl McpRegistry {
|
||||
async fn start_server(
|
||||
&self,
|
||||
id: String,
|
||||
) -> Result<(String, Arc<ConnectedServer>, ServerCatalog)> {
|
||||
) -> Result<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> {
|
||||
let spec = self
|
||||
.config
|
||||
.as_ref()
|
||||
.and_then(|c| c.mcp_servers.get(&id))
|
||||
.with_context(|| format!("MCP server not found in config: {id}"))?;
|
||||
|
||||
let service = spawn_mcp_server(spec, self.log_path.as_deref()).await?;
|
||||
let bearer_token = if spec.is_remote() {
|
||||
oauth::load_valid_mcp_token(&id)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let service = match spawn_mcp_server(spec, self.log_path.as_deref(), bearer_token).await {
|
||||
Ok(s) => s,
|
||||
Err(e) if is_auth_required_error(&e) => {
|
||||
warn!(
|
||||
"MCP server '{id}' requires OAuth authentication. \
|
||||
Run `.mcp auth {id}` in the REPL to authenticate, then restart Coyote."
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
let tools = service.list_tools(None).await?;
|
||||
debug!("Available tools for MCP server {id}: {tools:?}");
|
||||
@@ -289,7 +308,7 @@ impl McpRegistry {
|
||||
|
||||
info!("Started MCP server: {id}");
|
||||
|
||||
Ok((id.to_string(), service, catalog))
|
||||
Ok(Some((id.to_string(), service, catalog)))
|
||||
}
|
||||
|
||||
fn resolve_server_ids(&self, enabled_mcp_servers: Option<Vec<String>>) -> Vec<String> {
|
||||
@@ -337,15 +356,18 @@ impl McpRegistry {
|
||||
pub(crate) async fn spawn_mcp_server(
|
||||
spec: &McpServer,
|
||||
log_path: Option<&Path>,
|
||||
bearer_token: Option<String>,
|
||||
) -> Result<Arc<ConnectedServer>> {
|
||||
match spec.transport_type {
|
||||
McpTransportType::Http => {
|
||||
let url = spec.url.as_deref().expect("validated: http spec has url");
|
||||
spawn_http_mcp_server(url, spec.headers.as_ref()).await
|
||||
let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
|
||||
spawn_http_mcp_server(url, headers.as_ref()).await
|
||||
}
|
||||
McpTransportType::Sse => {
|
||||
let url = spec.url.as_deref().expect("validated: sse spec has url");
|
||||
spawn_sse_mcp_server(url, spec.headers.as_ref()).await
|
||||
let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
|
||||
spawn_sse_mcp_server(url, headers.as_ref()).await
|
||||
}
|
||||
McpTransportType::Stdio => {
|
||||
let command = spec
|
||||
@@ -357,6 +379,30 @@ pub(crate) async fn spawn_mcp_server(
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_bearer_token(
|
||||
headers: Option<&IndexMap<String, String>>,
|
||||
bearer_token: Option<String>,
|
||||
) -> Option<IndexMap<String, String>> {
|
||||
match (headers, bearer_token) {
|
||||
(None, None) => None,
|
||||
(Some(h), None) => Some(h.clone()),
|
||||
(None, Some(token)) => {
|
||||
let mut m = IndexMap::new();
|
||||
m.insert("Authorization".to_string(), format!("Bearer {token}"));
|
||||
Some(m)
|
||||
}
|
||||
(Some(h), Some(token)) => {
|
||||
let mut m = h.clone();
|
||||
m.insert("Authorization".to_string(), format!("Bearer {token}"));
|
||||
Some(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_auth_required_error(e: &anyhow::Error) -> bool {
|
||||
e.to_string().contains("Auth required")
|
||||
}
|
||||
|
||||
async fn spawn_http_mcp_server(
|
||||
url: &str,
|
||||
headers: Option<&IndexMap<String, String>>,
|
||||
@@ -465,6 +511,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,6 +524,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some(url.to_string()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,6 +537,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some(url.to_string()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -506,6 +555,7 @@ mod tests {
|
||||
#[test]
|
||||
fn validate_stdio_with_command_succeeds() {
|
||||
let spec = stdio_server("npx");
|
||||
|
||||
assert!(spec.validate("test").is_ok());
|
||||
}
|
||||
|
||||
@@ -519,8 +569,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("missing a \"command\" field"));
|
||||
}
|
||||
|
||||
@@ -534,8 +587,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some("http://localhost".into()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("remote fields"));
|
||||
}
|
||||
|
||||
@@ -551,14 +607,18 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: Some(headers),
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("remote fields"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_http_with_url_succeeds() {
|
||||
let spec = http_server("http://localhost:8080");
|
||||
|
||||
assert!(spec.validate("test").is_ok());
|
||||
}
|
||||
|
||||
@@ -572,8 +632,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("missing a \"url\" field"));
|
||||
}
|
||||
|
||||
@@ -587,8 +650,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some("http://localhost".into()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("stdio fields"));
|
||||
}
|
||||
|
||||
@@ -602,8 +668,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some("http://localhost".into()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("stdio fields"));
|
||||
}
|
||||
|
||||
@@ -617,14 +686,18 @@ mod tests {
|
||||
cwd: Some("/tmp".into()),
|
||||
url: Some("http://localhost".into()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("stdio fields"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_sse_with_url_succeeds() {
|
||||
let spec = sse_server("http://sse.example.com");
|
||||
|
||||
assert!(spec.validate("test").is_ok());
|
||||
}
|
||||
|
||||
@@ -638,8 +711,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("missing a \"url\" field"));
|
||||
}
|
||||
|
||||
@@ -665,9 +741,13 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert!(config.mcp_servers.contains_key("my-server"));
|
||||
|
||||
let spec = &config.mcp_servers["my-server"];
|
||||
|
||||
assert_eq!(spec.transport_type, McpTransportType::Stdio);
|
||||
assert_eq!(spec.command.as_deref(), Some("npx"));
|
||||
assert_eq!(
|
||||
@@ -688,7 +768,9 @@ mod tests {
|
||||
}
|
||||
}"#;
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
let spec = &config.mcp_servers["remote"];
|
||||
|
||||
assert_eq!(spec.transport_type, McpTransportType::Http);
|
||||
assert_eq!(spec.url.as_deref(), Some("http://localhost:8080/mcp"));
|
||||
assert_eq!(
|
||||
@@ -713,7 +795,9 @@ mod tests {
|
||||
}
|
||||
}"#;
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
let env = config.mcp_servers["s"].env.as_ref().unwrap();
|
||||
|
||||
assert!(matches!(env["STR_VAR"], JsonField::Str(ref s) if s == "hello"));
|
||||
assert!(matches!(env["BOOL_VAR"], JsonField::Bool(true)));
|
||||
assert!(matches!(env["INT_VAR"], JsonField::Int(42)));
|
||||
@@ -727,7 +811,9 @@ mod tests {
|
||||
"remote-api": { "type": "http", "url": "http://api.example.com" }
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(config.mcp_servers.len(), 2);
|
||||
assert!(config.mcp_servers.contains_key("github"));
|
||||
assert!(config.mcp_servers.contains_key("remote-api"));
|
||||
@@ -736,7 +822,9 @@ mod tests {
|
||||
#[test]
|
||||
fn deserialize_empty_servers_map() {
|
||||
let json = r#"{ "mcpServers": {} }"#;
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert!(config.mcp_servers.is_empty());
|
||||
}
|
||||
|
||||
@@ -751,77 +839,96 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(config.mcp_servers["s"].cwd.as_deref(), Some("/tmp/work"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_all_returns_all_configured_servers() {
|
||||
let registry = make_registry_with_config(&["github", "slack", "jira"]);
|
||||
|
||||
let mut ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
|
||||
ids.sort();
|
||||
|
||||
assert_eq!(ids, vec!["github", "jira", "slack"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_comma_separated_returns_matching_servers() {
|
||||
let registry = make_registry_with_config(&["github", "slack", "jira"]);
|
||||
|
||||
let mut ids =
|
||||
registry.resolve_server_ids(Some(vec!["github".to_string(), "jira".to_string()]));
|
||||
ids.sort();
|
||||
|
||||
assert_eq!(ids, vec!["github", "jira"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_single_server_name() {
|
||||
let registry = make_registry_with_config(&["github", "slack"]);
|
||||
|
||||
let ids = registry.resolve_server_ids(Some(vec!["slack".to_string()]));
|
||||
|
||||
assert_eq!(ids, vec!["slack"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_none_returns_empty() {
|
||||
let registry = make_registry_with_config(&["github"]);
|
||||
|
||||
let ids = registry.resolve_server_ids(None);
|
||||
|
||||
assert!(ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_no_config_returns_empty() {
|
||||
let registry = McpRegistry::default();
|
||||
|
||||
let ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
|
||||
|
||||
assert!(ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_nonexistent_server_filtered_out() {
|
||||
let registry = make_registry_with_config(&["github"]);
|
||||
|
||||
let ids = registry
|
||||
.resolve_server_ids(Some(vec!["github".to_string(), "nonexistent".to_string()]));
|
||||
|
||||
assert_eq!(ids, vec!["github"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_all_nonexistent_returns_empty() {
|
||||
let registry = make_registry_with_config(&["github"]);
|
||||
|
||||
let ids = registry.resolve_server_ids(Some(vec!["foo".to_string(), "bar".to_string()]));
|
||||
|
||||
assert!(ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_trims_whitespace() {
|
||||
let registry = make_registry_with_config(&["github", "slack"]);
|
||||
|
||||
let mut ids = registry.resolve_server_ids(Some(vec![
|
||||
" github ".to_string(),
|
||||
" slack ".to_string(),
|
||||
]));
|
||||
ids.sort();
|
||||
|
||||
assert_eq!(ids, vec!["github", "slack"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_default_is_empty() {
|
||||
let registry = McpRegistry::default();
|
||||
|
||||
assert!(registry.is_empty());
|
||||
assert!(registry.list_started_servers().is_empty());
|
||||
assert!(registry.mcp_config().is_none());
|
||||
@@ -831,6 +938,7 @@ mod tests {
|
||||
#[test]
|
||||
fn registry_with_config_reports_config() {
|
||||
let registry = make_registry_with_config(&["github"]);
|
||||
|
||||
assert!(registry.mcp_config().is_some());
|
||||
assert!(
|
||||
registry
|
||||
@@ -847,4 +955,53 @@ mod tests {
|
||||
assert_eq!(MCP_SEARCH_META_FUNCTION_NAME_PREFIX, "mcp_search");
|
||||
assert_eq!(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, "mcp_describe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_bearer_token_both_none_returns_none() {
|
||||
assert!(merge_bearer_token(None, None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_bearer_token_headers_only_passes_through() {
|
||||
let mut h = IndexMap::new();
|
||||
h.insert("X-Key".to_string(), "val".to_string());
|
||||
|
||||
let result = merge_bearer_token(Some(&h), None).unwrap();
|
||||
|
||||
assert_eq!(result["X-Key"], "val");
|
||||
assert!(!result.contains_key("Authorization"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_bearer_token_token_only_injects_bearer() {
|
||||
let result = merge_bearer_token(None, Some("tok123".to_string())).unwrap();
|
||||
|
||||
assert_eq!(result["Authorization"], "Bearer tok123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_bearer_token_both_merges_and_overrides_authorization() {
|
||||
let mut h = IndexMap::new();
|
||||
h.insert("Authorization".to_string(), "old".to_string());
|
||||
h.insert("X-Custom".to_string(), "keep".to_string());
|
||||
|
||||
let result = merge_bearer_token(Some(&h), Some("newtoken".to_string())).unwrap();
|
||||
|
||||
assert_eq!(result["Authorization"], "Bearer newtoken");
|
||||
assert_eq!(result["X-Custom"], "keep");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_auth_required_error_matches_rmcp_message() {
|
||||
let e = anyhow!("Auth required, when send initialize request");
|
||||
|
||||
assert!(is_auth_required_error(&e));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_auth_required_error_does_not_match_unrelated() {
|
||||
let e = anyhow!("Connection refused");
|
||||
|
||||
assert!(!is_auth_required_error(&e));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user