From 16f324cefc0ab632aaaab0b3e2f247d90f4e5712 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Thu, 2 Jul 2026 14:43:24 -0600 Subject: [PATCH] feat: add OAuth authentication support for remote MCP servers --- src/client/oauth.rs | 14 +- src/config/mcp_factory.rs | 13 +- src/config/request_context.rs | 12 ++ src/mcp/mod.rs | 175 +++++++++++++++++- src/mcp/oauth.rs | 325 ++++++++++++++++++++++++++++++++++ src/repl/mod.rs | 62 ++++++- 6 files changed, 582 insertions(+), 19 deletions(-) create mode 100644 src/mcp/oauth.rs diff --git a/src/client/oauth.rs b/src/client/oauth.rs index 0a9555f..e2e9e4d 100644 --- a/src/client/oauth.rs +++ b/src/client/oauth.rs @@ -53,6 +53,10 @@ pub trait OAuthProvider: Send + Sync { fn extra_request_headers(&self) -> Vec<(&str, &str)> { vec![] } + + fn fixed_redirect_uri(&self) -> Option { + None + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -72,14 +76,16 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) -> let state = Uuid::new_v4().to_string(); - let redirect_uri = if provider.uses_localhost_redirect() { + let (redirect_uri, use_callback_listener) = if let Some(fixed) = provider.fixed_redirect_uri() { + (fixed, true) + } else if provider.uses_localhost_redirect() { let listener = TcpListener::bind("127.0.0.1:0")?; let port = listener.local_addr()?.port(); let uri = format!("http://127.0.0.1:{port}/callback"); drop(listener); - uri + (uri, true) } else { - provider.redirect_uri().to_string() + (provider.redirect_uri().to_string(), false) }; let encoded_scopes = urlencoding::encode(provider.scopes()); @@ -112,7 +118,7 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) -> let _ = open::that(&authorize_url); - let (code, returned_state) = if provider.uses_localhost_redirect() { + let (code, returned_state) = if use_callback_listener { listen_for_oauth_callback(&redirect_uri)? } else { let input = Text::new("Paste the authorization code:").prompt()?; diff --git a/src/config/mcp_factory.rs b/src/config/mcp_factory.rs index 1f05401..5d8cca7 100644 --- a/src/config/mcp_factory.rs +++ b/src/config/mcp_factory.rs @@ -1,4 +1,6 @@ -use crate::mcp::{ConnectedServer, JsonField, McpServer, McpTransportType, spawn_mcp_server}; +use crate::mcp::{ + ConnectedServer, JsonField, McpServer, McpTransportType, oauth, spawn_mcp_server, +}; use anyhow::Result; use parking_lot::Mutex; @@ -99,7 +101,12 @@ impl McpFactory { return Ok(existing); } - let handle = spawn_mcp_server(spec, log_path).await?; + let bearer_token = if spec.is_remote() { + oauth::load_valid_mcp_token(name) + } else { + None + }; + let handle = spawn_mcp_server(spec, log_path, bearer_token).await?; self.insert_active(key, &handle); Ok(handle) } @@ -125,6 +132,7 @@ mod tests { cwd: None, url: None, headers: None, + oauth_client_id: None, } } @@ -141,6 +149,7 @@ mod tests { cwd: None, url: Some(url.to_string()), headers, + oauth_client_id: None, } } diff --git a/src/config/request_context.rs b/src/config/request_context.rs index 3d21aa2..7410a15 100644 --- a/src/config/request_context.rs +++ b/src/config/request_context.rs @@ -2340,6 +2340,17 @@ impl RequestContext { } _ => vec![], }; + } else if cmd == ".mcp" && args.first() == Some(&"auth") && args.len() == 2 { + if let Some(mcp_config) = &self.app.mcp_config { + values = super::map_completion_values( + mcp_config + .mcp_servers + .iter() + .filter(|(_, spec)| spec.is_remote()) + .map(|(name, _)| name.clone()) + .collect(), + ); + } } else if (cmd == ".edit" && args.first() == Some(&"skill") && args.len() == 2) || (cmd == ".skill" && args.first() == Some(&"load") && args.len() == 2) { @@ -3687,6 +3698,7 @@ mod tests { cwd: None, url: None, headers: None, + oauth_client_id: None, }, ); } diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 150bcec..57a78de 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod oauth; mod sse_transport; use crate::config::AppConfig; @@ -73,6 +74,8 @@ pub(crate) struct McpServer { pub url: Option, #[serde(skip_serializing_if = "Option::is_none")] pub headers: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_client_id: Option, } impl McpServer { @@ -107,10 +110,10 @@ impl McpServer { "MCP server '{name}' is missing a \"command\" field (required for stdio transport)" )); } - if self.url.is_some() || self.headers.is_some() { + if self.url.is_some() || self.headers.is_some() || self.oauth_client_id.is_some() { return Err(anyhow!( "MCP server '{name}' has type \"stdio\" but also specifies remote fields \ - (url/headers). Remove the remote fields or change the type to \"http\" or \"sse\"." + (url/headers/oauth_client_id). Remove the remote fields or change the type to \"http\" or \"sse\"." )); } } @@ -237,7 +240,7 @@ impl McpRegistry { debug!("Starting selected MCP servers: {:?}", ids_to_start); - let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( + let results: Vec, ServerCatalog)>> = stream::iter( ids_to_start .into_iter() .map(|id| async { self.start_server(id).await }), @@ -246,7 +249,7 @@ impl McpRegistry { .try_collect() .await?; - for (id, server, catalog) in results { + for (id, server, catalog) in results.into_iter().flatten() { self.servers.insert(id.clone(), server); self.catalogs.insert(id, catalog); } @@ -257,14 +260,30 @@ impl McpRegistry { async fn start_server( &self, id: String, - ) -> Result<(String, Arc, ServerCatalog)> { + ) -> Result, ServerCatalog)>> { let spec = self .config .as_ref() .and_then(|c| c.mcp_servers.get(&id)) .with_context(|| format!("MCP server not found in config: {id}"))?; - let service = spawn_mcp_server(spec, self.log_path.as_deref()).await?; + let bearer_token = if spec.is_remote() { + oauth::load_valid_mcp_token(&id) + } else { + None + }; + + let service = match spawn_mcp_server(spec, self.log_path.as_deref(), bearer_token).await { + Ok(s) => s, + Err(e) if is_auth_required_error(&e) => { + warn!( + "MCP server '{id}' requires OAuth authentication. \ + Run `.mcp auth {id}` in the REPL to authenticate, then restart Coyote." + ); + return Ok(None); + } + Err(e) => return Err(e), + }; let tools = service.list_tools(None).await?; debug!("Available tools for MCP server {id}: {tools:?}"); @@ -289,7 +308,7 @@ impl McpRegistry { info!("Started MCP server: {id}"); - Ok((id.to_string(), service, catalog)) + Ok(Some((id.to_string(), service, catalog))) } fn resolve_server_ids(&self, enabled_mcp_servers: Option>) -> Vec { @@ -337,15 +356,18 @@ impl McpRegistry { pub(crate) async fn spawn_mcp_server( spec: &McpServer, log_path: Option<&Path>, + bearer_token: Option, ) -> Result> { match spec.transport_type { McpTransportType::Http => { let url = spec.url.as_deref().expect("validated: http spec has url"); - spawn_http_mcp_server(url, spec.headers.as_ref()).await + let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token); + spawn_http_mcp_server(url, headers.as_ref()).await } McpTransportType::Sse => { let url = spec.url.as_deref().expect("validated: sse spec has url"); - spawn_sse_mcp_server(url, spec.headers.as_ref()).await + let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token); + spawn_sse_mcp_server(url, headers.as_ref()).await } McpTransportType::Stdio => { let command = spec @@ -357,6 +379,30 @@ pub(crate) async fn spawn_mcp_server( } } +fn merge_bearer_token( + headers: Option<&IndexMap>, + bearer_token: Option, +) -> Option> { + match (headers, bearer_token) { + (None, None) => None, + (Some(h), None) => Some(h.clone()), + (None, Some(token)) => { + let mut m = IndexMap::new(); + m.insert("Authorization".to_string(), format!("Bearer {token}")); + Some(m) + } + (Some(h), Some(token)) => { + let mut m = h.clone(); + m.insert("Authorization".to_string(), format!("Bearer {token}")); + Some(m) + } + } +} + +fn is_auth_required_error(e: &anyhow::Error) -> bool { + e.to_string().contains("Auth required") +} + async fn spawn_http_mcp_server( url: &str, headers: Option<&IndexMap>, @@ -465,6 +511,7 @@ mod tests { cwd: None, url: None, headers: None, + oauth_client_id: None, } } @@ -477,6 +524,7 @@ mod tests { cwd: None, url: Some(url.to_string()), headers: None, + oauth_client_id: None, } } @@ -489,6 +537,7 @@ mod tests { cwd: None, url: Some(url.to_string()), headers: None, + oauth_client_id: None, } } @@ -506,6 +555,7 @@ mod tests { #[test] fn validate_stdio_with_command_succeeds() { let spec = stdio_server("npx"); + assert!(spec.validate("test").is_ok()); } @@ -519,8 +569,11 @@ mod tests { cwd: None, url: None, headers: None, + oauth_client_id: None, }; + let err = spec.validate("test").unwrap_err(); + assert!(err.to_string().contains("missing a \"command\" field")); } @@ -534,8 +587,11 @@ mod tests { cwd: None, url: Some("http://localhost".into()), headers: None, + oauth_client_id: None, }; + let err = spec.validate("test").unwrap_err(); + assert!(err.to_string().contains("remote fields")); } @@ -551,14 +607,18 @@ mod tests { cwd: None, url: None, headers: Some(headers), + oauth_client_id: None, }; + let err = spec.validate("test").unwrap_err(); + assert!(err.to_string().contains("remote fields")); } #[test] fn validate_http_with_url_succeeds() { let spec = http_server("http://localhost:8080"); + assert!(spec.validate("test").is_ok()); } @@ -572,8 +632,11 @@ mod tests { cwd: None, url: None, headers: None, + oauth_client_id: None, }; + let err = spec.validate("test").unwrap_err(); + assert!(err.to_string().contains("missing a \"url\" field")); } @@ -587,8 +650,11 @@ mod tests { cwd: None, url: Some("http://localhost".into()), headers: None, + oauth_client_id: None, }; + let err = spec.validate("test").unwrap_err(); + assert!(err.to_string().contains("stdio fields")); } @@ -602,8 +668,11 @@ mod tests { cwd: None, url: Some("http://localhost".into()), headers: None, + oauth_client_id: None, }; + let err = spec.validate("test").unwrap_err(); + assert!(err.to_string().contains("stdio fields")); } @@ -617,14 +686,18 @@ mod tests { cwd: Some("/tmp".into()), url: Some("http://localhost".into()), headers: None, + oauth_client_id: None, }; + let err = spec.validate("test").unwrap_err(); + assert!(err.to_string().contains("stdio fields")); } #[test] fn validate_sse_with_url_succeeds() { let spec = sse_server("http://sse.example.com"); + assert!(spec.validate("test").is_ok()); } @@ -638,8 +711,11 @@ mod tests { cwd: None, url: None, headers: None, + oauth_client_id: None, }; + let err = spec.validate("test").unwrap_err(); + assert!(err.to_string().contains("missing a \"url\" field")); } @@ -665,9 +741,13 @@ mod tests { } } }"#; + let config: McpServersConfig = serde_json::from_str(json).unwrap(); + assert!(config.mcp_servers.contains_key("my-server")); + let spec = &config.mcp_servers["my-server"]; + assert_eq!(spec.transport_type, McpTransportType::Stdio); assert_eq!(spec.command.as_deref(), Some("npx")); assert_eq!( @@ -688,7 +768,9 @@ mod tests { } }"#; let config: McpServersConfig = serde_json::from_str(json).unwrap(); + let spec = &config.mcp_servers["remote"]; + assert_eq!(spec.transport_type, McpTransportType::Http); assert_eq!(spec.url.as_deref(), Some("http://localhost:8080/mcp")); assert_eq!( @@ -713,7 +795,9 @@ mod tests { } }"#; let config: McpServersConfig = serde_json::from_str(json).unwrap(); + let env = config.mcp_servers["s"].env.as_ref().unwrap(); + assert!(matches!(env["STR_VAR"], JsonField::Str(ref s) if s == "hello")); assert!(matches!(env["BOOL_VAR"], JsonField::Bool(true))); assert!(matches!(env["INT_VAR"], JsonField::Int(42))); @@ -727,7 +811,9 @@ mod tests { "remote-api": { "type": "http", "url": "http://api.example.com" } } }"#; + let config: McpServersConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.mcp_servers.len(), 2); assert!(config.mcp_servers.contains_key("github")); assert!(config.mcp_servers.contains_key("remote-api")); @@ -736,7 +822,9 @@ mod tests { #[test] fn deserialize_empty_servers_map() { let json = r#"{ "mcpServers": {} }"#; + let config: McpServersConfig = serde_json::from_str(json).unwrap(); + assert!(config.mcp_servers.is_empty()); } @@ -751,77 +839,96 @@ mod tests { } } }"#; + let config: McpServersConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.mcp_servers["s"].cwd.as_deref(), Some("/tmp/work")); } #[test] fn resolve_all_returns_all_configured_servers() { let registry = make_registry_with_config(&["github", "slack", "jira"]); + let mut ids = registry.resolve_server_ids(Some(vec!["all".to_string()])); ids.sort(); + assert_eq!(ids, vec!["github", "jira", "slack"]); } #[test] fn resolve_comma_separated_returns_matching_servers() { let registry = make_registry_with_config(&["github", "slack", "jira"]); + let mut ids = registry.resolve_server_ids(Some(vec!["github".to_string(), "jira".to_string()])); ids.sort(); + assert_eq!(ids, vec!["github", "jira"]); } #[test] fn resolve_single_server_name() { let registry = make_registry_with_config(&["github", "slack"]); + let ids = registry.resolve_server_ids(Some(vec!["slack".to_string()])); + assert_eq!(ids, vec!["slack"]); } #[test] fn resolve_none_returns_empty() { let registry = make_registry_with_config(&["github"]); + let ids = registry.resolve_server_ids(None); + assert!(ids.is_empty()); } #[test] fn resolve_no_config_returns_empty() { let registry = McpRegistry::default(); + let ids = registry.resolve_server_ids(Some(vec!["all".to_string()])); + assert!(ids.is_empty()); } #[test] fn resolve_nonexistent_server_filtered_out() { let registry = make_registry_with_config(&["github"]); + let ids = registry .resolve_server_ids(Some(vec!["github".to_string(), "nonexistent".to_string()])); + assert_eq!(ids, vec!["github"]); } #[test] fn resolve_all_nonexistent_returns_empty() { let registry = make_registry_with_config(&["github"]); + let ids = registry.resolve_server_ids(Some(vec!["foo".to_string(), "bar".to_string()])); + assert!(ids.is_empty()); } #[test] fn resolve_trims_whitespace() { let registry = make_registry_with_config(&["github", "slack"]); + let mut ids = registry.resolve_server_ids(Some(vec![ " github ".to_string(), " slack ".to_string(), ])); ids.sort(); + assert_eq!(ids, vec!["github", "slack"]); } #[test] fn registry_default_is_empty() { let registry = McpRegistry::default(); + assert!(registry.is_empty()); assert!(registry.list_started_servers().is_empty()); assert!(registry.mcp_config().is_none()); @@ -831,6 +938,7 @@ mod tests { #[test] fn registry_with_config_reports_config() { let registry = make_registry_with_config(&["github"]); + assert!(registry.mcp_config().is_some()); assert!( registry @@ -847,4 +955,53 @@ mod tests { assert_eq!(MCP_SEARCH_META_FUNCTION_NAME_PREFIX, "mcp_search"); assert_eq!(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, "mcp_describe"); } + + #[test] + fn merge_bearer_token_both_none_returns_none() { + assert!(merge_bearer_token(None, None).is_none()); + } + + #[test] + fn merge_bearer_token_headers_only_passes_through() { + let mut h = IndexMap::new(); + h.insert("X-Key".to_string(), "val".to_string()); + + let result = merge_bearer_token(Some(&h), None).unwrap(); + + assert_eq!(result["X-Key"], "val"); + assert!(!result.contains_key("Authorization")); + } + + #[test] + fn merge_bearer_token_token_only_injects_bearer() { + let result = merge_bearer_token(None, Some("tok123".to_string())).unwrap(); + + assert_eq!(result["Authorization"], "Bearer tok123"); + } + + #[test] + fn merge_bearer_token_both_merges_and_overrides_authorization() { + let mut h = IndexMap::new(); + h.insert("Authorization".to_string(), "old".to_string()); + h.insert("X-Custom".to_string(), "keep".to_string()); + + let result = merge_bearer_token(Some(&h), Some("newtoken".to_string())).unwrap(); + + assert_eq!(result["Authorization"], "Bearer newtoken"); + assert_eq!(result["X-Custom"], "keep"); + } + + #[test] + fn is_auth_required_error_matches_rmcp_message() { + let e = anyhow!("Auth required, when send initialize request"); + + assert!(is_auth_required_error(&e)); + } + + #[test] + fn is_auth_required_error_does_not_match_unrelated() { + let e = anyhow!("Connection refused"); + + assert!(!is_auth_required_error(&e)); + } } diff --git a/src/mcp/oauth.rs b/src/mcp/oauth.rs new file mode 100644 index 0000000..0ee0236 --- /dev/null +++ b/src/mcp/oauth.rs @@ -0,0 +1,325 @@ +use crate::client::oauth::{OAuthProvider, load_oauth_tokens, run_oauth_flow}; +use crate::config::paths; +use anyhow::{Context, Result, anyhow}; +use chrono::Utc; +use inquire::Text; +use log::warn; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::net::TcpListener; +use url::Url; + +#[derive(Debug, Deserialize)] +struct ProtectedResourceMetadata { + #[serde(default)] + authorization_servers: Vec, +} + +#[derive(Debug, Deserialize)] +struct OAuthServerMetadata { + authorization_endpoint: String, + token_endpoint: String, + #[serde(default)] + scopes_supported: Vec, + registration_endpoint: Option, +} + +#[derive(Serialize, Deserialize)] +struct McpRegistration { + client_id: String, +} + +struct McpOAuthProvider { + client_id: String, + authorize_url: String, + token_url: String, + scopes: String, + fixed_redirect: String, +} + +impl OAuthProvider for McpOAuthProvider { + fn provider_name(&self) -> &str { + "MCP" + } + + fn client_id(&self) -> &str { + &self.client_id + } + + fn authorize_url(&self) -> &str { + &self.authorize_url + } + + fn token_url(&self) -> &str { + &self.token_url + } + + fn redirect_uri(&self) -> &str { + "" + } + + fn scopes(&self) -> &str { + &self.scopes + } + + fn uses_localhost_redirect(&self) -> bool { + false + } + + fn fixed_redirect_uri(&self) -> Option { + Some(self.fixed_redirect.clone()) + } +} + +pub async fn run_mcp_oauth_flow( + server_name: &str, + server_url: &str, + configured_client_id: Option<&str>, +) -> Result<()> { + let metadata = discover_oauth_metadata(server_url).await?; + + let listener = TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + drop(listener); + let redirect_uri = format!("http://127.0.0.1:{port}/callback"); + + let client_id = if let Some(id) = configured_client_id { + id.to_string() + } else if let Some(cached) = load_registered_client_id(server_name) { + cached + } else if let Some(reg_endpoint) = &metadata.registration_endpoint { + match register_client(reg_endpoint, &redirect_uri).await { + Ok(id) => { + let _ = save_registered_client_id(server_name, &id); + id + } + Err(e) => { + warn!("Dynamic client registration failed: {e}. Falling back to manual entry."); + Text::new("Enter the OAuth client ID for this MCP server:") + .prompt() + .context("Failed to read client ID")? + } + } + } else { + Text::new("Enter the OAuth client ID for this MCP server:") + .prompt() + .context("Failed to read client ID")? + }; + + let provider = McpOAuthProvider { + client_id, + authorize_url: metadata.authorization_endpoint, + token_url: metadata.token_endpoint, + scopes: metadata.scopes_supported.join(" "), + fixed_redirect: redirect_uri, + }; + + run_oauth_flow(&provider, &mcp_token_key(server_name)).await +} + +pub fn load_valid_mcp_token(server_name: &str) -> Option { + let tokens = load_oauth_tokens(&mcp_token_key(server_name))?; + if Utc::now().timestamp() < tokens.expires_at { + Some(tokens.access_token) + } else { + None + } +} + +fn mcp_token_key(server_name: &str) -> String { + format!("mcp_{server_name}") +} + +fn load_registered_client_id(server_name: &str) -> Option { + let path = paths::oauth_tokens_path().join(format!("mcp_{server_name}_registration.json")); + let content = fs::read_to_string(path).ok()?; + let reg: McpRegistration = serde_json::from_str(&content).ok()?; + + Some(reg.client_id) +} + +fn save_registered_client_id(server_name: &str, client_id: &str) -> Result<()> { + let dir = paths::oauth_tokens_path(); + fs::create_dir_all(&dir)?; + + let path = dir.join(format!("mcp_{server_name}_registration.json")); + let reg = McpRegistration { + client_id: client_id.to_string(), + }; + + fs::write(path, serde_json::to_string_pretty(®)?)?; + + Ok(()) +} + +async fn register_client(endpoint: &str, redirect_uri: &str) -> Result { + let body = serde_json::json!({ + "client_name": "Coyote", + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none" + }); + + let response: serde_json::Value = Client::new() + .post(endpoint) + .json(&body) + .send() + .await + .context("Failed to reach registration endpoint")? + .json() + .await + .context("Failed to parse registration response")?; + + response["client_id"] + .as_str() + .ok_or_else(|| anyhow!("Missing client_id in registration response: {response}")) + .map(|s| s.to_string()) +} + +async fn discover_oauth_metadata(server_url: &str) -> Result { + let base = extract_base_url(server_url)?; + let client = Client::new(); + + // RFC 9728: try protected resource metadata first; it points to the auth server + let pr_url = format!("{base}/.well-known/oauth-protected-resource"); + if let Ok(resp) = client.get(&pr_url).send().await + && resp.status().is_success() + && let Ok(pr) = resp.json::().await + && let Some(auth_server) = pr.authorization_servers.first() + { + let as_url = format!("{auth_server}/.well-known/oauth-authorization-server"); + if let Ok(resp) = client.get(&as_url).send().await + && resp.status().is_success() + && let Ok(meta) = resp.json::().await + { + return Ok(meta); + } + } + + let as_url = format!("{base}/.well-known/oauth-authorization-server"); + let resp = client + .get(&as_url) + .send() + .await + .with_context(|| format!("Failed to reach {as_url}"))?; + + if resp.status().is_success() { + return resp + .json::() + .await + .with_context(|| format!("Failed to parse OAuth metadata from {as_url}")); + } + + Err(anyhow!( + "Could not discover OAuth metadata for '{server_url}'.\n\ + Tried:\n {pr_url}\n {as_url}\n\ + Ensure the server supports MCP OAuth discovery, or consult its documentation." + )) +} + +fn extract_base_url(url: &str) -> Result { + let parsed = Url::parse(url).with_context(|| format!("Invalid URL: {url}"))?; + let scheme = parsed.scheme(); + let host = parsed + .host_str() + .ok_or_else(|| anyhow!("No host in URL: {url}"))?; + let port = parsed.port().map(|p| format!(":{p}")).unwrap_or_default(); + + Ok(format!("{scheme}://{host}{port}")) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::get_env_name; + use serial_test::serial; + use std::{ + env, fs, + time::{self, SystemTime}, + }; + + fn with_temp_cache(f: F) { + let unique = SystemTime::now() + .duration_since(time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + let root = env::temp_dir().join(format!("coyote-mcp-oauth-test-{unique}")); + fs::create_dir_all(&root).unwrap(); + let env_key = get_env_name("cache_dir"); + let prev = env::var_os(&env_key); + unsafe { + env::set_var(&env_key, &root); + } + f(); + unsafe { + match prev { + Some(v) => env::set_var(&env_key, v), + None => env::remove_var(&env_key), + } + } + let _ = fs::remove_dir_all(&root); + } + + #[test] + fn extract_base_url_strips_path_and_query() { + let result = extract_base_url("https://mcp.notion.com/mcp?foo=bar").unwrap(); + + assert_eq!(result, "https://mcp.notion.com"); + } + + #[test] + fn extract_base_url_preserves_explicit_port() { + let result = extract_base_url("http://localhost:8080/mcp").unwrap(); + + assert_eq!(result, "http://localhost:8080"); + } + + #[test] + fn extract_base_url_standard_port_omitted() { + let result = extract_base_url("https://example.com/mcp/v1").unwrap(); + + assert_eq!(result, "https://example.com"); + } + + #[test] + fn extract_base_url_rejects_invalid_url() { + assert!(extract_base_url("not-a-url").is_err()); + } + + #[test] + #[serial] + fn registered_client_id_roundtrip() { + with_temp_cache(|| { + save_registered_client_id("notion", "client-xyz-123").unwrap(); + + let loaded = load_registered_client_id("notion"); + + assert_eq!(loaded, Some("client-xyz-123".to_string())); + }); + } + + #[test] + #[serial] + fn load_registered_client_id_returns_none_for_missing() { + with_temp_cache(|| { + let loaded = load_registered_client_id("no-such-server"); + + assert!(loaded.is_none()); + }); + } + + #[test] + #[serial] + fn registered_client_id_second_save_overwrites_first() { + with_temp_cache(|| { + save_registered_client_id("github", "first-id").unwrap(); + save_registered_client_id("github", "second-id").unwrap(); + + let loaded = load_registered_client_id("github"); + + assert_eq!(loaded, Some("second-id".to_string())); + }); + } +} diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 07aeab7..c63270f 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -20,7 +20,7 @@ use crate::utils::{ }; use crate::sandbox::SANDBOX_ENV_FLAG; -use crate::{config, graph, resolve_oauth_client}; +use crate::{config, graph, mcp, resolve_oauth_client}; use anyhow::{Context, Result, bail}; use crossterm::cursor::SetCursorStyle; use fancy_regex::Regex; @@ -49,7 +49,7 @@ pub const DEFAULT_CONTINUATION_PROMPT: &str = indoc! {" 4. Continue with the next pending item now. Call tools immediately." }; -static REPL_COMMANDS: LazyLock<[ReplCommand; 49]> = LazyLock::new(|| { +static REPL_COMMANDS: LazyLock<[ReplCommand; 50]> = LazyLock::new(|| { [ ReplCommand::new(".help", "Show this help guide", AssertState::pass()), ReplCommand::new(".info", "Show system info", AssertState::pass()), @@ -63,6 +63,11 @@ static REPL_COMMANDS: LazyLock<[ReplCommand; 49]> = LazyLock::new(|| { "Authenticate the current model client via OAuth (if configured)", AssertState::pass(), ), + ReplCommand::new( + ".mcp auth", + "Authenticate with an MCP server via OAuth", + AssertState::pass(), + ), ReplCommand::new( ".edit config", "Modify configuration file", @@ -541,6 +546,55 @@ pub async fn run_repl_command( let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?; oauth::run_oauth_flow(&*provider, &client_name).await?; } + ".mcp" => match args { + Some(args) => { + let mut parts = args.splitn(2, char::is_whitespace); + let sub = parts.next().unwrap_or("").trim(); + let rest = parts.next().map(str::trim).unwrap_or(""); + match sub { + "auth" => { + if rest.is_empty() { + println!("Usage: .mcp auth "); + } else { + let server_name = rest; + let server_spec = ctx + .app + .mcp_config + .as_ref() + .and_then(|c| c.mcp_servers.get(server_name)) + .cloned(); + match server_spec { + None => bail!( + "MCP server '{}' not found in config. \ + Check your mcp_config.json.", + server_name + ), + Some(spec) if !spec.is_remote() => bail!( + "MCP server '{}' uses stdio transport; \ + OAuth is only supported for http/sse servers.", + server_name + ), + Some(spec) => { + let url = spec + .url + .as_deref() + .expect("validated: remote spec has url"); + let client_id = spec.oauth_client_id.as_deref(); + mcp::oauth::run_mcp_oauth_flow(server_name, url, client_id) + .await?; + println!( + "Authentication saved. \ + Restart Coyote to connect to '{server_name}'." + ); + } + } + } + } + _ => unknown_command()?, + } + } + None => println!("Usage: .mcp auth "), + }, ".prompt" => match args { Some(text) => { let app = Arc::clone(&ctx.app.config); @@ -1415,8 +1469,8 @@ mod tests { } #[test] - fn repl_commands_has_49_entries() { - assert_eq!(REPL_COMMANDS.len(), 49); + fn repl_commands_has_50_entries() { + assert_eq!(REPL_COMMANDS.len(), 50); } #[test]