Compare commits
9 Commits
cc50d39ab4
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
2ec2aec4c0
|
|||
|
c2cb4ac433
|
|||
|
605a9170b0
|
|||
|
385bd3eda2
|
|||
|
6c3d96ac83
|
|||
|
aa1fe7f7aa
|
|||
|
5e50828108
|
|||
|
693e2d9672
|
|||
|
16f324cefc
|
+29
-1
@@ -5,9 +5,9 @@ use crate::utils::list_file_names;
|
||||
use crate::vault::Vault;
|
||||
use clap_complete::{CompletionCandidate, Shell, generate};
|
||||
use clap_complete_nushell::Nushell;
|
||||
use std::env;
|
||||
use std::ffi::OsStr;
|
||||
use std::io;
|
||||
use std::{env, fs};
|
||||
|
||||
const COYOTE_CLI_NAME: &str = "coyote";
|
||||
|
||||
@@ -134,6 +134,34 @@ pub(super) fn session_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(super) fn mcp_server_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
let content = match fs::read_to_string(paths::mcp_config_file()) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return vec![],
|
||||
};
|
||||
let json: serde_json::Value = match serde_json::from_str(&content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return vec![],
|
||||
};
|
||||
let servers = match json.get("mcpServers").and_then(|v| v.as_object()) {
|
||||
Some(s) => s,
|
||||
None => return vec![],
|
||||
};
|
||||
|
||||
servers
|
||||
.iter()
|
||||
.filter(|(_, v)| {
|
||||
v.get("type")
|
||||
.and_then(|t| t.as_str())
|
||||
.map(|t| t == "http" || t == "sse")
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.filter(|(k, _)| k.starts_with(&*cur))
|
||||
.map(|(k, _)| CompletionCandidate::new(k))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(super) fn secrets_completer(current: &OsStr) -> Vec<CompletionCandidate> {
|
||||
let cur = current.to_string_lossy();
|
||||
match load_app_config_for_completion() {
|
||||
|
||||
+5
-2
@@ -1,8 +1,8 @@
|
||||
mod completer;
|
||||
|
||||
use crate::cli::completer::{
|
||||
ShellCompletion, agent_completer, macro_completer, model_completer, rag_completer,
|
||||
role_completer, secrets_completer, session_completer,
|
||||
ShellCompletion, agent_completer, macro_completer, mcp_server_completer, model_completer,
|
||||
rag_completer, role_completer, secrets_completer, session_completer,
|
||||
};
|
||||
use crate::config::{AssetCategory, InstallFilter, MemoryScope};
|
||||
use anyhow::{Context, Result};
|
||||
@@ -171,6 +171,9 @@ pub struct Cli {
|
||||
/// Authenticate with an LLM provider using OAuth (e.g., --authenticate client_name)
|
||||
#[arg(long, exclusive = true, value_name = "CLIENT_NAME")]
|
||||
pub authenticate: Option<Option<String>>,
|
||||
/// Authenticate with an OAuth-protected remote MCP server (e.g., --auth-mcp server_name)
|
||||
#[arg(long, exclusive = true, value_name = "SERVER_NAME", add = ArgValueCompleter::new(mcp_server_completer))]
|
||||
pub auth_mcp: Option<String>,
|
||||
/// Generate static shell completion scripts
|
||||
#[arg(long, value_name = "SHELL", value_enum)]
|
||||
pub completions: Option<ShellCompletion>,
|
||||
|
||||
@@ -133,6 +133,13 @@ impl MessageContent {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_text(&self) -> Option<&str> {
|
||||
match self {
|
||||
MessageContent::Text(text) => Some(text),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) {
|
||||
match self {
|
||||
MessageContent::Text(text) => *text = replace_fn(text),
|
||||
|
||||
+10
-4
@@ -53,6 +53,10 @@ pub trait OAuthProvider: Send + Sync {
|
||||
fn extra_request_headers(&self) -> Vec<(&str, &str)> {
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn fixed_redirect_uri(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -72,14 +76,16 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
|
||||
|
||||
let state = Uuid::new_v4().to_string();
|
||||
|
||||
let redirect_uri = if provider.uses_localhost_redirect() {
|
||||
let (redirect_uri, use_callback_listener) = if let Some(fixed) = provider.fixed_redirect_uri() {
|
||||
(fixed, true)
|
||||
} else if provider.uses_localhost_redirect() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
let uri = format!("http://127.0.0.1:{port}/callback");
|
||||
drop(listener);
|
||||
uri
|
||||
(uri, true)
|
||||
} else {
|
||||
provider.redirect_uri().to_string()
|
||||
(provider.redirect_uri().to_string(), false)
|
||||
};
|
||||
|
||||
let encoded_scopes = urlencoding::encode(provider.scopes());
|
||||
@@ -112,7 +118,7 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
|
||||
|
||||
let _ = open::that(&authorize_url);
|
||||
|
||||
let (code, returned_state) = if provider.uses_localhost_redirect() {
|
||||
let (code, returned_state) = if use_callback_listener {
|
||||
listen_for_oauth_callback(&redirect_uri)?
|
||||
} else {
|
||||
let input = Text::new("Paste the authorization code:").prompt()?;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use crate::mcp::{ConnectedServer, JsonField, McpServer, McpTransportType, spawn_mcp_server};
|
||||
use crate::mcp::{
|
||||
ConnectedServer, JsonField, McpServer, McpTransportType, oauth, spawn_mcp_server,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use parking_lot::Mutex;
|
||||
@@ -99,7 +101,12 @@ impl McpFactory {
|
||||
return Ok(existing);
|
||||
}
|
||||
|
||||
let handle = spawn_mcp_server(spec, log_path).await?;
|
||||
let bearer_token = if spec.is_remote() {
|
||||
oauth::load_valid_mcp_token(name)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let handle = spawn_mcp_server(spec, log_path, bearer_token).await?;
|
||||
self.insert_active(key, &handle);
|
||||
Ok(handle)
|
||||
}
|
||||
@@ -125,6 +132,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +149,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some(url.to_string()),
|
||||
headers,
|
||||
oauth_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -135,6 +135,7 @@ const RAGS_DIR_NAME: &str = "rags";
|
||||
const FUNCTIONS_DIR_NAME: &str = "functions";
|
||||
const FUNCTIONS_BIN_DIR_NAME: &str = "bin";
|
||||
const AGENTS_DIR_NAME: &str = "agents";
|
||||
const REPL_HISTORY_DIR_NAME: &str = "repl-history";
|
||||
const GLOBAL_TOOLS_DIR_NAME: &str = "tools";
|
||||
const GLOBAL_TOOLS_UTILS_DIR_NAME: &str = "utils";
|
||||
const BASH_PROMPT_UTILS_FILE_NAME: &str = "prompt-utils.sh";
|
||||
|
||||
@@ -8,6 +8,8 @@ use super::{
|
||||
SKILLS_DIR_NAME, WORKSPACE_MEMORY_DIR_NAME,
|
||||
};
|
||||
use crate::client::ProviderModels;
|
||||
use crate::config::REPL_HISTORY_DIR_NAME;
|
||||
use crate::config::session::Session;
|
||||
use crate::utils::{get_env_name, list_file_names, normalize_env_name};
|
||||
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
@@ -320,6 +322,20 @@ pub fn workspace_memory_dir_for(workspace_root: &Path) -> PathBuf {
|
||||
.join(MEMORY_DIR_NAME)
|
||||
}
|
||||
|
||||
pub fn repl_history_dir() -> PathBuf {
|
||||
cache_path().join(REPL_HISTORY_DIR_NAME)
|
||||
}
|
||||
|
||||
pub fn repl_history_file(session: &Option<Session>) -> PathBuf {
|
||||
let history_key = if let Some(session) = &session {
|
||||
format!("session_{}", session.name().replace('/', "_"))
|
||||
} else {
|
||||
"default".to_string()
|
||||
};
|
||||
|
||||
repl_history_dir().join(history_key)
|
||||
}
|
||||
|
||||
pub fn log_config() -> Result<(LevelFilter, Option<PathBuf>)> {
|
||||
let log_level = env::var(get_env_name("log_level"))
|
||||
.ok()
|
||||
|
||||
@@ -2340,6 +2340,17 @@ impl RequestContext {
|
||||
}
|
||||
_ => vec![],
|
||||
};
|
||||
} else if cmd == ".mcp" && args.first() == Some(&"auth") && args.len() == 2 {
|
||||
if let Some(mcp_config) = &self.app.mcp_config {
|
||||
values = super::map_completion_values(
|
||||
mcp_config
|
||||
.mcp_servers
|
||||
.iter()
|
||||
.filter(|(_, spec)| spec.is_remote())
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
} else if (cmd == ".edit" && args.first() == Some(&"skill") && args.len() == 2)
|
||||
|| (cmd == ".skill" && args.first() == Some(&"load") && args.len() == 2)
|
||||
{
|
||||
@@ -3687,6 +3698,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -163,6 +163,14 @@ impl Session {
|
||||
self.messages.is_empty() && self.compressed_messages.is_empty()
|
||||
}
|
||||
|
||||
pub fn messages(&self) -> &[Message] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
pub fn compressed_messages(&self) -> &[Message] {
|
||||
&self.compressed_messages
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
+46
-2
@@ -28,11 +28,12 @@ use crate::config::{
|
||||
install_builtins, list_agents, load_env_file, macro_execute, sync_models,
|
||||
};
|
||||
use crate::function::supervisor::{GuardrailAction, check_pending_agents_guardrail};
|
||||
use crate::mcp::McpServersConfig;
|
||||
use crate::render::{prompt_theme, render_error};
|
||||
use crate::repl::Repl;
|
||||
use crate::utils::*;
|
||||
use crate::vault::Vault;
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use crate::vault::{Vault, interpolate_secrets};
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use clap::{CommandFactory, Parser};
|
||||
use clap_complete::CompleteEnv;
|
||||
use client::ClientConfig;
|
||||
@@ -120,6 +121,49 @@ async fn main() -> Result<()> {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(server_name) = &cli.auth_mcp {
|
||||
let cfg = Config::load_with_interpolation(true).await?;
|
||||
let app_config = AppConfig::from_config(cfg)?;
|
||||
let vault = Vault::init(&app_config)?;
|
||||
let mcp_path = paths::mcp_config_file();
|
||||
if !mcp_path.exists() {
|
||||
bail!(
|
||||
"No MCP configuration file found at '{}'",
|
||||
mcp_path.display()
|
||||
);
|
||||
}
|
||||
|
||||
let raw = tokio::fs::read_to_string(&mcp_path)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read MCP config at '{}'", mcp_path.display()))?;
|
||||
|
||||
let (content, missing) = interpolate_secrets(&raw, &vault)?;
|
||||
if !missing.is_empty() {
|
||||
bail!(
|
||||
"MCP config references vault secrets that are missing: {:?}",
|
||||
missing
|
||||
);
|
||||
}
|
||||
|
||||
let mcp_config: McpServersConfig =
|
||||
serde_json::from_str(&content).context("Failed to parse MCP config file")?;
|
||||
let spec = mcp_config
|
||||
.mcp_servers
|
||||
.get(server_name.as_str())
|
||||
.ok_or_else(|| anyhow!("MCP server '{server_name}' not found in mcp.json"))?;
|
||||
if !spec.is_remote() {
|
||||
bail!(
|
||||
"MCP server '{server_name}' is a stdio server; OAuth is only supported for http/sse servers"
|
||||
);
|
||||
}
|
||||
|
||||
let url = spec.url.as_deref().expect("validated: remote spec has url");
|
||||
mcp::oauth::run_mcp_oauth_flow(server_name, url, spec.oauth_client_id.as_deref()).await?;
|
||||
println!("Authentication saved. '{server_name}' is now available for use.");
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if vault_flags {
|
||||
let cfg = Config::load_with_interpolation(true).await?;
|
||||
let app_config = AppConfig::from_config(cfg)?;
|
||||
|
||||
+166
-9
@@ -1,3 +1,4 @@
|
||||
pub(crate) mod oauth;
|
||||
mod sse_transport;
|
||||
|
||||
use crate::config::AppConfig;
|
||||
@@ -73,6 +74,8 @@ pub(crate) struct McpServer {
|
||||
pub url: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub headers: Option<IndexMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub oauth_client_id: Option<String>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
@@ -107,10 +110,10 @@ impl McpServer {
|
||||
"MCP server '{name}' is missing a \"command\" field (required for stdio transport)"
|
||||
));
|
||||
}
|
||||
if self.url.is_some() || self.headers.is_some() {
|
||||
if self.url.is_some() || self.headers.is_some() || self.oauth_client_id.is_some() {
|
||||
return Err(anyhow!(
|
||||
"MCP server '{name}' has type \"stdio\" but also specifies remote fields \
|
||||
(url/headers). Remove the remote fields or change the type to \"http\" or \"sse\"."
|
||||
(url/headers/oauth_client_id). Remove the remote fields or change the type to \"http\" or \"sse\"."
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -237,7 +240,7 @@ impl McpRegistry {
|
||||
|
||||
debug!("Starting selected MCP servers: {:?}", ids_to_start);
|
||||
|
||||
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter(
|
||||
let results: Vec<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> = stream::iter(
|
||||
ids_to_start
|
||||
.into_iter()
|
||||
.map(|id| async { self.start_server(id).await }),
|
||||
@@ -246,7 +249,7 @@ impl McpRegistry {
|
||||
.try_collect()
|
||||
.await?;
|
||||
|
||||
for (id, server, catalog) in results {
|
||||
for (id, server, catalog) in results.into_iter().flatten() {
|
||||
self.servers.insert(id.clone(), server);
|
||||
self.catalogs.insert(id, catalog);
|
||||
}
|
||||
@@ -257,14 +260,30 @@ impl McpRegistry {
|
||||
async fn start_server(
|
||||
&self,
|
||||
id: String,
|
||||
) -> Result<(String, Arc<ConnectedServer>, ServerCatalog)> {
|
||||
) -> Result<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> {
|
||||
let spec = self
|
||||
.config
|
||||
.as_ref()
|
||||
.and_then(|c| c.mcp_servers.get(&id))
|
||||
.with_context(|| format!("MCP server not found in config: {id}"))?;
|
||||
|
||||
let service = spawn_mcp_server(spec, self.log_path.as_deref()).await?;
|
||||
let bearer_token = if spec.is_remote() {
|
||||
oauth::load_valid_mcp_token(&id)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let service = match spawn_mcp_server(spec, self.log_path.as_deref(), bearer_token).await {
|
||||
Ok(s) => s,
|
||||
Err(e) if is_auth_required_error(&e) => {
|
||||
warn!(
|
||||
"MCP server '{id}' requires OAuth authentication. \
|
||||
Run `.mcp auth {id}` in the REPL to authenticate, then restart Coyote."
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
let tools = service.list_tools(None).await?;
|
||||
debug!("Available tools for MCP server {id}: {tools:?}");
|
||||
@@ -289,7 +308,7 @@ impl McpRegistry {
|
||||
|
||||
info!("Started MCP server: {id}");
|
||||
|
||||
Ok((id.to_string(), service, catalog))
|
||||
Ok(Some((id.to_string(), service, catalog)))
|
||||
}
|
||||
|
||||
fn resolve_server_ids(&self, enabled_mcp_servers: Option<Vec<String>>) -> Vec<String> {
|
||||
@@ -337,15 +356,18 @@ impl McpRegistry {
|
||||
pub(crate) async fn spawn_mcp_server(
|
||||
spec: &McpServer,
|
||||
log_path: Option<&Path>,
|
||||
bearer_token: Option<String>,
|
||||
) -> Result<Arc<ConnectedServer>> {
|
||||
match spec.transport_type {
|
||||
McpTransportType::Http => {
|
||||
let url = spec.url.as_deref().expect("validated: http spec has url");
|
||||
spawn_http_mcp_server(url, spec.headers.as_ref()).await
|
||||
let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
|
||||
spawn_http_mcp_server(url, headers.as_ref()).await
|
||||
}
|
||||
McpTransportType::Sse => {
|
||||
let url = spec.url.as_deref().expect("validated: sse spec has url");
|
||||
spawn_sse_mcp_server(url, spec.headers.as_ref()).await
|
||||
let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
|
||||
spawn_sse_mcp_server(url, headers.as_ref()).await
|
||||
}
|
||||
McpTransportType::Stdio => {
|
||||
let command = spec
|
||||
@@ -357,6 +379,30 @@ pub(crate) async fn spawn_mcp_server(
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_bearer_token(
|
||||
headers: Option<&IndexMap<String, String>>,
|
||||
bearer_token: Option<String>,
|
||||
) -> Option<IndexMap<String, String>> {
|
||||
match (headers, bearer_token) {
|
||||
(None, None) => None,
|
||||
(Some(h), None) => Some(h.clone()),
|
||||
(None, Some(token)) => {
|
||||
let mut m = IndexMap::new();
|
||||
m.insert("Authorization".to_string(), format!("Bearer {token}"));
|
||||
Some(m)
|
||||
}
|
||||
(Some(h), Some(token)) => {
|
||||
let mut m = h.clone();
|
||||
m.insert("Authorization".to_string(), format!("Bearer {token}"));
|
||||
Some(m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_auth_required_error(e: &anyhow::Error) -> bool {
|
||||
e.to_string().contains("Auth required")
|
||||
}
|
||||
|
||||
async fn spawn_http_mcp_server(
|
||||
url: &str,
|
||||
headers: Option<&IndexMap<String, String>>,
|
||||
@@ -465,6 +511,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -477,6 +524,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some(url.to_string()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -489,6 +537,7 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some(url.to_string()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -506,6 +555,7 @@ mod tests {
|
||||
#[test]
|
||||
fn validate_stdio_with_command_succeeds() {
|
||||
let spec = stdio_server("npx");
|
||||
|
||||
assert!(spec.validate("test").is_ok());
|
||||
}
|
||||
|
||||
@@ -519,8 +569,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("missing a \"command\" field"));
|
||||
}
|
||||
|
||||
@@ -534,8 +587,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some("http://localhost".into()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("remote fields"));
|
||||
}
|
||||
|
||||
@@ -551,14 +607,18 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: Some(headers),
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("remote fields"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_http_with_url_succeeds() {
|
||||
let spec = http_server("http://localhost:8080");
|
||||
|
||||
assert!(spec.validate("test").is_ok());
|
||||
}
|
||||
|
||||
@@ -572,8 +632,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("missing a \"url\" field"));
|
||||
}
|
||||
|
||||
@@ -587,8 +650,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some("http://localhost".into()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("stdio fields"));
|
||||
}
|
||||
|
||||
@@ -602,8 +668,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: Some("http://localhost".into()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("stdio fields"));
|
||||
}
|
||||
|
||||
@@ -617,14 +686,18 @@ mod tests {
|
||||
cwd: Some("/tmp".into()),
|
||||
url: Some("http://localhost".into()),
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("stdio fields"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_sse_with_url_succeeds() {
|
||||
let spec = sse_server("http://sse.example.com");
|
||||
|
||||
assert!(spec.validate("test").is_ok());
|
||||
}
|
||||
|
||||
@@ -638,8 +711,11 @@ mod tests {
|
||||
cwd: None,
|
||||
url: None,
|
||||
headers: None,
|
||||
oauth_client_id: None,
|
||||
};
|
||||
|
||||
let err = spec.validate("test").unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("missing a \"url\" field"));
|
||||
}
|
||||
|
||||
@@ -665,9 +741,13 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert!(config.mcp_servers.contains_key("my-server"));
|
||||
|
||||
let spec = &config.mcp_servers["my-server"];
|
||||
|
||||
assert_eq!(spec.transport_type, McpTransportType::Stdio);
|
||||
assert_eq!(spec.command.as_deref(), Some("npx"));
|
||||
assert_eq!(
|
||||
@@ -688,7 +768,9 @@ mod tests {
|
||||
}
|
||||
}"#;
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
let spec = &config.mcp_servers["remote"];
|
||||
|
||||
assert_eq!(spec.transport_type, McpTransportType::Http);
|
||||
assert_eq!(spec.url.as_deref(), Some("http://localhost:8080/mcp"));
|
||||
assert_eq!(
|
||||
@@ -713,7 +795,9 @@ mod tests {
|
||||
}
|
||||
}"#;
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
let env = config.mcp_servers["s"].env.as_ref().unwrap();
|
||||
|
||||
assert!(matches!(env["STR_VAR"], JsonField::Str(ref s) if s == "hello"));
|
||||
assert!(matches!(env["BOOL_VAR"], JsonField::Bool(true)));
|
||||
assert!(matches!(env["INT_VAR"], JsonField::Int(42)));
|
||||
@@ -727,7 +811,9 @@ mod tests {
|
||||
"remote-api": { "type": "http", "url": "http://api.example.com" }
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(config.mcp_servers.len(), 2);
|
||||
assert!(config.mcp_servers.contains_key("github"));
|
||||
assert!(config.mcp_servers.contains_key("remote-api"));
|
||||
@@ -736,7 +822,9 @@ mod tests {
|
||||
#[test]
|
||||
fn deserialize_empty_servers_map() {
|
||||
let json = r#"{ "mcpServers": {} }"#;
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert!(config.mcp_servers.is_empty());
|
||||
}
|
||||
|
||||
@@ -751,77 +839,96 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let config: McpServersConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(config.mcp_servers["s"].cwd.as_deref(), Some("/tmp/work"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_all_returns_all_configured_servers() {
|
||||
let registry = make_registry_with_config(&["github", "slack", "jira"]);
|
||||
|
||||
let mut ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
|
||||
ids.sort();
|
||||
|
||||
assert_eq!(ids, vec!["github", "jira", "slack"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_comma_separated_returns_matching_servers() {
|
||||
let registry = make_registry_with_config(&["github", "slack", "jira"]);
|
||||
|
||||
let mut ids =
|
||||
registry.resolve_server_ids(Some(vec!["github".to_string(), "jira".to_string()]));
|
||||
ids.sort();
|
||||
|
||||
assert_eq!(ids, vec!["github", "jira"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_single_server_name() {
|
||||
let registry = make_registry_with_config(&["github", "slack"]);
|
||||
|
||||
let ids = registry.resolve_server_ids(Some(vec!["slack".to_string()]));
|
||||
|
||||
assert_eq!(ids, vec!["slack"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_none_returns_empty() {
|
||||
let registry = make_registry_with_config(&["github"]);
|
||||
|
||||
let ids = registry.resolve_server_ids(None);
|
||||
|
||||
assert!(ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_no_config_returns_empty() {
|
||||
let registry = McpRegistry::default();
|
||||
|
||||
let ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
|
||||
|
||||
assert!(ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_nonexistent_server_filtered_out() {
|
||||
let registry = make_registry_with_config(&["github"]);
|
||||
|
||||
let ids = registry
|
||||
.resolve_server_ids(Some(vec!["github".to_string(), "nonexistent".to_string()]));
|
||||
|
||||
assert_eq!(ids, vec!["github"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_all_nonexistent_returns_empty() {
|
||||
let registry = make_registry_with_config(&["github"]);
|
||||
|
||||
let ids = registry.resolve_server_ids(Some(vec!["foo".to_string(), "bar".to_string()]));
|
||||
|
||||
assert!(ids.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_trims_whitespace() {
|
||||
let registry = make_registry_with_config(&["github", "slack"]);
|
||||
|
||||
let mut ids = registry.resolve_server_ids(Some(vec![
|
||||
" github ".to_string(),
|
||||
" slack ".to_string(),
|
||||
]));
|
||||
ids.sort();
|
||||
|
||||
assert_eq!(ids, vec!["github", "slack"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registry_default_is_empty() {
|
||||
let registry = McpRegistry::default();
|
||||
|
||||
assert!(registry.is_empty());
|
||||
assert!(registry.list_started_servers().is_empty());
|
||||
assert!(registry.mcp_config().is_none());
|
||||
@@ -831,6 +938,7 @@ mod tests {
|
||||
#[test]
|
||||
fn registry_with_config_reports_config() {
|
||||
let registry = make_registry_with_config(&["github"]);
|
||||
|
||||
assert!(registry.mcp_config().is_some());
|
||||
assert!(
|
||||
registry
|
||||
@@ -847,4 +955,53 @@ mod tests {
|
||||
assert_eq!(MCP_SEARCH_META_FUNCTION_NAME_PREFIX, "mcp_search");
|
||||
assert_eq!(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, "mcp_describe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_bearer_token_both_none_returns_none() {
|
||||
assert!(merge_bearer_token(None, None).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_bearer_token_headers_only_passes_through() {
|
||||
let mut h = IndexMap::new();
|
||||
h.insert("X-Key".to_string(), "val".to_string());
|
||||
|
||||
let result = merge_bearer_token(Some(&h), None).unwrap();
|
||||
|
||||
assert_eq!(result["X-Key"], "val");
|
||||
assert!(!result.contains_key("Authorization"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_bearer_token_token_only_injects_bearer() {
|
||||
let result = merge_bearer_token(None, Some("tok123".to_string())).unwrap();
|
||||
|
||||
assert_eq!(result["Authorization"], "Bearer tok123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_bearer_token_both_merges_and_overrides_authorization() {
|
||||
let mut h = IndexMap::new();
|
||||
h.insert("Authorization".to_string(), "old".to_string());
|
||||
h.insert("X-Custom".to_string(), "keep".to_string());
|
||||
|
||||
let result = merge_bearer_token(Some(&h), Some("newtoken".to_string())).unwrap();
|
||||
|
||||
assert_eq!(result["Authorization"], "Bearer newtoken");
|
||||
assert_eq!(result["X-Custom"], "keep");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_auth_required_error_matches_rmcp_message() {
|
||||
let e = anyhow!("Auth required, when send initialize request");
|
||||
|
||||
assert!(is_auth_required_error(&e));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_auth_required_error_does_not_match_unrelated() {
|
||||
let e = anyhow!("Connection refused");
|
||||
|
||||
assert!(!is_auth_required_error(&e));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
use crate::client::oauth::{OAuthProvider, TokenRequestFormat, load_oauth_tokens, run_oauth_flow};
|
||||
use crate::config::paths;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use chrono::Utc;
|
||||
use inquire::Text;
|
||||
use log::warn;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fs;
|
||||
use std::net::TcpListener;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ProtectedResourceMetadata {
|
||||
#[serde(default)]
|
||||
authorization_servers: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthServerMetadata {
|
||||
authorization_endpoint: String,
|
||||
token_endpoint: String,
|
||||
#[serde(default)]
|
||||
scopes_supported: Vec<String>,
|
||||
registration_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct McpRegistration {
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
struct McpOAuthProvider {
|
||||
client_id: String,
|
||||
authorize_url: String,
|
||||
token_url: String,
|
||||
scopes: String,
|
||||
fixed_redirect: String,
|
||||
}
|
||||
|
||||
impl OAuthProvider for McpOAuthProvider {
|
||||
fn provider_name(&self) -> &str {
|
||||
"MCP"
|
||||
}
|
||||
|
||||
fn client_id(&self) -> &str {
|
||||
&self.client_id
|
||||
}
|
||||
|
||||
fn authorize_url(&self) -> &str {
|
||||
&self.authorize_url
|
||||
}
|
||||
|
||||
fn token_url(&self) -> &str {
|
||||
&self.token_url
|
||||
}
|
||||
|
||||
fn redirect_uri(&self) -> &str {
|
||||
""
|
||||
}
|
||||
|
||||
fn scopes(&self) -> &str {
|
||||
&self.scopes
|
||||
}
|
||||
|
||||
fn token_request_format(&self) -> TokenRequestFormat {
|
||||
TokenRequestFormat::FormUrlEncoded
|
||||
}
|
||||
|
||||
fn uses_localhost_redirect(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn fixed_redirect_uri(&self) -> Option<String> {
|
||||
Some(self.fixed_redirect.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_mcp_oauth_flow(
|
||||
server_name: &str,
|
||||
server_url: &str,
|
||||
configured_client_id: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let metadata = discover_oauth_metadata(server_url).await?;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
let redirect_uri = format!("http://127.0.0.1:{port}/callback");
|
||||
|
||||
let client_id = if let Some(id) = configured_client_id {
|
||||
id.to_string()
|
||||
} else if let Some(cached) = load_registered_client_id(server_name) {
|
||||
cached
|
||||
} else if let Some(reg_endpoint) = &metadata.registration_endpoint {
|
||||
match register_client(reg_endpoint, &redirect_uri).await {
|
||||
Ok(id) => {
|
||||
let _ = save_registered_client_id(server_name, &id);
|
||||
id
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Dynamic client registration failed: {e}. Falling back to manual entry.");
|
||||
Text::new("Enter the OAuth client ID for this MCP server:")
|
||||
.prompt()
|
||||
.context("Failed to read client ID")?
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Text::new("Enter the OAuth client ID for this MCP server:")
|
||||
.prompt()
|
||||
.context("Failed to read client ID")?
|
||||
};
|
||||
|
||||
let provider = McpOAuthProvider {
|
||||
client_id,
|
||||
authorize_url: metadata.authorization_endpoint,
|
||||
token_url: metadata.token_endpoint,
|
||||
scopes: metadata.scopes_supported.join(" "),
|
||||
fixed_redirect: redirect_uri,
|
||||
};
|
||||
|
||||
run_oauth_flow(&provider, &mcp_token_key(server_name)).await
|
||||
}
|
||||
|
||||
pub fn load_valid_mcp_token(server_name: &str) -> Option<String> {
|
||||
let tokens = load_oauth_tokens(&mcp_token_key(server_name))?;
|
||||
if Utc::now().timestamp() < tokens.expires_at {
|
||||
Some(tokens.access_token)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn mcp_token_key(server_name: &str) -> String {
|
||||
format!("mcp_{server_name}")
|
||||
}
|
||||
|
||||
fn load_registered_client_id(server_name: &str) -> Option<String> {
|
||||
let path = paths::oauth_tokens_path().join(format!("mcp_{server_name}_registration.json"));
|
||||
let content = fs::read_to_string(path).ok()?;
|
||||
let reg: McpRegistration = serde_json::from_str(&content).ok()?;
|
||||
|
||||
Some(reg.client_id)
|
||||
}
|
||||
|
||||
fn save_registered_client_id(server_name: &str, client_id: &str) -> Result<()> {
|
||||
let dir = paths::oauth_tokens_path();
|
||||
fs::create_dir_all(&dir)?;
|
||||
|
||||
let path = dir.join(format!("mcp_{server_name}_registration.json"));
|
||||
let reg = McpRegistration {
|
||||
client_id: client_id.to_string(),
|
||||
};
|
||||
|
||||
fs::write(path, serde_json::to_string_pretty(®)?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn register_client(endpoint: &str, redirect_uri: &str) -> Result<String> {
|
||||
let body = serde_json::json!({
|
||||
"client_name": "Coyote",
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "none"
|
||||
});
|
||||
|
||||
let response: serde_json::Value = Client::new()
|
||||
.post(endpoint)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to reach registration endpoint")?
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to parse registration response")?;
|
||||
|
||||
response["client_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow!("Missing client_id in registration response: {response}"))
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
async fn discover_oauth_metadata(server_url: &str) -> Result<OAuthServerMetadata> {
|
||||
let base = extract_base_url(server_url)?;
|
||||
let client = Client::new();
|
||||
|
||||
// RFC 9728: try protected resource metadata first; it points to the auth server
|
||||
let pr_url = format!("{base}/.well-known/oauth-protected-resource");
|
||||
if let Ok(resp) = client.get(&pr_url).send().await
|
||||
&& resp.status().is_success()
|
||||
&& let Ok(pr) = resp.json::<ProtectedResourceMetadata>().await
|
||||
&& let Some(auth_server) = pr.authorization_servers.first()
|
||||
{
|
||||
let as_url = format!("{auth_server}/.well-known/oauth-authorization-server");
|
||||
if let Ok(resp) = client.get(&as_url).send().await
|
||||
&& resp.status().is_success()
|
||||
&& let Ok(meta) = resp.json::<OAuthServerMetadata>().await
|
||||
{
|
||||
return Ok(meta);
|
||||
}
|
||||
}
|
||||
|
||||
let as_url = format!("{base}/.well-known/oauth-authorization-server");
|
||||
let resp = client
|
||||
.get(&as_url)
|
||||
.send()
|
||||
.await
|
||||
.with_context(|| format!("Failed to reach {as_url}"))?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
return resp
|
||||
.json::<OAuthServerMetadata>()
|
||||
.await
|
||||
.with_context(|| format!("Failed to parse OAuth metadata from {as_url}"));
|
||||
}
|
||||
|
||||
Err(anyhow!(
|
||||
"Could not discover OAuth metadata for '{server_url}'.\n\
|
||||
Tried:\n {pr_url}\n {as_url}\n\
|
||||
Ensure the server supports MCP OAuth discovery, or consult its documentation."
|
||||
))
|
||||
}
|
||||
|
||||
fn extract_base_url(url: &str) -> Result<String> {
|
||||
let parsed = Url::parse(url).with_context(|| format!("Invalid URL: {url}"))?;
|
||||
let scheme = parsed.scheme();
|
||||
let host = parsed
|
||||
.host_str()
|
||||
.ok_or_else(|| anyhow!("No host in URL: {url}"))?;
|
||||
let port = parsed.port().map(|p| format!(":{p}")).unwrap_or_default();
|
||||
|
||||
Ok(format!("{scheme}://{host}{port}"))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::utils::get_env_name;
|
||||
use serial_test::serial;
|
||||
use std::{
|
||||
env, fs,
|
||||
time::{self, SystemTime},
|
||||
};
|
||||
|
||||
fn with_temp_cache<F: FnOnce()>(f: F) {
|
||||
let unique = SystemTime::now()
|
||||
.duration_since(time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let root = env::temp_dir().join(format!("coyote-mcp-oauth-test-{unique}"));
|
||||
fs::create_dir_all(&root).unwrap();
|
||||
let env_key = get_env_name("cache_dir");
|
||||
let prev = env::var_os(&env_key);
|
||||
unsafe {
|
||||
env::set_var(&env_key, &root);
|
||||
}
|
||||
f();
|
||||
unsafe {
|
||||
match prev {
|
||||
Some(v) => env::set_var(&env_key, v),
|
||||
None => env::remove_var(&env_key),
|
||||
}
|
||||
}
|
||||
let _ = fs::remove_dir_all(&root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_base_url_strips_path_and_query() {
|
||||
let result = extract_base_url("https://mcp.notion.com/mcp?foo=bar").unwrap();
|
||||
|
||||
assert_eq!(result, "https://mcp.notion.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_base_url_preserves_explicit_port() {
|
||||
let result = extract_base_url("http://localhost:8080/mcp").unwrap();
|
||||
|
||||
assert_eq!(result, "http://localhost:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_base_url_standard_port_omitted() {
|
||||
let result = extract_base_url("https://example.com/mcp/v1").unwrap();
|
||||
|
||||
assert_eq!(result, "https://example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_base_url_rejects_invalid_url() {
|
||||
assert!(extract_base_url("not-a-url").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn registered_client_id_roundtrip() {
|
||||
with_temp_cache(|| {
|
||||
save_registered_client_id("notion", "client-xyz-123").unwrap();
|
||||
|
||||
let loaded = load_registered_client_id("notion");
|
||||
|
||||
assert_eq!(loaded, Some("client-xyz-123".to_string()));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn load_registered_client_id_returns_none_for_missing() {
|
||||
with_temp_cache(|| {
|
||||
let loaded = load_registered_client_id("no-such-server");
|
||||
|
||||
assert!(loaded.is_none());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn registered_client_id_second_save_overwrites_first() {
|
||||
with_temp_cache(|| {
|
||||
save_registered_client_id("github", "first-id").unwrap();
|
||||
save_registered_client_id("github", "second-id").unwrap();
|
||||
|
||||
let loaded = load_registered_client_id("github");
|
||||
|
||||
assert_eq!(loaded, Some("second-id".to_string()));
|
||||
});
|
||||
}
|
||||
}
|
||||
+163
-8
@@ -6,7 +6,10 @@ use self::completer::ReplCompleter;
|
||||
use self::highlighter::ReplHighlighter;
|
||||
use self::prompt::ReplPrompt;
|
||||
|
||||
use crate::client::{call_chat_completions, call_chat_completions_streaming, init_client, oauth};
|
||||
use crate::client::{
|
||||
Message, MessageRole, call_chat_completions, call_chat_completions_streaming, init_client,
|
||||
oauth,
|
||||
};
|
||||
use crate::config::{
|
||||
AgentVariables, AppConfig, AssertState, Input, LastMessage, RequestContext, StateFlags,
|
||||
macro_execute,
|
||||
@@ -20,7 +23,7 @@ use crate::utils::{
|
||||
};
|
||||
|
||||
use crate::sandbox::SANDBOX_ENV_FLAG;
|
||||
use crate::{config, graph, resolve_oauth_client};
|
||||
use crate::{config, graph, mcp, resolve_oauth_client};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use crossterm::cursor::SetCursorStyle;
|
||||
use fancy_regex::Regex;
|
||||
@@ -29,9 +32,9 @@ use log::warn;
|
||||
use parking_lot::RwLock;
|
||||
use reedline::CursorConfig;
|
||||
use reedline::{
|
||||
ColumnarMenu, EditCommand, EditMode, Emacs, KeyCode, KeyModifiers, Keybindings, Reedline,
|
||||
ReedlineEvent, ReedlineMenu, ValidationResult, Validator, Vi, default_emacs_keybindings,
|
||||
default_vi_insert_keybindings, default_vi_normal_keybindings,
|
||||
ColumnarMenu, EditCommand, EditMode, Emacs, FileBackedHistory, KeyCode, KeyModifiers,
|
||||
Keybindings, Reedline, ReedlineEvent, ReedlineMenu, ValidationResult, Validator, Vi,
|
||||
default_emacs_keybindings, default_vi_insert_keybindings, default_vi_normal_keybindings,
|
||||
};
|
||||
use reedline::{MenuBuilder, Signal};
|
||||
use std::sync::LazyLock;
|
||||
@@ -49,7 +52,7 @@ pub const DEFAULT_CONTINUATION_PROMPT: &str = indoc! {"
|
||||
4. Continue with the next pending item now. Call tools immediately."
|
||||
};
|
||||
|
||||
static REPL_COMMANDS: LazyLock<[ReplCommand; 49]> = LazyLock::new(|| {
|
||||
static REPL_COMMANDS: LazyLock<[ReplCommand; 50]> = LazyLock::new(|| {
|
||||
[
|
||||
ReplCommand::new(".help", "Show this help guide", AssertState::pass()),
|
||||
ReplCommand::new(".info", "Show system info", AssertState::pass()),
|
||||
@@ -63,6 +66,11 @@ static REPL_COMMANDS: LazyLock<[ReplCommand; 49]> = LazyLock::new(|| {
|
||||
"Authenticate the current model client via OAuth (if configured)",
|
||||
AssertState::pass(),
|
||||
),
|
||||
ReplCommand::new(
|
||||
".mcp auth",
|
||||
"Authenticate with an MCP server via OAuth",
|
||||
AssertState::pass(),
|
||||
),
|
||||
ReplCommand::new(
|
||||
".edit config",
|
||||
"Modify configuration file",
|
||||
@@ -313,6 +321,58 @@ Type ".help" for additional help.
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let (messages_snapshot, compressed_count) = {
|
||||
let ctx = self.ctx.read();
|
||||
if let Some(session) = &ctx.session {
|
||||
let msgs: Vec<Message> = session
|
||||
.messages()
|
||||
.iter()
|
||||
.filter(|m| !m.role.is_system())
|
||||
.cloned()
|
||||
.collect();
|
||||
let compressed = session.compressed_messages().len();
|
||||
(msgs, compressed)
|
||||
} else {
|
||||
(vec![], 0)
|
||||
}
|
||||
};
|
||||
|
||||
if !messages_snapshot.is_empty() || compressed_count > 0 {
|
||||
let app = Arc::clone(&self.ctx.read().app.config);
|
||||
if compressed_count > 0 {
|
||||
println!(
|
||||
"{}",
|
||||
dimmed_text(&format!(
|
||||
"({compressed_count} earlier messages not shown; compressed for context)"
|
||||
))
|
||||
);
|
||||
println!();
|
||||
}
|
||||
|
||||
for message in &messages_snapshot {
|
||||
match message.role {
|
||||
MessageRole::User => {
|
||||
if let Some(text) = message.content.as_text() {
|
||||
println!("{}", dimmed_text("You:"));
|
||||
println!("{text}");
|
||||
println!();
|
||||
}
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
if let Some(text) = message.content.as_text() {
|
||||
app.print_markdown(text)?;
|
||||
println!();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
println!("{}", dimmed_text("─── ↑ previous conversation ↑ ───"));
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
if self.abort_signal.aborted_ctrld() {
|
||||
break;
|
||||
@@ -388,6 +448,14 @@ Type ".help" for additional help.
|
||||
editor = editor.with_buffer_editor(command, temp_file);
|
||||
}
|
||||
|
||||
if app.save_shell_history {
|
||||
let ctx = ctx.read();
|
||||
let history_path = paths::repl_history_file(&ctx.session);
|
||||
if let Ok(history) = FileBackedHistory::with_file(1000, history_path) {
|
||||
editor = editor.with_history(Box::new(history));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(editor)
|
||||
}
|
||||
|
||||
@@ -541,6 +609,53 @@ pub async fn run_repl_command(
|
||||
let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?;
|
||||
oauth::run_oauth_flow(&*provider, &client_name).await?;
|
||||
}
|
||||
".mcp" => match args {
|
||||
Some(args) => {
|
||||
let mut parts = args.splitn(2, char::is_whitespace);
|
||||
let sub = parts.next().unwrap_or("").trim();
|
||||
let rest = parts.next().map(str::trim).unwrap_or("");
|
||||
match sub {
|
||||
"auth" => {
|
||||
if rest.is_empty() {
|
||||
println!("Usage: .mcp auth <server_name>");
|
||||
} else {
|
||||
let server_name = rest;
|
||||
let server_spec = ctx
|
||||
.app
|
||||
.mcp_config
|
||||
.as_ref()
|
||||
.and_then(|c| c.mcp_servers.get(server_name))
|
||||
.cloned();
|
||||
match server_spec {
|
||||
None => {
|
||||
bail!("MCP server '{}' not found in mcp.json.", server_name)
|
||||
}
|
||||
Some(spec) if !spec.is_remote() => bail!(
|
||||
"MCP server '{}' uses stdio transport; \
|
||||
OAuth is only supported for http/sse servers.",
|
||||
server_name
|
||||
),
|
||||
Some(spec) => {
|
||||
let url = spec
|
||||
.url
|
||||
.as_deref()
|
||||
.expect("validated: remote spec has url");
|
||||
let client_id = spec.oauth_client_id.as_deref();
|
||||
mcp::oauth::run_mcp_oauth_flow(server_name, url, client_id)
|
||||
.await?;
|
||||
println!(
|
||||
"Authentication saved. \
|
||||
Restart Coyote to connect to '{server_name}'."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => unknown_command()?,
|
||||
}
|
||||
}
|
||||
None => println!("Usage: .mcp auth <server_name>"),
|
||||
},
|
||||
".prompt" => match args {
|
||||
Some(text) => {
|
||||
let app = Arc::clone(&ctx.app.config);
|
||||
@@ -632,6 +747,46 @@ pub async fn run_repl_command(
|
||||
session.set_autonaming(false);
|
||||
}
|
||||
}
|
||||
if let Some(session) = &ctx.session {
|
||||
let messages_snapshot: Vec<Message> = session
|
||||
.messages()
|
||||
.iter()
|
||||
.filter(|m| !m.role.is_system())
|
||||
.cloned()
|
||||
.collect();
|
||||
let compressed_count = session.compressed_messages().len();
|
||||
if !messages_snapshot.is_empty() || compressed_count > 0 {
|
||||
if compressed_count > 0 {
|
||||
println!(
|
||||
"{}",
|
||||
dimmed_text(&format!(
|
||||
"({compressed_count} earlier messages not shown — compressed for context)"
|
||||
))
|
||||
);
|
||||
println!();
|
||||
}
|
||||
for message in &messages_snapshot {
|
||||
match message.role {
|
||||
MessageRole::User => {
|
||||
if let Some(text) = message.content.as_text() {
|
||||
println!("{}", dimmed_text("You:"));
|
||||
println!("{text}");
|
||||
println!();
|
||||
}
|
||||
}
|
||||
MessageRole::Assistant => {
|
||||
if let Some(text) = message.content.as_text() {
|
||||
app.print_markdown(text)?;
|
||||
println!();
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
println!("{}", dimmed_text("─── ↑ previous conversation ↑ ───"));
|
||||
println!();
|
||||
}
|
||||
}
|
||||
}
|
||||
".install" => {
|
||||
let trimmed = args.map(str::trim).unwrap_or("");
|
||||
@@ -1415,8 +1570,8 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn repl_commands_has_49_entries() {
|
||||
assert_eq!(REPL_COMMANDS.len(), 49);
|
||||
fn repl_commands_has_50_entries() {
|
||||
assert_eq!(REPL_COMMANDS.len(), 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user