Compare commits
2 Commits
cc50d39ab4
...
693e2d9672
| Author | SHA1 | Date | |
|---|---|---|---|
|
693e2d9672
|
|||
|
16f324cefc
|
+29
-1
@@ -5,9 +5,9 @@ use crate::utils::list_file_names;
|
|||||||
use crate::vault::Vault;
|
use crate::vault::Vault;
|
||||||
use clap_complete::{CompletionCandidate, Shell, generate};
|
use clap_complete::{CompletionCandidate, Shell, generate};
|
||||||
use clap_complete_nushell::Nushell;
|
use clap_complete_nushell::Nushell;
|
||||||
use std::env;
|
|
||||||
use std::ffi::OsStr;
|
use std::ffi::OsStr;
|
||||||
use std::io;
|
use std::io;
|
||||||
|
use std::{env, fs};
|
||||||
|
|
||||||
const COYOTE_CLI_NAME: &str = "coyote";
|
const COYOTE_CLI_NAME: &str = "coyote";
|
||||||
|
|
||||||
@@ -134,6 +134,34 @@ pub(super) fn session_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
|||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn mcp_server_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||||
|
let cur = current.to_string_lossy();
|
||||||
|
let content = match fs::read_to_string(paths::mcp_config_file()) {
|
||||||
|
Ok(c) => c,
|
||||||
|
Err(_) => return vec![],
|
||||||
|
};
|
||||||
|
let json: serde_json::Value = match serde_json::from_str(&content) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => return vec![],
|
||||||
|
};
|
||||||
|
let servers = match json.get("mcpServers").and_then(|v| v.as_object()) {
|
||||||
|
Some(s) => s,
|
||||||
|
None => return vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
servers
|
||||||
|
.iter()
|
||||||
|
.filter(|(_, v)| {
|
||||||
|
v.get("type")
|
||||||
|
.and_then(|t| t.as_str())
|
||||||
|
.map(|t| t == "http" || t == "sse")
|
||||||
|
.unwrap_or(false)
|
||||||
|
})
|
||||||
|
.filter(|(k, _)| k.starts_with(&*cur))
|
||||||
|
.map(|(k, _)| CompletionCandidate::new(k))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) fn secrets_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
pub(super) fn secrets_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||||
let cur = current.to_string_lossy();
|
let cur = current.to_string_lossy();
|
||||||
match load_app_config_for_completion() {
|
match load_app_config_for_completion() {
|
||||||
|
|||||||
+5
-2
@@ -1,8 +1,8 @@
|
|||||||
mod completer;
|
mod completer;
|
||||||
|
|
||||||
use crate::cli::completer::{
|
use crate::cli::completer::{
|
||||||
ShellCompletion, agent_completer, macro_completer, model_completer, rag_completer,
|
ShellCompletion, agent_completer, macro_completer, mcp_server_completer, model_completer,
|
||||||
role_completer, secrets_completer, session_completer,
|
rag_completer, role_completer, secrets_completer, session_completer,
|
||||||
};
|
};
|
||||||
use crate::config::{AssetCategory, InstallFilter, MemoryScope};
|
use crate::config::{AssetCategory, InstallFilter, MemoryScope};
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
@@ -171,6 +171,9 @@ pub struct Cli {
|
|||||||
/// Authenticate with an LLM provider using OAuth (e.g., --authenticate client_name)
|
/// Authenticate with an LLM provider using OAuth (e.g., --authenticate client_name)
|
||||||
#[arg(long, exclusive = true, value_name = "CLIENT_NAME")]
|
#[arg(long, exclusive = true, value_name = "CLIENT_NAME")]
|
||||||
pub authenticate: Option<Option<String>>,
|
pub authenticate: Option<Option<String>>,
|
||||||
|
/// Authenticate with an OAuth-protected remote MCP server (e.g., --auth-mcp server_name)
|
||||||
|
#[arg(long, exclusive = true, value_name = "SERVER_NAME", add = ArgValueCompleter::new(mcp_server_completer))]
|
||||||
|
pub auth_mcp: Option<String>,
|
||||||
/// Generate static shell completion scripts
|
/// Generate static shell completion scripts
|
||||||
#[arg(long, value_name = "SHELL", value_enum)]
|
#[arg(long, value_name = "SHELL", value_enum)]
|
||||||
pub completions: Option<ShellCompletion>,
|
pub completions: Option<ShellCompletion>,
|
||||||
|
|||||||
+10
-4
@@ -53,6 +53,10 @@ pub trait OAuthProvider: Send + Sync {
|
|||||||
fn extra_request_headers(&self) -> Vec<(&str, &str)> {
|
fn extra_request_headers(&self) -> Vec<(&str, &str)> {
|
||||||
vec![]
|
vec![]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fixed_redirect_uri(&self) -> Option<String> {
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -72,14 +76,16 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
|
|||||||
|
|
||||||
let state = Uuid::new_v4().to_string();
|
let state = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
let redirect_uri = if provider.uses_localhost_redirect() {
|
let (redirect_uri, use_callback_listener) = if let Some(fixed) = provider.fixed_redirect_uri() {
|
||||||
|
(fixed, true)
|
||||||
|
} else if provider.uses_localhost_redirect() {
|
||||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||||
let port = listener.local_addr()?.port();
|
let port = listener.local_addr()?.port();
|
||||||
let uri = format!("http://127.0.0.1:{port}/callback");
|
let uri = format!("http://127.0.0.1:{port}/callback");
|
||||||
drop(listener);
|
drop(listener);
|
||||||
uri
|
(uri, true)
|
||||||
} else {
|
} else {
|
||||||
provider.redirect_uri().to_string()
|
(provider.redirect_uri().to_string(), false)
|
||||||
};
|
};
|
||||||
|
|
||||||
let encoded_scopes = urlencoding::encode(provider.scopes());
|
let encoded_scopes = urlencoding::encode(provider.scopes());
|
||||||
@@ -112,7 +118,7 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
|
|||||||
|
|
||||||
let _ = open::that(&authorize_url);
|
let _ = open::that(&authorize_url);
|
||||||
|
|
||||||
let (code, returned_state) = if provider.uses_localhost_redirect() {
|
let (code, returned_state) = if use_callback_listener {
|
||||||
listen_for_oauth_callback(&redirect_uri)?
|
listen_for_oauth_callback(&redirect_uri)?
|
||||||
} else {
|
} else {
|
||||||
let input = Text::new("Paste the authorization code:").prompt()?;
|
let input = Text::new("Paste the authorization code:").prompt()?;
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
use crate::mcp::{ConnectedServer, JsonField, McpServer, McpTransportType, spawn_mcp_server};
|
use crate::mcp::{
|
||||||
|
ConnectedServer, JsonField, McpServer, McpTransportType, oauth, spawn_mcp_server,
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
@@ -99,7 +101,12 @@ impl McpFactory {
|
|||||||
return Ok(existing);
|
return Ok(existing);
|
||||||
}
|
}
|
||||||
|
|
||||||
let handle = spawn_mcp_server(spec, log_path).await?;
|
let bearer_token = if spec.is_remote() {
|
||||||
|
oauth::load_valid_mcp_token(name)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let handle = spawn_mcp_server(spec, log_path, bearer_token).await?;
|
||||||
self.insert_active(key, &handle);
|
self.insert_active(key, &handle);
|
||||||
Ok(handle)
|
Ok(handle)
|
||||||
}
|
}
|
||||||
@@ -125,6 +132,7 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: None,
|
url: None,
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,6 +149,7 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: Some(url.to_string()),
|
url: Some(url.to_string()),
|
||||||
headers,
|
headers,
|
||||||
|
oauth_client_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2340,6 +2340,17 @@ impl RequestContext {
|
|||||||
}
|
}
|
||||||
_ => vec![],
|
_ => vec![],
|
||||||
};
|
};
|
||||||
|
} else if cmd == ".mcp" && args.first() == Some(&"auth") && args.len() == 2 {
|
||||||
|
if let Some(mcp_config) = &self.app.mcp_config {
|
||||||
|
values = super::map_completion_values(
|
||||||
|
mcp_config
|
||||||
|
.mcp_servers
|
||||||
|
.iter()
|
||||||
|
.filter(|(_, spec)| spec.is_remote())
|
||||||
|
.map(|(name, _)| name.clone())
|
||||||
|
.collect(),
|
||||||
|
);
|
||||||
|
}
|
||||||
} else if (cmd == ".edit" && args.first() == Some(&"skill") && args.len() == 2)
|
} else if (cmd == ".edit" && args.first() == Some(&"skill") && args.len() == 2)
|
||||||
|| (cmd == ".skill" && args.first() == Some(&"load") && args.len() == 2)
|
|| (cmd == ".skill" && args.first() == Some(&"load") && args.len() == 2)
|
||||||
{
|
{
|
||||||
@@ -3687,6 +3698,7 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: None,
|
url: None,
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
+45
-2
@@ -28,11 +28,12 @@ use crate::config::{
|
|||||||
install_builtins, list_agents, load_env_file, macro_execute, sync_models,
|
install_builtins, list_agents, load_env_file, macro_execute, sync_models,
|
||||||
};
|
};
|
||||||
use crate::function::supervisor::{GuardrailAction, check_pending_agents_guardrail};
|
use crate::function::supervisor::{GuardrailAction, check_pending_agents_guardrail};
|
||||||
|
use crate::mcp::McpServersConfig;
|
||||||
use crate::render::{prompt_theme, render_error};
|
use crate::render::{prompt_theme, render_error};
|
||||||
use crate::repl::Repl;
|
use crate::repl::Repl;
|
||||||
use crate::utils::*;
|
use crate::utils::*;
|
||||||
use crate::vault::Vault;
|
use crate::vault::{Vault, interpolate_secrets};
|
||||||
use anyhow::{Result, anyhow, bail};
|
use anyhow::{Context, Result, anyhow, bail};
|
||||||
use clap::{CommandFactory, Parser};
|
use clap::{CommandFactory, Parser};
|
||||||
use clap_complete::CompleteEnv;
|
use clap_complete::CompleteEnv;
|
||||||
use client::ClientConfig;
|
use client::ClientConfig;
|
||||||
@@ -120,6 +121,48 @@ async fn main() -> Result<()> {
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(server_name) = &cli.auth_mcp {
|
||||||
|
let cfg = Config::load_with_interpolation(true).await?;
|
||||||
|
let app_config = AppConfig::from_config(cfg)?;
|
||||||
|
let vault = Vault::init(&app_config)?;
|
||||||
|
let mcp_path = paths::mcp_config_file();
|
||||||
|
if !mcp_path.exists() {
|
||||||
|
bail!(
|
||||||
|
"No MCP configuration file found at '{}'",
|
||||||
|
mcp_path.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw = tokio::fs::read_to_string(&mcp_path)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Failed to read MCP config at '{}'", mcp_path.display()))?;
|
||||||
|
|
||||||
|
let (content, missing) = interpolate_secrets(&raw, &vault)?;
|
||||||
|
if !missing.is_empty() {
|
||||||
|
bail!(
|
||||||
|
"MCP config references vault secrets that are missing: {:?}",
|
||||||
|
missing
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mcp_config: McpServersConfig =
|
||||||
|
serde_json::from_str(&content).context("Failed to parse MCP config file")?;
|
||||||
|
let spec = mcp_config
|
||||||
|
.mcp_servers
|
||||||
|
.get(server_name.as_str())
|
||||||
|
.ok_or_else(|| anyhow!("MCP server '{server_name}' not found in mcp.json"))?;
|
||||||
|
if !spec.is_remote() {
|
||||||
|
bail!(
|
||||||
|
"MCP server '{server_name}' is a stdio server; OAuth is only supported for http/sse servers"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = spec.url.as_deref().expect("validated: remote spec has url");
|
||||||
|
mcp::oauth::run_mcp_oauth_flow(server_name, url, spec.oauth_client_id.as_deref()).await?;
|
||||||
|
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
if vault_flags {
|
if vault_flags {
|
||||||
let cfg = Config::load_with_interpolation(true).await?;
|
let cfg = Config::load_with_interpolation(true).await?;
|
||||||
let app_config = AppConfig::from_config(cfg)?;
|
let app_config = AppConfig::from_config(cfg)?;
|
||||||
|
|||||||
+166
-9
@@ -1,3 +1,4 @@
|
|||||||
|
pub(crate) mod oauth;
|
||||||
mod sse_transport;
|
mod sse_transport;
|
||||||
|
|
||||||
use crate::config::AppConfig;
|
use crate::config::AppConfig;
|
||||||
@@ -73,6 +74,8 @@ pub(crate) struct McpServer {
|
|||||||
pub url: Option<String>,
|
pub url: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub headers: Option<IndexMap<String, String>>,
|
pub headers: Option<IndexMap<String, String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub oauth_client_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl McpServer {
|
impl McpServer {
|
||||||
@@ -107,10 +110,10 @@ impl McpServer {
|
|||||||
"MCP server '{name}' is missing a \"command\" field (required for stdio transport)"
|
"MCP server '{name}' is missing a \"command\" field (required for stdio transport)"
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
if self.url.is_some() || self.headers.is_some() {
|
if self.url.is_some() || self.headers.is_some() || self.oauth_client_id.is_some() {
|
||||||
return Err(anyhow!(
|
return Err(anyhow!(
|
||||||
"MCP server '{name}' has type \"stdio\" but also specifies remote fields \
|
"MCP server '{name}' has type \"stdio\" but also specifies remote fields \
|
||||||
(url/headers). Remove the remote fields or change the type to \"http\" or \"sse\"."
|
(url/headers/oauth_client_id). Remove the remote fields or change the type to \"http\" or \"sse\"."
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -237,7 +240,7 @@ impl McpRegistry {
|
|||||||
|
|
||||||
debug!("Starting selected MCP servers: {:?}", ids_to_start);
|
debug!("Starting selected MCP servers: {:?}", ids_to_start);
|
||||||
|
|
||||||
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
|
let results: Vec<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> = stream::iter(
|
||||||
ids_to_start
|
ids_to_start
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|id| async { self.start_server(id).await }),
|
.map(|id| async { self.start_server(id).await }),
|
||||||
@@ -246,7 +249,7 @@ impl McpRegistry {
|
|||||||
.try_collect()
|
.try_collect()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
for (id, server, catalog) in results {
|
for (id, server, catalog) in results.into_iter().flatten() {
|
||||||
self.servers.insert(id.clone(), server);
|
self.servers.insert(id.clone(), server);
|
||||||
self.catalogs.insert(id, catalog);
|
self.catalogs.insert(id, catalog);
|
||||||
}
|
}
|
||||||
@@ -257,14 +260,30 @@ impl McpRegistry {
|
|||||||
async fn start_server(
|
async fn start_server(
|
||||||
&self,
|
&self,
|
||||||
id: String,
|
id: String,
|
||||||
) -> Result<(String, Arc<ConnectedServer>, ServerCatalog)> {
|
) -> Result<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> {
|
||||||
let spec = self
|
let spec = self
|
||||||
.config
|
.config
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|c| c.mcp_servers.get(&id))
|
.and_then(|c| c.mcp_servers.get(&id))
|
||||||
.with_context(|| format!("MCP server not found in config: {id}"))?;
|
.with_context(|| format!("MCP server not found in config: {id}"))?;
|
||||||
|
|
||||||
let service = spawn_mcp_server(spec, self.log_path.as_deref()).await?;
|
let bearer_token = if spec.is_remote() {
|
||||||
|
oauth::load_valid_mcp_token(&id)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let service = match spawn_mcp_server(spec, self.log_path.as_deref(), bearer_token).await {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) if is_auth_required_error(&e) => {
|
||||||
|
warn!(
|
||||||
|
"MCP server '{id}' requires OAuth authentication. \
|
||||||
|
Run `.mcp auth {id}` in the REPL to authenticate, then restart Coyote."
|
||||||
|
);
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
};
|
||||||
|
|
||||||
let tools = service.list_tools(None).await?;
|
let tools = service.list_tools(None).await?;
|
||||||
debug!("Available tools for MCP server {id}: {tools:?}");
|
debug!("Available tools for MCP server {id}: {tools:?}");
|
||||||
@@ -289,7 +308,7 @@ impl McpRegistry {
|
|||||||
|
|
||||||
info!("Started MCP server: {id}");
|
info!("Started MCP server: {id}");
|
||||||
|
|
||||||
Ok((id.to_string(), service, catalog))
|
Ok(Some((id.to_string(), service, catalog)))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn resolve_server_ids(&self, enabled_mcp_servers: Option<Vec<String>>) -> Vec<String> {
|
fn resolve_server_ids(&self, enabled_mcp_servers: Option<Vec<String>>) -> Vec<String> {
|
||||||
@@ -337,15 +356,18 @@ impl McpRegistry {
|
|||||||
pub(crate) async fn spawn_mcp_server(
|
pub(crate) async fn spawn_mcp_server(
|
||||||
spec: &McpServer,
|
spec: &McpServer,
|
||||||
log_path: Option<&Path>,
|
log_path: Option<&Path>,
|
||||||
|
bearer_token: Option<String>,
|
||||||
) -> Result<Arc<ConnectedServer>> {
|
) -> Result<Arc<ConnectedServer>> {
|
||||||
match spec.transport_type {
|
match spec.transport_type {
|
||||||
McpTransportType::Http => {
|
McpTransportType::Http => {
|
||||||
let url = spec.url.as_deref().expect("validated: http spec has url");
|
let url = spec.url.as_deref().expect("validated: http spec has url");
|
||||||
spawn_http_mcp_server(url, spec.headers.as_ref()).await
|
let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
|
||||||
|
spawn_http_mcp_server(url, headers.as_ref()).await
|
||||||
}
|
}
|
||||||
McpTransportType::Sse => {
|
McpTransportType::Sse => {
|
||||||
let url = spec.url.as_deref().expect("validated: sse spec has url");
|
let url = spec.url.as_deref().expect("validated: sse spec has url");
|
||||||
spawn_sse_mcp_server(url, spec.headers.as_ref()).await
|
let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
|
||||||
|
spawn_sse_mcp_server(url, headers.as_ref()).await
|
||||||
}
|
}
|
||||||
McpTransportType::Stdio => {
|
McpTransportType::Stdio => {
|
||||||
let command = spec
|
let command = spec
|
||||||
@@ -357,6 +379,30 @@ pub(crate) async fn spawn_mcp_server(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn merge_bearer_token(
|
||||||
|
headers: Option<&IndexMap<String, String>>,
|
||||||
|
bearer_token: Option<String>,
|
||||||
|
) -> Option<IndexMap<String, String>> {
|
||||||
|
match (headers, bearer_token) {
|
||||||
|
(None, None) => None,
|
||||||
|
(Some(h), None) => Some(h.clone()),
|
||||||
|
(None, Some(token)) => {
|
||||||
|
let mut m = IndexMap::new();
|
||||||
|
m.insert("Authorization".to_string(), format!("Bearer {token}"));
|
||||||
|
Some(m)
|
||||||
|
}
|
||||||
|
(Some(h), Some(token)) => {
|
||||||
|
let mut m = h.clone();
|
||||||
|
m.insert("Authorization".to_string(), format!("Bearer {token}"));
|
||||||
|
Some(m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_auth_required_error(e: &anyhow::Error) -> bool {
|
||||||
|
e.to_string().contains("Auth required")
|
||||||
|
}
|
||||||
|
|
||||||
async fn spawn_http_mcp_server(
|
async fn spawn_http_mcp_server(
|
||||||
url: &str,
|
url: &str,
|
||||||
headers: Option<&IndexMap<String, String>>,
|
headers: Option<&IndexMap<String, String>>,
|
||||||
@@ -465,6 +511,7 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: None,
|
url: None,
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -477,6 +524,7 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: Some(url.to_string()),
|
url: Some(url.to_string()),
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -489,6 +537,7 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: Some(url.to_string()),
|
url: Some(url.to_string()),
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -506,6 +555,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn validate_stdio_with_command_succeeds() {
|
fn validate_stdio_with_command_succeeds() {
|
||||||
let spec = stdio_server("npx");
|
let spec = stdio_server("npx");
|
||||||
|
|
||||||
assert!(spec.validate("test").is_ok());
|
assert!(spec.validate("test").is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -519,8 +569,11 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: None,
|
url: None,
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = spec.validate("test").unwrap_err();
|
let err = spec.validate("test").unwrap_err();
|
||||||
|
|
||||||
assert!(err.to_string().contains("missing a \"command\" field"));
|
assert!(err.to_string().contains("missing a \"command\" field"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -534,8 +587,11 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: Some("http://localhost".into()),
|
url: Some("http://localhost".into()),
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = spec.validate("test").unwrap_err();
|
let err = spec.validate("test").unwrap_err();
|
||||||
|
|
||||||
assert!(err.to_string().contains("remote fields"));
|
assert!(err.to_string().contains("remote fields"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -551,14 +607,18 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: None,
|
url: None,
|
||||||
headers: Some(headers),
|
headers: Some(headers),
|
||||||
|
oauth_client_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = spec.validate("test").unwrap_err();
|
let err = spec.validate("test").unwrap_err();
|
||||||
|
|
||||||
assert!(err.to_string().contains("remote fields"));
|
assert!(err.to_string().contains("remote fields"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn validate_http_with_url_succeeds() {
|
fn validate_http_with_url_succeeds() {
|
||||||
let spec = http_server("http://localhost:8080");
|
let spec = http_server("http://localhost:8080");
|
||||||
|
|
||||||
assert!(spec.validate("test").is_ok());
|
assert!(spec.validate("test").is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -572,8 +632,11 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: None,
|
url: None,
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = spec.validate("test").unwrap_err();
|
let err = spec.validate("test").unwrap_err();
|
||||||
|
|
||||||
assert!(err.to_string().contains("missing a \"url\" field"));
|
assert!(err.to_string().contains("missing a \"url\" field"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -587,8 +650,11 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: Some("http://localhost".into()),
|
url: Some("http://localhost".into()),
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = spec.validate("test").unwrap_err();
|
let err = spec.validate("test").unwrap_err();
|
||||||
|
|
||||||
assert!(err.to_string().contains("stdio fields"));
|
assert!(err.to_string().contains("stdio fields"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -602,8 +668,11 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: Some("http://localhost".into()),
|
url: Some("http://localhost".into()),
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = spec.validate("test").unwrap_err();
|
let err = spec.validate("test").unwrap_err();
|
||||||
|
|
||||||
assert!(err.to_string().contains("stdio fields"));
|
assert!(err.to_string().contains("stdio fields"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -617,14 +686,18 @@ mod tests {
|
|||||||
cwd: Some("/tmp".into()),
|
cwd: Some("/tmp".into()),
|
||||||
url: Some("http://localhost".into()),
|
url: Some("http://localhost".into()),
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = spec.validate("test").unwrap_err();
|
let err = spec.validate("test").unwrap_err();
|
||||||
|
|
||||||
assert!(err.to_string().contains("stdio fields"));
|
assert!(err.to_string().contains("stdio fields"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn validate_sse_with_url_succeeds() {
|
fn validate_sse_with_url_succeeds() {
|
||||||
let spec = sse_server("http://sse.example.com");
|
let spec = sse_server("http://sse.example.com");
|
||||||
|
|
||||||
assert!(spec.validate("test").is_ok());
|
assert!(spec.validate("test").is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -638,8 +711,11 @@ mod tests {
|
|||||||
cwd: None,
|
cwd: None,
|
||||||
url: None,
|
url: None,
|
||||||
headers: None,
|
headers: None,
|
||||||
|
oauth_client_id: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let err = spec.validate("test").unwrap_err();
|
let err = spec.validate("test").unwrap_err();
|
||||||
|
|
||||||
assert!(err.to_string().contains("missing a \"url\" field"));
|
assert!(err.to_string().contains("missing a \"url\" field"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -665,9 +741,13 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||||
|
|
||||||
assert!(config.mcp_servers.contains_key("my-server"));
|
assert!(config.mcp_servers.contains_key("my-server"));
|
||||||
|
|
||||||
let spec = &config.mcp_servers["my-server"];
|
let spec = &config.mcp_servers["my-server"];
|
||||||
|
|
||||||
assert_eq!(spec.transport_type, McpTransportType::Stdio);
|
assert_eq!(spec.transport_type, McpTransportType::Stdio);
|
||||||
assert_eq!(spec.command.as_deref(), Some("npx"));
|
assert_eq!(spec.command.as_deref(), Some("npx"));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -688,7 +768,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}"#;
|
}"#;
|
||||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||||
|
|
||||||
let spec = &config.mcp_servers["remote"];
|
let spec = &config.mcp_servers["remote"];
|
||||||
|
|
||||||
assert_eq!(spec.transport_type, McpTransportType::Http);
|
assert_eq!(spec.transport_type, McpTransportType::Http);
|
||||||
assert_eq!(spec.url.as_deref(), Some("http://localhost:8080/mcp"));
|
assert_eq!(spec.url.as_deref(), Some("http://localhost:8080/mcp"));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@@ -713,7 +795,9 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}"#;
|
}"#;
|
||||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||||
|
|
||||||
let env = config.mcp_servers["s"].env.as_ref().unwrap();
|
let env = config.mcp_servers["s"].env.as_ref().unwrap();
|
||||||
|
|
||||||
assert!(matches!(env["STR_VAR"], JsonField::Str(ref s) if s == "hello"));
|
assert!(matches!(env["STR_VAR"], JsonField::Str(ref s) if s == "hello"));
|
||||||
assert!(matches!(env["BOOL_VAR"], JsonField::Bool(true)));
|
assert!(matches!(env["BOOL_VAR"], JsonField::Bool(true)));
|
||||||
assert!(matches!(env["INT_VAR"], JsonField::Int(42)));
|
assert!(matches!(env["INT_VAR"], JsonField::Int(42)));
|
||||||
@@ -727,7 +811,9 @@ mod tests {
|
|||||||
"remote-api": { "type": "http", "url": "http://api.example.com" }
|
"remote-api": { "type": "http", "url": "http://api.example.com" }
|
||||||
}
|
}
|
||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||||
|
|
||||||
assert_eq!(config.mcp_servers.len(), 2);
|
assert_eq!(config.mcp_servers.len(), 2);
|
||||||
assert!(config.mcp_servers.contains_key("github"));
|
assert!(config.mcp_servers.contains_key("github"));
|
||||||
assert!(config.mcp_servers.contains_key("remote-api"));
|
assert!(config.mcp_servers.contains_key("remote-api"));
|
||||||
@@ -736,7 +822,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn deserialize_empty_servers_map() {
|
fn deserialize_empty_servers_map() {
|
||||||
let json = r#"{ "mcpServers": {} }"#;
|
let json = r#"{ "mcpServers": {} }"#;
|
||||||
|
|
||||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||||
|
|
||||||
assert!(config.mcp_servers.is_empty());
|
assert!(config.mcp_servers.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -751,77 +839,96 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}"#;
|
}"#;
|
||||||
|
|
||||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||||
|
|
||||||
assert_eq!(config.mcp_servers["s"].cwd.as_deref(), Some("/tmp/work"));
|
assert_eq!(config.mcp_servers["s"].cwd.as_deref(), Some("/tmp/work"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_all_returns_all_configured_servers() {
|
fn resolve_all_returns_all_configured_servers() {
|
||||||
let registry = make_registry_with_config(&["github", "slack", "jira"]);
|
let registry = make_registry_with_config(&["github", "slack", "jira"]);
|
||||||
|
|
||||||
let mut ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
|
let mut ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
|
||||||
ids.sort();
|
ids.sort();
|
||||||
|
|
||||||
assert_eq!(ids, vec!["github", "jira", "slack"]);
|
assert_eq!(ids, vec!["github", "jira", "slack"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_comma_separated_returns_matching_servers() {
|
fn resolve_comma_separated_returns_matching_servers() {
|
||||||
let registry = make_registry_with_config(&["github", "slack", "jira"]);
|
let registry = make_registry_with_config(&["github", "slack", "jira"]);
|
||||||
|
|
||||||
let mut ids =
|
let mut ids =
|
||||||
registry.resolve_server_ids(Some(vec!["github".to_string(), "jira".to_string()]));
|
registry.resolve_server_ids(Some(vec!["github".to_string(), "jira".to_string()]));
|
||||||
ids.sort();
|
ids.sort();
|
||||||
|
|
||||||
assert_eq!(ids, vec!["github", "jira"]);
|
assert_eq!(ids, vec!["github", "jira"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_single_server_name() {
|
fn resolve_single_server_name() {
|
||||||
let registry = make_registry_with_config(&["github", "slack"]);
|
let registry = make_registry_with_config(&["github", "slack"]);
|
||||||
|
|
||||||
let ids = registry.resolve_server_ids(Some(vec!["slack".to_string()]));
|
let ids = registry.resolve_server_ids(Some(vec!["slack".to_string()]));
|
||||||
|
|
||||||
assert_eq!(ids, vec!["slack"]);
|
assert_eq!(ids, vec!["slack"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_none_returns_empty() {
|
fn resolve_none_returns_empty() {
|
||||||
let registry = make_registry_with_config(&["github"]);
|
let registry = make_registry_with_config(&["github"]);
|
||||||
|
|
||||||
let ids = registry.resolve_server_ids(None);
|
let ids = registry.resolve_server_ids(None);
|
||||||
|
|
||||||
assert!(ids.is_empty());
|
assert!(ids.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_no_config_returns_empty() {
|
fn resolve_no_config_returns_empty() {
|
||||||
let registry = McpRegistry::default();
|
let registry = McpRegistry::default();
|
||||||
|
|
||||||
let ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
|
let ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
|
||||||
|
|
||||||
assert!(ids.is_empty());
|
assert!(ids.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_nonexistent_server_filtered_out() {
|
fn resolve_nonexistent_server_filtered_out() {
|
||||||
let registry = make_registry_with_config(&["github"]);
|
let registry = make_registry_with_config(&["github"]);
|
||||||
|
|
||||||
let ids = registry
|
let ids = registry
|
||||||
.resolve_server_ids(Some(vec!["github".to_string(), "nonexistent".to_string()]));
|
.resolve_server_ids(Some(vec!["github".to_string(), "nonexistent".to_string()]));
|
||||||
|
|
||||||
assert_eq!(ids, vec!["github"]);
|
assert_eq!(ids, vec!["github"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_all_nonexistent_returns_empty() {
|
fn resolve_all_nonexistent_returns_empty() {
|
||||||
let registry = make_registry_with_config(&["github"]);
|
let registry = make_registry_with_config(&["github"]);
|
||||||
|
|
||||||
let ids = registry.resolve_server_ids(Some(vec!["foo".to_string(), "bar".to_string()]));
|
let ids = registry.resolve_server_ids(Some(vec!["foo".to_string(), "bar".to_string()]));
|
||||||
|
|
||||||
assert!(ids.is_empty());
|
assert!(ids.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_trims_whitespace() {
|
fn resolve_trims_whitespace() {
|
||||||
let registry = make_registry_with_config(&["github", "slack"]);
|
let registry = make_registry_with_config(&["github", "slack"]);
|
||||||
|
|
||||||
let mut ids = registry.resolve_server_ids(Some(vec![
|
let mut ids = registry.resolve_server_ids(Some(vec![
|
||||||
" github ".to_string(),
|
" github ".to_string(),
|
||||||
" slack ".to_string(),
|
" slack ".to_string(),
|
||||||
]));
|
]));
|
||||||
ids.sort();
|
ids.sort();
|
||||||
|
|
||||||
assert_eq!(ids, vec!["github", "slack"]);
|
assert_eq!(ids, vec!["github", "slack"]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn registry_default_is_empty() {
|
fn registry_default_is_empty() {
|
||||||
let registry = McpRegistry::default();
|
let registry = McpRegistry::default();
|
||||||
|
|
||||||
assert!(registry.is_empty());
|
assert!(registry.is_empty());
|
||||||
assert!(registry.list_started_servers().is_empty());
|
assert!(registry.list_started_servers().is_empty());
|
||||||
assert!(registry.mcp_config().is_none());
|
assert!(registry.mcp_config().is_none());
|
||||||
@@ -831,6 +938,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn registry_with_config_reports_config() {
|
fn registry_with_config_reports_config() {
|
||||||
let registry = make_registry_with_config(&["github"]);
|
let registry = make_registry_with_config(&["github"]);
|
||||||
|
|
||||||
assert!(registry.mcp_config().is_some());
|
assert!(registry.mcp_config().is_some());
|
||||||
assert!(
|
assert!(
|
||||||
registry
|
registry
|
||||||
@@ -847,4 +955,53 @@ mod tests {
|
|||||||
assert_eq!(MCP_SEARCH_META_FUNCTION_NAME_PREFIX, "mcp_search");
|
assert_eq!(MCP_SEARCH_META_FUNCTION_NAME_PREFIX, "mcp_search");
|
||||||
assert_eq!(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, "mcp_describe");
|
assert_eq!(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, "mcp_describe");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn merge_bearer_token_both_none_returns_none() {
|
||||||
|
assert!(merge_bearer_token(None, None).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn merge_bearer_token_headers_only_passes_through() {
|
||||||
|
let mut h = IndexMap::new();
|
||||||
|
h.insert("X-Key".to_string(), "val".to_string());
|
||||||
|
|
||||||
|
let result = merge_bearer_token(Some(&h), None).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result["X-Key"], "val");
|
||||||
|
assert!(!result.contains_key("Authorization"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn merge_bearer_token_token_only_injects_bearer() {
|
||||||
|
let result = merge_bearer_token(None, Some("tok123".to_string())).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result["Authorization"], "Bearer tok123");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn merge_bearer_token_both_merges_and_overrides_authorization() {
|
||||||
|
let mut h = IndexMap::new();
|
||||||
|
h.insert("Authorization".to_string(), "old".to_string());
|
||||||
|
h.insert("X-Custom".to_string(), "keep".to_string());
|
||||||
|
|
||||||
|
let result = merge_bearer_token(Some(&h), Some("newtoken".to_string())).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result["Authorization"], "Bearer newtoken");
|
||||||
|
assert_eq!(result["X-Custom"], "keep");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_auth_required_error_matches_rmcp_message() {
|
||||||
|
let e = anyhow!("Auth required, when send initialize request");
|
||||||
|
|
||||||
|
assert!(is_auth_required_error(&e));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_auth_required_error_does_not_match_unrelated() {
|
||||||
|
let e = anyhow!("Connection refused");
|
||||||
|
|
||||||
|
assert!(!is_auth_required_error(&e));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,325 @@
|
|||||||
|
use crate::client::oauth::{OAuthProvider, load_oauth_tokens, run_oauth_flow};
|
||||||
|
use crate::config::paths;
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use chrono::Utc;
|
||||||
|
use inquire::Text;
|
||||||
|
use log::warn;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fs;
|
||||||
|
use std::net::TcpListener;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ProtectedResourceMetadata {
|
||||||
|
#[serde(default)]
|
||||||
|
authorization_servers: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OAuthServerMetadata {
|
||||||
|
authorization_endpoint: String,
|
||||||
|
token_endpoint: String,
|
||||||
|
#[serde(default)]
|
||||||
|
scopes_supported: Vec<String>,
|
||||||
|
registration_endpoint: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
struct McpRegistration {
|
||||||
|
client_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct McpOAuthProvider {
|
||||||
|
client_id: String,
|
||||||
|
authorize_url: String,
|
||||||
|
token_url: String,
|
||||||
|
scopes: String,
|
||||||
|
fixed_redirect: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OAuthProvider for McpOAuthProvider {
|
||||||
|
fn provider_name(&self) -> &str {
|
||||||
|
"MCP"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_id(&self) -> &str {
|
||||||
|
&self.client_id
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authorize_url(&self) -> &str {
|
||||||
|
&self.authorize_url
|
||||||
|
}
|
||||||
|
|
||||||
|
fn token_url(&self) -> &str {
|
||||||
|
&self.token_url
|
||||||
|
}
|
||||||
|
|
||||||
|
fn redirect_uri(&self) -> &str {
|
||||||
|
""
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scopes(&self) -> &str {
|
||||||
|
&self.scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uses_localhost_redirect(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fixed_redirect_uri(&self) -> Option<String> {
|
||||||
|
Some(self.fixed_redirect.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run_mcp_oauth_flow(
|
||||||
|
server_name: &str,
|
||||||
|
server_url: &str,
|
||||||
|
configured_client_id: Option<&str>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let metadata = discover_oauth_metadata(server_url).await?;
|
||||||
|
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||||
|
let port = listener.local_addr()?.port();
|
||||||
|
drop(listener);
|
||||||
|
let redirect_uri = format!("http://127.0.0.1:{port}/callback");
|
||||||
|
|
||||||
|
let client_id = if let Some(id) = configured_client_id {
|
||||||
|
id.to_string()
|
||||||
|
} else if let Some(cached) = load_registered_client_id(server_name) {
|
||||||
|
cached
|
||||||
|
} else if let Some(reg_endpoint) = &metadata.registration_endpoint {
|
||||||
|
match register_client(reg_endpoint, &redirect_uri).await {
|
||||||
|
Ok(id) => {
|
||||||
|
let _ = save_registered_client_id(server_name, &id);
|
||||||
|
id
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Dynamic client registration failed: {e}. Falling back to manual entry.");
|
||||||
|
Text::new("Enter the OAuth client ID for this MCP server:")
|
||||||
|
.prompt()
|
||||||
|
.context("Failed to read client ID")?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Text::new("Enter the OAuth client ID for this MCP server:")
|
||||||
|
.prompt()
|
||||||
|
.context("Failed to read client ID")?
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider = McpOAuthProvider {
|
||||||
|
client_id,
|
||||||
|
authorize_url: metadata.authorization_endpoint,
|
||||||
|
token_url: metadata.token_endpoint,
|
||||||
|
scopes: metadata.scopes_supported.join(" "),
|
||||||
|
fixed_redirect: redirect_uri,
|
||||||
|
};
|
||||||
|
|
||||||
|
run_oauth_flow(&provider, &mcp_token_key(server_name)).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_valid_mcp_token(server_name: &str) -> Option<String> {
|
||||||
|
let tokens = load_oauth_tokens(&mcp_token_key(server_name))?;
|
||||||
|
if Utc::now().timestamp() < tokens.expires_at {
|
||||||
|
Some(tokens.access_token)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mcp_token_key(server_name: &str) -> String {
|
||||||
|
format!("mcp_{server_name}")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_registered_client_id(server_name: &str) -> Option<String> {
|
||||||
|
let path = paths::oauth_tokens_path().join(format!("mcp_{server_name}_registration.json"));
|
||||||
|
let content = fs::read_to_string(path).ok()?;
|
||||||
|
let reg: McpRegistration = serde_json::from_str(&content).ok()?;
|
||||||
|
|
||||||
|
Some(reg.client_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save_registered_client_id(server_name: &str, client_id: &str) -> Result<()> {
|
||||||
|
let dir = paths::oauth_tokens_path();
|
||||||
|
fs::create_dir_all(&dir)?;
|
||||||
|
|
||||||
|
let path = dir.join(format!("mcp_{server_name}_registration.json"));
|
||||||
|
let reg = McpRegistration {
|
||||||
|
client_id: client_id.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
fs::write(path, serde_json::to_string_pretty(®)?)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn register_client(endpoint: &str, redirect_uri: &str) -> Result<String> {
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"client_name": "Coyote",
|
||||||
|
"redirect_uris": [redirect_uri],
|
||||||
|
"grant_types": ["authorization_code", "refresh_token"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"token_endpoint_auth_method": "none"
|
||||||
|
});
|
||||||
|
|
||||||
|
let response: serde_json::Value = Client::new()
|
||||||
|
.post(endpoint)
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.context("Failed to reach registration endpoint")?
|
||||||
|
.json()
|
||||||
|
.await
|
||||||
|
.context("Failed to parse registration response")?;
|
||||||
|
|
||||||
|
response["client_id"]
|
||||||
|
.as_str()
|
||||||
|
.ok_or_else(|| anyhow!("Missing client_id in registration response: {response}"))
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn discover_oauth_metadata(server_url: &str) -> Result<OAuthServerMetadata> {
|
||||||
|
let base = extract_base_url(server_url)?;
|
||||||
|
let client = Client::new();
|
||||||
|
|
||||||
|
// RFC 9728: try protected resource metadata first; it points to the auth server
|
||||||
|
let pr_url = format!("{base}/.well-known/oauth-protected-resource");
|
||||||
|
if let Ok(resp) = client.get(&pr_url).send().await
|
||||||
|
&& resp.status().is_success()
|
||||||
|
&& let Ok(pr) = resp.json::<ProtectedResourceMetadata>().await
|
||||||
|
&& let Some(auth_server) = pr.authorization_servers.first()
|
||||||
|
{
|
||||||
|
let as_url = format!("{auth_server}/.well-known/oauth-authorization-server");
|
||||||
|
if let Ok(resp) = client.get(&as_url).send().await
|
||||||
|
&& resp.status().is_success()
|
||||||
|
&& let Ok(meta) = resp.json::<OAuthServerMetadata>().await
|
||||||
|
{
|
||||||
|
return Ok(meta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let as_url = format!("{base}/.well-known/oauth-authorization-server");
|
||||||
|
let resp = client
|
||||||
|
.get(&as_url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Failed to reach {as_url}"))?;
|
||||||
|
|
||||||
|
if resp.status().is_success() {
|
||||||
|
return resp
|
||||||
|
.json::<OAuthServerMetadata>()
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Failed to parse OAuth metadata from {as_url}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(anyhow!(
|
||||||
|
"Could not discover OAuth metadata for '{server_url}'.\n\
|
||||||
|
Tried:\n {pr_url}\n {as_url}\n\
|
||||||
|
Ensure the server supports MCP OAuth discovery, or consult its documentation."
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_base_url(url: &str) -> Result<String> {
|
||||||
|
let parsed = Url::parse(url).with_context(|| format!("Invalid URL: {url}"))?;
|
||||||
|
let scheme = parsed.scheme();
|
||||||
|
let host = parsed
|
||||||
|
.host_str()
|
||||||
|
.ok_or_else(|| anyhow!("No host in URL: {url}"))?;
|
||||||
|
let port = parsed.port().map(|p| format!(":{p}")).unwrap_or_default();
|
||||||
|
|
||||||
|
Ok(format!("{scheme}://{host}{port}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::utils::get_env_name;
|
||||||
|
use serial_test::serial;
|
||||||
|
use std::{
|
||||||
|
env, fs,
|
||||||
|
time::{self, SystemTime},
|
||||||
|
};
|
||||||
|
|
||||||
|
fn with_temp_cache<F: FnOnce()>(f: F) {
|
||||||
|
let unique = SystemTime::now()
|
||||||
|
.duration_since(time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_nanos();
|
||||||
|
let root = env::temp_dir().join(format!("coyote-mcp-oauth-test-{unique}"));
|
||||||
|
fs::create_dir_all(&root).unwrap();
|
||||||
|
let env_key = get_env_name("cache_dir");
|
||||||
|
let prev = env::var_os(&env_key);
|
||||||
|
unsafe {
|
||||||
|
env::set_var(&env_key, &root);
|
||||||
|
}
|
||||||
|
f();
|
||||||
|
unsafe {
|
||||||
|
match prev {
|
||||||
|
Some(v) => env::set_var(&env_key, v),
|
||||||
|
None => env::remove_var(&env_key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let _ = fs::remove_dir_all(&root);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_base_url_strips_path_and_query() {
|
||||||
|
let result = extract_base_url("https://mcp.notion.com/mcp?foo=bar").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result, "https://mcp.notion.com");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_base_url_preserves_explicit_port() {
|
||||||
|
let result = extract_base_url("http://localhost:8080/mcp").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result, "http://localhost:8080");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_base_url_standard_port_omitted() {
|
||||||
|
let result = extract_base_url("https://example.com/mcp/v1").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result, "https://example.com");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_base_url_rejects_invalid_url() {
|
||||||
|
assert!(extract_base_url("not-a-url").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn registered_client_id_roundtrip() {
|
||||||
|
with_temp_cache(|| {
|
||||||
|
save_registered_client_id("notion", "client-xyz-123").unwrap();
|
||||||
|
|
||||||
|
let loaded = load_registered_client_id("notion");
|
||||||
|
|
||||||
|
assert_eq!(loaded, Some("client-xyz-123".to_string()));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn load_registered_client_id_returns_none_for_missing() {
|
||||||
|
with_temp_cache(|| {
|
||||||
|
let loaded = load_registered_client_id("no-such-server");
|
||||||
|
|
||||||
|
assert!(loaded.is_none());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[serial]
|
||||||
|
fn registered_client_id_second_save_overwrites_first() {
|
||||||
|
with_temp_cache(|| {
|
||||||
|
save_registered_client_id("github", "first-id").unwrap();
|
||||||
|
save_registered_client_id("github", "second-id").unwrap();
|
||||||
|
|
||||||
|
let loaded = load_registered_client_id("github");
|
||||||
|
|
||||||
|
assert_eq!(loaded, Some("second-id".to_string()));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
+58
-4
@@ -20,7 +20,7 @@ use crate::utils::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use crate::sandbox::SANDBOX_ENV_FLAG;
|
use crate::sandbox::SANDBOX_ENV_FLAG;
|
||||||
use crate::{config, graph, resolve_oauth_client};
|
use crate::{config, graph, mcp, resolve_oauth_client};
|
||||||
use anyhow::{Context, Result, bail};
|
use anyhow::{Context, Result, bail};
|
||||||
use crossterm::cursor::SetCursorStyle;
|
use crossterm::cursor::SetCursorStyle;
|
||||||
use fancy_regex::Regex;
|
use fancy_regex::Regex;
|
||||||
@@ -49,7 +49,7 @@ pub const DEFAULT_CONTINUATION_PROMPT: &str = indoc! {"
|
|||||||
4. Continue with the next pending item now. Call tools immediately."
|
4. Continue with the next pending item now. Call tools immediately."
|
||||||
};
|
};
|
||||||
|
|
||||||
static REPL_COMMANDS: LazyLock<[ReplCommand; 49]> = LazyLock::new(|| {
|
static REPL_COMMANDS: LazyLock<[ReplCommand; 50]> = LazyLock::new(|| {
|
||||||
[
|
[
|
||||||
ReplCommand::new(".help", "Show this help guide", AssertState::pass()),
|
ReplCommand::new(".help", "Show this help guide", AssertState::pass()),
|
||||||
ReplCommand::new(".info", "Show system info", AssertState::pass()),
|
ReplCommand::new(".info", "Show system info", AssertState::pass()),
|
||||||
@@ -63,6 +63,11 @@ static REPL_COMMANDS: LazyLock<[ReplCommand; 49]> = LazyLock::new(|| {
|
|||||||
"Authenticate the current model client via OAuth (if configured)",
|
"Authenticate the current model client via OAuth (if configured)",
|
||||||
AssertState::pass(),
|
AssertState::pass(),
|
||||||
),
|
),
|
||||||
|
ReplCommand::new(
|
||||||
|
".mcp auth",
|
||||||
|
"Authenticate with an MCP server via OAuth",
|
||||||
|
AssertState::pass(),
|
||||||
|
),
|
||||||
ReplCommand::new(
|
ReplCommand::new(
|
||||||
".edit config",
|
".edit config",
|
||||||
"Modify configuration file",
|
"Modify configuration file",
|
||||||
@@ -541,6 +546,55 @@ pub async fn run_repl_command(
|
|||||||
let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?;
|
let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?;
|
||||||
oauth::run_oauth_flow(&*provider, &client_name).await?;
|
oauth::run_oauth_flow(&*provider, &client_name).await?;
|
||||||
}
|
}
|
||||||
|
".mcp" => match args {
|
||||||
|
Some(args) => {
|
||||||
|
let mut parts = args.splitn(2, char::is_whitespace);
|
||||||
|
let sub = parts.next().unwrap_or("").trim();
|
||||||
|
let rest = parts.next().map(str::trim).unwrap_or("");
|
||||||
|
match sub {
|
||||||
|
"auth" => {
|
||||||
|
if rest.is_empty() {
|
||||||
|
println!("Usage: .mcp auth <server_name>");
|
||||||
|
} else {
|
||||||
|
let server_name = rest;
|
||||||
|
let server_spec = ctx
|
||||||
|
.app
|
||||||
|
.mcp_config
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|c| c.mcp_servers.get(server_name))
|
||||||
|
.cloned();
|
||||||
|
match server_spec {
|
||||||
|
None => bail!(
|
||||||
|
"MCP server '{}' not found in config. \
|
||||||
|
Check your mcp_config.json.",
|
||||||
|
server_name
|
||||||
|
),
|
||||||
|
Some(spec) if !spec.is_remote() => bail!(
|
||||||
|
"MCP server '{}' uses stdio transport; \
|
||||||
|
OAuth is only supported for http/sse servers.",
|
||||||
|
server_name
|
||||||
|
),
|
||||||
|
Some(spec) => {
|
||||||
|
let url = spec
|
||||||
|
.url
|
||||||
|
.as_deref()
|
||||||
|
.expect("validated: remote spec has url");
|
||||||
|
let client_id = spec.oauth_client_id.as_deref();
|
||||||
|
mcp::oauth::run_mcp_oauth_flow(server_name, url, client_id)
|
||||||
|
.await?;
|
||||||
|
println!(
|
||||||
|
"Authentication saved. \
|
||||||
|
Restart Coyote to connect to '{server_name}'."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => unknown_command()?,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => println!("Usage: .mcp auth <server_name>"),
|
||||||
|
},
|
||||||
".prompt" => match args {
|
".prompt" => match args {
|
||||||
Some(text) => {
|
Some(text) => {
|
||||||
let app = Arc::clone(&ctx.app.config);
|
let app = Arc::clone(&ctx.app.config);
|
||||||
@@ -1415,8 +1469,8 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn repl_commands_has_49_entries() {
|
fn repl_commands_has_50_entries() {
|
||||||
assert_eq!(REPL_COMMANDS.len(), 49);
|
assert_eq!(REPL_COMMANDS.len(), 50);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
Reference in New Issue
Block a user