11 Commits

Author SHA1 Message Date
Dark-Alex-17 2ec2aec4c0 style: updated the previous conversation marker a tad
CI / All (macos-latest) (push) Waiting to run
CI / All (windows-latest) (push) Waiting to run
CI / All (ubuntu-latest) (push) Failing after 26s
2026-07-02 16:49:38 -06:00
Dark-Alex-17 c2cb4ac433 feat: Session-specific, file-backed history in the REPL
CI / All (macos-latest) (push) Waiting to run
CI / All (windows-latest) (push) Waiting to run
CI / All (ubuntu-latest) (push) Failing after 25s
2026-07-02 16:44:55 -06:00
Dark-Alex-17 605a9170b0 feat: Replay session output when a user re-enters a session so all output can be seen again 2026-07-02 16:35:10 -06:00
Dark-Alex-17 385bd3eda2 fix: Overrode the default JSON content-type for MCP OAuth so its properly application/x-www-form-urlencoded
CI / All (macos-latest) (push) Waiting to run
CI / All (windows-latest) (push) Waiting to run
CI / All (ubuntu-latest) (push) Failing after 26s
2026-07-02 15:53:29 -06:00
Dark-Alex-17 6c3d96ac83 feat: Added confirmation message after MCP Oauth succeeds when invoked from --auth-mcp
CI / All (macos-latest) (push) Waiting to run
CI / All (windows-latest) (push) Waiting to run
CI / All (ubuntu-latest) (push) Failing after 26s
2026-07-02 15:22:22 -06:00
Dark-Alex-17 aa1fe7f7aa fmt: applied formatting 2026-07-02 15:22:00 -06:00
Dark-Alex-17 5e50828108 fix: typo in mcp file name 2026-07-02 15:20:57 -06:00
Dark-Alex-17 693e2d9672 feat: Created the --auth-mcp CLI flag to let users auth with remote MCP servers without needing to be in the REPL
CI / All (macos-latest) (push) Waiting to run
CI / All (windows-latest) (push) Waiting to run
CI / All (ubuntu-latest) (push) Failing after 26s
2026-07-02 14:51:52 -06:00
Dark-Alex-17 16f324cefc feat: add OAuth authentication support for remote MCP servers 2026-07-02 14:43:24 -06:00
Dark-Alex-17 cc50d39ab4 fix: Added uvx wrapper for macos-based sandboxes
CI / All (macos-latest) (push) Waiting to run
CI / All (windows-latest) (push) Waiting to run
CI / All (ubuntu-latest) (push) Failing after 28s
2026-07-02 12:57:12 -06:00
Dark-Alex-17 fc23b532d9 feat: Added mixin for sisyphus so the ddg MCP server can search arbitrary domains 2026-07-02 12:56:18 -06:00
15 changed files with 821 additions and 30 deletions
+11
View File
@@ -0,0 +1,11 @@
schemaVersion: '1'
kind: mixin
name: sisyphus-ddg
description: >
Allows Sisyphus to hit all domains since it utilizes the DuckDuckGo
MCP server. This allows the MCP server to actually perform web searches
on arbitrary domains and retrieve info for the agent.
network:
allowedDomains:
- '*'
+7 -2
View File
@@ -252,9 +252,14 @@ commands:
bzip2 bzip2
user: '1000' user: '1000'
description: Install system prerequisites (including pandoc for fetch_url_via_curl) description: Install system prerequisites (including pandoc for fetch_url_via_curl)
- command: 'curl -LsSf https://astral.sh/uv/install.sh | sh' - command: |
curl -LsSf https://astral.sh/uv/install.sh | sh
if [ -f "$HOME/.local/bin/uv" ]; then
printf '#!/bin/sh\nexec uv tool run "$@"\n' > "$HOME/.local/bin/uvx"
chmod +x "$HOME/.local/bin/uvx"
fi
user: '1000' user: '1000'
description: Install uv (required for Python-based custom tools) description: Install uv and write a uvx shell wrapper (the installer may place a macOS binary at this path on Docker-for-Mac hosts, which the Linux container cannot execute)
- command: | - command: |
set -euo pipefail set -euo pipefail
USQL_VERSION=0.21.4 USQL_VERSION=0.21.4
+29 -1
View File
@@ -5,9 +5,9 @@ use crate::utils::list_file_names;
use crate::vault::Vault; use crate::vault::Vault;
use clap_complete::{CompletionCandidate, Shell, generate}; use clap_complete::{CompletionCandidate, Shell, generate};
use clap_complete_nushell::Nushell; use clap_complete_nushell::Nushell;
use std::env;
use std::ffi::OsStr; use std::ffi::OsStr;
use std::io; use std::io;
use std::{env, fs};
const COYOTE_CLI_NAME: &str = "coyote"; const COYOTE_CLI_NAME: &str = "coyote";
@@ -134,6 +134,34 @@ pub(super) fn session_completer(current: &OsStr) -> Vec<CompletionCandidate> {
.collect() .collect()
} }
pub(super) fn mcp_server_completer(current: &OsStr) -> Vec<CompletionCandidate> {
let cur = current.to_string_lossy();
let content = match fs::read_to_string(paths::mcp_config_file()) {
Ok(c) => c,
Err(_) => return vec![],
};
let json: serde_json::Value = match serde_json::from_str(&content) {
Ok(v) => v,
Err(_) => return vec![],
};
let servers = match json.get("mcpServers").and_then(|v| v.as_object()) {
Some(s) => s,
None => return vec![],
};
servers
.iter()
.filter(|(_, v)| {
v.get("type")
.and_then(|t| t.as_str())
.map(|t| t == "http" || t == "sse")
.unwrap_or(false)
})
.filter(|(k, _)| k.starts_with(&*cur))
.map(|(k, _)| CompletionCandidate::new(k))
.collect()
}
pub(super) fn secrets_completer(current: &OsStr) -> Vec<CompletionCandidate> { pub(super) fn secrets_completer(current: &OsStr) -> Vec<CompletionCandidate> {
let cur = current.to_string_lossy(); let cur = current.to_string_lossy();
match load_app_config_for_completion() { match load_app_config_for_completion() {
+5 -2
View File
@@ -1,8 +1,8 @@
mod completer; mod completer;
use crate::cli::completer::{ use crate::cli::completer::{
ShellCompletion, agent_completer, macro_completer, model_completer, rag_completer, ShellCompletion, agent_completer, macro_completer, mcp_server_completer, model_completer,
role_completer, secrets_completer, session_completer, rag_completer, role_completer, secrets_completer, session_completer,
}; };
use crate::config::{AssetCategory, InstallFilter, MemoryScope}; use crate::config::{AssetCategory, InstallFilter, MemoryScope};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
@@ -171,6 +171,9 @@ pub struct Cli {
/// Authenticate with an LLM provider using OAuth (e.g., --authenticate client_name) /// Authenticate with an LLM provider using OAuth (e.g., --authenticate client_name)
#[arg(long, exclusive = true, value_name = "CLIENT_NAME")] #[arg(long, exclusive = true, value_name = "CLIENT_NAME")]
pub authenticate: Option<Option<String>>, pub authenticate: Option<Option<String>>,
/// Authenticate with an OAuth-protected remote MCP server (e.g., --auth-mcp server_name)
#[arg(long, exclusive = true, value_name = "SERVER_NAME", add = ArgValueCompleter::new(mcp_server_completer))]
pub auth_mcp: Option<String>,
/// Generate static shell completion scripts /// Generate static shell completion scripts
#[arg(long, value_name = "SHELL", value_enum)] #[arg(long, value_name = "SHELL", value_enum)]
pub completions: Option<ShellCompletion>, pub completions: Option<ShellCompletion>,
+7
View File
@@ -133,6 +133,13 @@ impl MessageContent {
} }
} }
pub fn as_text(&self) -> Option<&str> {
match self {
MessageContent::Text(text) => Some(text),
_ => None,
}
}
pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) { pub fn merge_prompt(&mut self, replace_fn: impl Fn(&str) -> String) {
match self { match self {
MessageContent::Text(text) => *text = replace_fn(text), MessageContent::Text(text) => *text = replace_fn(text),
+10 -4
View File
@@ -53,6 +53,10 @@ pub trait OAuthProvider: Send + Sync {
fn extra_request_headers(&self) -> Vec<(&str, &str)> { fn extra_request_headers(&self) -> Vec<(&str, &str)> {
vec![] vec![]
} }
fn fixed_redirect_uri(&self) -> Option<String> {
None
}
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -72,14 +76,16 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
let state = Uuid::new_v4().to_string(); let state = Uuid::new_v4().to_string();
let redirect_uri = if provider.uses_localhost_redirect() { let (redirect_uri, use_callback_listener) = if let Some(fixed) = provider.fixed_redirect_uri() {
(fixed, true)
} else if provider.uses_localhost_redirect() {
let listener = TcpListener::bind("127.0.0.1:0")?; let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port(); let port = listener.local_addr()?.port();
let uri = format!("http://127.0.0.1:{port}/callback"); let uri = format!("http://127.0.0.1:{port}/callback");
drop(listener); drop(listener);
uri (uri, true)
} else { } else {
provider.redirect_uri().to_string() (provider.redirect_uri().to_string(), false)
}; };
let encoded_scopes = urlencoding::encode(provider.scopes()); let encoded_scopes = urlencoding::encode(provider.scopes());
@@ -112,7 +118,7 @@ pub async fn run_oauth_flow(provider: &dyn OAuthProvider, client_name: &str) ->
let _ = open::that(&authorize_url); let _ = open::that(&authorize_url);
let (code, returned_state) = if provider.uses_localhost_redirect() { let (code, returned_state) = if use_callback_listener {
listen_for_oauth_callback(&redirect_uri)? listen_for_oauth_callback(&redirect_uri)?
} else { } else {
let input = Text::new("Paste the authorization code:").prompt()?; let input = Text::new("Paste the authorization code:").prompt()?;
+11 -2
View File
@@ -1,4 +1,6 @@
use crate::mcp::{ConnectedServer, JsonField, McpServer, McpTransportType, spawn_mcp_server}; use crate::mcp::{
ConnectedServer, JsonField, McpServer, McpTransportType, oauth, spawn_mcp_server,
};
use anyhow::Result; use anyhow::Result;
use parking_lot::Mutex; use parking_lot::Mutex;
@@ -99,7 +101,12 @@ impl McpFactory {
return Ok(existing); return Ok(existing);
} }
let handle = spawn_mcp_server(spec, log_path).await?; let bearer_token = if spec.is_remote() {
oauth::load_valid_mcp_token(name)
} else {
None
};
let handle = spawn_mcp_server(spec, log_path, bearer_token).await?;
self.insert_active(key, &handle); self.insert_active(key, &handle);
Ok(handle) Ok(handle)
} }
@@ -125,6 +132,7 @@ mod tests {
cwd: None, cwd: None,
url: None, url: None,
headers: None, headers: None,
oauth_client_id: None,
} }
} }
@@ -141,6 +149,7 @@ mod tests {
cwd: None, cwd: None,
url: Some(url.to_string()), url: Some(url.to_string()),
headers, headers,
oauth_client_id: None,
} }
} }
+1
View File
@@ -135,6 +135,7 @@ const RAGS_DIR_NAME: &str = "rags";
const FUNCTIONS_DIR_NAME: &str = "functions"; const FUNCTIONS_DIR_NAME: &str = "functions";
const FUNCTIONS_BIN_DIR_NAME: &str = "bin"; const FUNCTIONS_BIN_DIR_NAME: &str = "bin";
const AGENTS_DIR_NAME: &str = "agents"; const AGENTS_DIR_NAME: &str = "agents";
const REPL_HISTORY_DIR_NAME: &str = "repl-history";
const GLOBAL_TOOLS_DIR_NAME: &str = "tools"; const GLOBAL_TOOLS_DIR_NAME: &str = "tools";
const GLOBAL_TOOLS_UTILS_DIR_NAME: &str = "utils"; const GLOBAL_TOOLS_UTILS_DIR_NAME: &str = "utils";
const BASH_PROMPT_UTILS_FILE_NAME: &str = "prompt-utils.sh"; const BASH_PROMPT_UTILS_FILE_NAME: &str = "prompt-utils.sh";
+16
View File
@@ -8,6 +8,8 @@ use super::{
SKILLS_DIR_NAME, WORKSPACE_MEMORY_DIR_NAME, SKILLS_DIR_NAME, WORKSPACE_MEMORY_DIR_NAME,
}; };
use crate::client::ProviderModels; use crate::client::ProviderModels;
use crate::config::REPL_HISTORY_DIR_NAME;
use crate::config::session::Session;
use crate::utils::{get_env_name, list_file_names, normalize_env_name}; use crate::utils::{get_env_name, list_file_names, normalize_env_name};
use anyhow::{Context, Result, anyhow, bail}; use anyhow::{Context, Result, anyhow, bail};
@@ -320,6 +322,20 @@ pub fn workspace_memory_dir_for(workspace_root: &Path) -> PathBuf {
.join(MEMORY_DIR_NAME) .join(MEMORY_DIR_NAME)
} }
pub fn repl_history_dir() -> PathBuf {
cache_path().join(REPL_HISTORY_DIR_NAME)
}
pub fn repl_history_file(session: &Option<Session>) -> PathBuf {
let history_key = if let Some(session) = &session {
format!("session_{}", session.name().replace('/', "_"))
} else {
"default".to_string()
};
repl_history_dir().join(history_key)
}
pub fn log_config() -> Result<(LevelFilter, Option<PathBuf>)> { pub fn log_config() -> Result<(LevelFilter, Option<PathBuf>)> {
let log_level = env::var(get_env_name("log_level")) let log_level = env::var(get_env_name("log_level"))
.ok() .ok()
+12
View File
@@ -2340,6 +2340,17 @@ impl RequestContext {
} }
_ => vec![], _ => vec![],
}; };
} else if cmd == ".mcp" && args.first() == Some(&"auth") && args.len() == 2 {
if let Some(mcp_config) = &self.app.mcp_config {
values = super::map_completion_values(
mcp_config
.mcp_servers
.iter()
.filter(|(_, spec)| spec.is_remote())
.map(|(name, _)| name.clone())
.collect(),
);
}
} else if (cmd == ".edit" && args.first() == Some(&"skill") && args.len() == 2) } else if (cmd == ".edit" && args.first() == Some(&"skill") && args.len() == 2)
|| (cmd == ".skill" && args.first() == Some(&"load") && args.len() == 2) || (cmd == ".skill" && args.first() == Some(&"load") && args.len() == 2)
{ {
@@ -3687,6 +3698,7 @@ mod tests {
cwd: None, cwd: None,
url: None, url: None,
headers: None, headers: None,
oauth_client_id: None,
}, },
); );
} }
+8
View File
@@ -163,6 +163,14 @@ impl Session {
self.messages.is_empty() && self.compressed_messages.is_empty() self.messages.is_empty() && self.compressed_messages.is_empty()
} }
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn compressed_messages(&self) -> &[Message] {
&self.compressed_messages
}
pub fn name(&self) -> &str { pub fn name(&self) -> &str {
&self.name &self.name
} }
+46 -2
View File
@@ -28,11 +28,12 @@ use crate::config::{
install_builtins, list_agents, load_env_file, macro_execute, sync_models, install_builtins, list_agents, load_env_file, macro_execute, sync_models,
}; };
use crate::function::supervisor::{GuardrailAction, check_pending_agents_guardrail}; use crate::function::supervisor::{GuardrailAction, check_pending_agents_guardrail};
use crate::mcp::McpServersConfig;
use crate::render::{prompt_theme, render_error}; use crate::render::{prompt_theme, render_error};
use crate::repl::Repl; use crate::repl::Repl;
use crate::utils::*; use crate::utils::*;
use crate::vault::Vault; use crate::vault::{Vault, interpolate_secrets};
use anyhow::{Result, anyhow, bail}; use anyhow::{Context, Result, anyhow, bail};
use clap::{CommandFactory, Parser}; use clap::{CommandFactory, Parser};
use clap_complete::CompleteEnv; use clap_complete::CompleteEnv;
use client::ClientConfig; use client::ClientConfig;
@@ -120,6 +121,49 @@ async fn main() -> Result<()> {
return Ok(()); return Ok(());
} }
if let Some(server_name) = &cli.auth_mcp {
let cfg = Config::load_with_interpolation(true).await?;
let app_config = AppConfig::from_config(cfg)?;
let vault = Vault::init(&app_config)?;
let mcp_path = paths::mcp_config_file();
if !mcp_path.exists() {
bail!(
"No MCP configuration file found at '{}'",
mcp_path.display()
);
}
let raw = tokio::fs::read_to_string(&mcp_path)
.await
.with_context(|| format!("Failed to read MCP config at '{}'", mcp_path.display()))?;
let (content, missing) = interpolate_secrets(&raw, &vault)?;
if !missing.is_empty() {
bail!(
"MCP config references vault secrets that are missing: {:?}",
missing
);
}
let mcp_config: McpServersConfig =
serde_json::from_str(&content).context("Failed to parse MCP config file")?;
let spec = mcp_config
.mcp_servers
.get(server_name.as_str())
.ok_or_else(|| anyhow!("MCP server '{server_name}' not found in mcp.json"))?;
if !spec.is_remote() {
bail!(
"MCP server '{server_name}' is a stdio server; OAuth is only supported for http/sse servers"
);
}
let url = spec.url.as_deref().expect("validated: remote spec has url");
mcp::oauth::run_mcp_oauth_flow(server_name, url, spec.oauth_client_id.as_deref()).await?;
println!("Authentication saved. '{server_name}' is now available for use.");
return Ok(());
}
if vault_flags { if vault_flags {
let cfg = Config::load_with_interpolation(true).await?; let cfg = Config::load_with_interpolation(true).await?;
let app_config = AppConfig::from_config(cfg)?; let app_config = AppConfig::from_config(cfg)?;
+166 -9
View File
@@ -1,3 +1,4 @@
pub(crate) mod oauth;
mod sse_transport; mod sse_transport;
use crate::config::AppConfig; use crate::config::AppConfig;
@@ -73,6 +74,8 @@ pub(crate) struct McpServer {
pub url: Option<String>, pub url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub headers: Option<IndexMap<String, String>>, pub headers: Option<IndexMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth_client_id: Option<String>,
} }
impl McpServer { impl McpServer {
@@ -107,10 +110,10 @@ impl McpServer {
"MCP server '{name}' is missing a \"command\" field (required for stdio transport)" "MCP server '{name}' is missing a \"command\" field (required for stdio transport)"
)); ));
} }
if self.url.is_some() || self.headers.is_some() { if self.url.is_some() || self.headers.is_some() || self.oauth_client_id.is_some() {
return Err(anyhow!( return Err(anyhow!(
"MCP server '{name}' has type \"stdio\" but also specifies remote fields \ "MCP server '{name}' has type \"stdio\" but also specifies remote fields \
(url/headers). Remove the remote fields or change the type to \"http\" or \"sse\"." (url/headers/oauth_client_id). Remove the remote fields or change the type to \"http\" or \"sse\"."
)); ));
} }
} }
@@ -237,7 +240,7 @@ impl McpRegistry {
debug!("Starting selected MCP servers: {:?}", ids_to_start); debug!("Starting selected MCP servers: {:?}", ids_to_start);
let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( let results: Vec<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> = stream::iter(
ids_to_start ids_to_start
.into_iter() .into_iter()
.map(|id| async { self.start_server(id).await }), .map(|id| async { self.start_server(id).await }),
@@ -246,7 +249,7 @@ impl McpRegistry {
.try_collect() .try_collect()
.await?; .await?;
for (id, server, catalog) in results { for (id, server, catalog) in results.into_iter().flatten() {
self.servers.insert(id.clone(), server); self.servers.insert(id.clone(), server);
self.catalogs.insert(id, catalog); self.catalogs.insert(id, catalog);
} }
@@ -257,14 +260,30 @@ impl McpRegistry {
async fn start_server( async fn start_server(
&self, &self,
id: String, id: String,
) -> Result<(String, Arc<ConnectedServer>, ServerCatalog)> { ) -> Result<Option<(String, Arc<ConnectedServer>, ServerCatalog)>> {
let spec = self let spec = self
.config .config
.as_ref() .as_ref()
.and_then(|c| c.mcp_servers.get(&id)) .and_then(|c| c.mcp_servers.get(&id))
.with_context(|| format!("MCP server not found in config: {id}"))?; .with_context(|| format!("MCP server not found in config: {id}"))?;
let service = spawn_mcp_server(spec, self.log_path.as_deref()).await?; let bearer_token = if spec.is_remote() {
oauth::load_valid_mcp_token(&id)
} else {
None
};
let service = match spawn_mcp_server(spec, self.log_path.as_deref(), bearer_token).await {
Ok(s) => s,
Err(e) if is_auth_required_error(&e) => {
warn!(
"MCP server '{id}' requires OAuth authentication. \
Run `.mcp auth {id}` in the REPL to authenticate, then restart Coyote."
);
return Ok(None);
}
Err(e) => return Err(e),
};
let tools = service.list_tools(None).await?; let tools = service.list_tools(None).await?;
debug!("Available tools for MCP server {id}: {tools:?}"); debug!("Available tools for MCP server {id}: {tools:?}");
@@ -289,7 +308,7 @@ impl McpRegistry {
info!("Started MCP server: {id}"); info!("Started MCP server: {id}");
Ok((id.to_string(), service, catalog)) Ok(Some((id.to_string(), service, catalog)))
} }
fn resolve_server_ids(&self, enabled_mcp_servers: Option<Vec<String>>) -> Vec<String> { fn resolve_server_ids(&self, enabled_mcp_servers: Option<Vec<String>>) -> Vec<String> {
@@ -337,15 +356,18 @@ impl McpRegistry {
pub(crate) async fn spawn_mcp_server( pub(crate) async fn spawn_mcp_server(
spec: &McpServer, spec: &McpServer,
log_path: Option<&Path>, log_path: Option<&Path>,
bearer_token: Option<String>,
) -> Result<Arc<ConnectedServer>> { ) -> Result<Arc<ConnectedServer>> {
match spec.transport_type { match spec.transport_type {
McpTransportType::Http => { McpTransportType::Http => {
let url = spec.url.as_deref().expect("validated: http spec has url"); let url = spec.url.as_deref().expect("validated: http spec has url");
spawn_http_mcp_server(url, spec.headers.as_ref()).await let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
spawn_http_mcp_server(url, headers.as_ref()).await
} }
McpTransportType::Sse => { McpTransportType::Sse => {
let url = spec.url.as_deref().expect("validated: sse spec has url"); let url = spec.url.as_deref().expect("validated: sse spec has url");
spawn_sse_mcp_server(url, spec.headers.as_ref()).await let headers = merge_bearer_token(spec.headers.as_ref(), bearer_token);
spawn_sse_mcp_server(url, headers.as_ref()).await
} }
McpTransportType::Stdio => { McpTransportType::Stdio => {
let command = spec let command = spec
@@ -357,6 +379,30 @@ pub(crate) async fn spawn_mcp_server(
} }
} }
fn merge_bearer_token(
headers: Option<&IndexMap<String, String>>,
bearer_token: Option<String>,
) -> Option<IndexMap<String, String>> {
match (headers, bearer_token) {
(None, None) => None,
(Some(h), None) => Some(h.clone()),
(None, Some(token)) => {
let mut m = IndexMap::new();
m.insert("Authorization".to_string(), format!("Bearer {token}"));
Some(m)
}
(Some(h), Some(token)) => {
let mut m = h.clone();
m.insert("Authorization".to_string(), format!("Bearer {token}"));
Some(m)
}
}
}
fn is_auth_required_error(e: &anyhow::Error) -> bool {
e.to_string().contains("Auth required")
}
async fn spawn_http_mcp_server( async fn spawn_http_mcp_server(
url: &str, url: &str,
headers: Option<&IndexMap<String, String>>, headers: Option<&IndexMap<String, String>>,
@@ -465,6 +511,7 @@ mod tests {
cwd: None, cwd: None,
url: None, url: None,
headers: None, headers: None,
oauth_client_id: None,
} }
} }
@@ -477,6 +524,7 @@ mod tests {
cwd: None, cwd: None,
url: Some(url.to_string()), url: Some(url.to_string()),
headers: None, headers: None,
oauth_client_id: None,
} }
} }
@@ -489,6 +537,7 @@ mod tests {
cwd: None, cwd: None,
url: Some(url.to_string()), url: Some(url.to_string()),
headers: None, headers: None,
oauth_client_id: None,
} }
} }
@@ -506,6 +555,7 @@ mod tests {
#[test] #[test]
fn validate_stdio_with_command_succeeds() { fn validate_stdio_with_command_succeeds() {
let spec = stdio_server("npx"); let spec = stdio_server("npx");
assert!(spec.validate("test").is_ok()); assert!(spec.validate("test").is_ok());
} }
@@ -519,8 +569,11 @@ mod tests {
cwd: None, cwd: None,
url: None, url: None,
headers: None, headers: None,
oauth_client_id: None,
}; };
let err = spec.validate("test").unwrap_err(); let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("missing a \"command\" field")); assert!(err.to_string().contains("missing a \"command\" field"));
} }
@@ -534,8 +587,11 @@ mod tests {
cwd: None, cwd: None,
url: Some("http://localhost".into()), url: Some("http://localhost".into()),
headers: None, headers: None,
oauth_client_id: None,
}; };
let err = spec.validate("test").unwrap_err(); let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("remote fields")); assert!(err.to_string().contains("remote fields"));
} }
@@ -551,14 +607,18 @@ mod tests {
cwd: None, cwd: None,
url: None, url: None,
headers: Some(headers), headers: Some(headers),
oauth_client_id: None,
}; };
let err = spec.validate("test").unwrap_err(); let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("remote fields")); assert!(err.to_string().contains("remote fields"));
} }
#[test] #[test]
fn validate_http_with_url_succeeds() { fn validate_http_with_url_succeeds() {
let spec = http_server("http://localhost:8080"); let spec = http_server("http://localhost:8080");
assert!(spec.validate("test").is_ok()); assert!(spec.validate("test").is_ok());
} }
@@ -572,8 +632,11 @@ mod tests {
cwd: None, cwd: None,
url: None, url: None,
headers: None, headers: None,
oauth_client_id: None,
}; };
let err = spec.validate("test").unwrap_err(); let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("missing a \"url\" field")); assert!(err.to_string().contains("missing a \"url\" field"));
} }
@@ -587,8 +650,11 @@ mod tests {
cwd: None, cwd: None,
url: Some("http://localhost".into()), url: Some("http://localhost".into()),
headers: None, headers: None,
oauth_client_id: None,
}; };
let err = spec.validate("test").unwrap_err(); let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("stdio fields")); assert!(err.to_string().contains("stdio fields"));
} }
@@ -602,8 +668,11 @@ mod tests {
cwd: None, cwd: None,
url: Some("http://localhost".into()), url: Some("http://localhost".into()),
headers: None, headers: None,
oauth_client_id: None,
}; };
let err = spec.validate("test").unwrap_err(); let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("stdio fields")); assert!(err.to_string().contains("stdio fields"));
} }
@@ -617,14 +686,18 @@ mod tests {
cwd: Some("/tmp".into()), cwd: Some("/tmp".into()),
url: Some("http://localhost".into()), url: Some("http://localhost".into()),
headers: None, headers: None,
oauth_client_id: None,
}; };
let err = spec.validate("test").unwrap_err(); let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("stdio fields")); assert!(err.to_string().contains("stdio fields"));
} }
#[test] #[test]
fn validate_sse_with_url_succeeds() { fn validate_sse_with_url_succeeds() {
let spec = sse_server("http://sse.example.com"); let spec = sse_server("http://sse.example.com");
assert!(spec.validate("test").is_ok()); assert!(spec.validate("test").is_ok());
} }
@@ -638,8 +711,11 @@ mod tests {
cwd: None, cwd: None,
url: None, url: None,
headers: None, headers: None,
oauth_client_id: None,
}; };
let err = spec.validate("test").unwrap_err(); let err = spec.validate("test").unwrap_err();
assert!(err.to_string().contains("missing a \"url\" field")); assert!(err.to_string().contains("missing a \"url\" field"));
} }
@@ -665,9 +741,13 @@ mod tests {
} }
} }
}"#; }"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap(); let config: McpServersConfig = serde_json::from_str(json).unwrap();
assert!(config.mcp_servers.contains_key("my-server")); assert!(config.mcp_servers.contains_key("my-server"));
let spec = &config.mcp_servers["my-server"]; let spec = &config.mcp_servers["my-server"];
assert_eq!(spec.transport_type, McpTransportType::Stdio); assert_eq!(spec.transport_type, McpTransportType::Stdio);
assert_eq!(spec.command.as_deref(), Some("npx")); assert_eq!(spec.command.as_deref(), Some("npx"));
assert_eq!( assert_eq!(
@@ -688,7 +768,9 @@ mod tests {
} }
}"#; }"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap(); let config: McpServersConfig = serde_json::from_str(json).unwrap();
let spec = &config.mcp_servers["remote"]; let spec = &config.mcp_servers["remote"];
assert_eq!(spec.transport_type, McpTransportType::Http); assert_eq!(spec.transport_type, McpTransportType::Http);
assert_eq!(spec.url.as_deref(), Some("http://localhost:8080/mcp")); assert_eq!(spec.url.as_deref(), Some("http://localhost:8080/mcp"));
assert_eq!( assert_eq!(
@@ -713,7 +795,9 @@ mod tests {
} }
}"#; }"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap(); let config: McpServersConfig = serde_json::from_str(json).unwrap();
let env = config.mcp_servers["s"].env.as_ref().unwrap(); let env = config.mcp_servers["s"].env.as_ref().unwrap();
assert!(matches!(env["STR_VAR"], JsonField::Str(ref s) if s == "hello")); assert!(matches!(env["STR_VAR"], JsonField::Str(ref s) if s == "hello"));
assert!(matches!(env["BOOL_VAR"], JsonField::Bool(true))); assert!(matches!(env["BOOL_VAR"], JsonField::Bool(true)));
assert!(matches!(env["INT_VAR"], JsonField::Int(42))); assert!(matches!(env["INT_VAR"], JsonField::Int(42)));
@@ -727,7 +811,9 @@ mod tests {
"remote-api": { "type": "http", "url": "http://api.example.com" } "remote-api": { "type": "http", "url": "http://api.example.com" }
} }
}"#; }"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap(); let config: McpServersConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.mcp_servers.len(), 2); assert_eq!(config.mcp_servers.len(), 2);
assert!(config.mcp_servers.contains_key("github")); assert!(config.mcp_servers.contains_key("github"));
assert!(config.mcp_servers.contains_key("remote-api")); assert!(config.mcp_servers.contains_key("remote-api"));
@@ -736,7 +822,9 @@ mod tests {
#[test] #[test]
fn deserialize_empty_servers_map() { fn deserialize_empty_servers_map() {
let json = r#"{ "mcpServers": {} }"#; let json = r#"{ "mcpServers": {} }"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap(); let config: McpServersConfig = serde_json::from_str(json).unwrap();
assert!(config.mcp_servers.is_empty()); assert!(config.mcp_servers.is_empty());
} }
@@ -751,77 +839,96 @@ mod tests {
} }
} }
}"#; }"#;
let config: McpServersConfig = serde_json::from_str(json).unwrap(); let config: McpServersConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.mcp_servers["s"].cwd.as_deref(), Some("/tmp/work")); assert_eq!(config.mcp_servers["s"].cwd.as_deref(), Some("/tmp/work"));
} }
#[test] #[test]
fn resolve_all_returns_all_configured_servers() { fn resolve_all_returns_all_configured_servers() {
let registry = make_registry_with_config(&["github", "slack", "jira"]); let registry = make_registry_with_config(&["github", "slack", "jira"]);
let mut ids = registry.resolve_server_ids(Some(vec!["all".to_string()])); let mut ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
ids.sort(); ids.sort();
assert_eq!(ids, vec!["github", "jira", "slack"]); assert_eq!(ids, vec!["github", "jira", "slack"]);
} }
#[test] #[test]
fn resolve_comma_separated_returns_matching_servers() { fn resolve_comma_separated_returns_matching_servers() {
let registry = make_registry_with_config(&["github", "slack", "jira"]); let registry = make_registry_with_config(&["github", "slack", "jira"]);
let mut ids = let mut ids =
registry.resolve_server_ids(Some(vec!["github".to_string(), "jira".to_string()])); registry.resolve_server_ids(Some(vec!["github".to_string(), "jira".to_string()]));
ids.sort(); ids.sort();
assert_eq!(ids, vec!["github", "jira"]); assert_eq!(ids, vec!["github", "jira"]);
} }
#[test] #[test]
fn resolve_single_server_name() { fn resolve_single_server_name() {
let registry = make_registry_with_config(&["github", "slack"]); let registry = make_registry_with_config(&["github", "slack"]);
let ids = registry.resolve_server_ids(Some(vec!["slack".to_string()])); let ids = registry.resolve_server_ids(Some(vec!["slack".to_string()]));
assert_eq!(ids, vec!["slack"]); assert_eq!(ids, vec!["slack"]);
} }
#[test] #[test]
fn resolve_none_returns_empty() { fn resolve_none_returns_empty() {
let registry = make_registry_with_config(&["github"]); let registry = make_registry_with_config(&["github"]);
let ids = registry.resolve_server_ids(None); let ids = registry.resolve_server_ids(None);
assert!(ids.is_empty()); assert!(ids.is_empty());
} }
#[test] #[test]
fn resolve_no_config_returns_empty() { fn resolve_no_config_returns_empty() {
let registry = McpRegistry::default(); let registry = McpRegistry::default();
let ids = registry.resolve_server_ids(Some(vec!["all".to_string()])); let ids = registry.resolve_server_ids(Some(vec!["all".to_string()]));
assert!(ids.is_empty()); assert!(ids.is_empty());
} }
#[test] #[test]
fn resolve_nonexistent_server_filtered_out() { fn resolve_nonexistent_server_filtered_out() {
let registry = make_registry_with_config(&["github"]); let registry = make_registry_with_config(&["github"]);
let ids = registry let ids = registry
.resolve_server_ids(Some(vec!["github".to_string(), "nonexistent".to_string()])); .resolve_server_ids(Some(vec!["github".to_string(), "nonexistent".to_string()]));
assert_eq!(ids, vec!["github"]); assert_eq!(ids, vec!["github"]);
} }
#[test] #[test]
fn resolve_all_nonexistent_returns_empty() { fn resolve_all_nonexistent_returns_empty() {
let registry = make_registry_with_config(&["github"]); let registry = make_registry_with_config(&["github"]);
let ids = registry.resolve_server_ids(Some(vec!["foo".to_string(), "bar".to_string()])); let ids = registry.resolve_server_ids(Some(vec!["foo".to_string(), "bar".to_string()]));
assert!(ids.is_empty()); assert!(ids.is_empty());
} }
#[test] #[test]
fn resolve_trims_whitespace() { fn resolve_trims_whitespace() {
let registry = make_registry_with_config(&["github", "slack"]); let registry = make_registry_with_config(&["github", "slack"]);
let mut ids = registry.resolve_server_ids(Some(vec![ let mut ids = registry.resolve_server_ids(Some(vec![
" github ".to_string(), " github ".to_string(),
" slack ".to_string(), " slack ".to_string(),
])); ]));
ids.sort(); ids.sort();
assert_eq!(ids, vec!["github", "slack"]); assert_eq!(ids, vec!["github", "slack"]);
} }
#[test] #[test]
fn registry_default_is_empty() { fn registry_default_is_empty() {
let registry = McpRegistry::default(); let registry = McpRegistry::default();
assert!(registry.is_empty()); assert!(registry.is_empty());
assert!(registry.list_started_servers().is_empty()); assert!(registry.list_started_servers().is_empty());
assert!(registry.mcp_config().is_none()); assert!(registry.mcp_config().is_none());
@@ -831,6 +938,7 @@ mod tests {
#[test] #[test]
fn registry_with_config_reports_config() { fn registry_with_config_reports_config() {
let registry = make_registry_with_config(&["github"]); let registry = make_registry_with_config(&["github"]);
assert!(registry.mcp_config().is_some()); assert!(registry.mcp_config().is_some());
assert!( assert!(
registry registry
@@ -847,4 +955,53 @@ mod tests {
assert_eq!(MCP_SEARCH_META_FUNCTION_NAME_PREFIX, "mcp_search"); assert_eq!(MCP_SEARCH_META_FUNCTION_NAME_PREFIX, "mcp_search");
assert_eq!(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, "mcp_describe"); assert_eq!(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, "mcp_describe");
} }
#[test]
fn merge_bearer_token_both_none_returns_none() {
assert!(merge_bearer_token(None, None).is_none());
}
#[test]
fn merge_bearer_token_headers_only_passes_through() {
let mut h = IndexMap::new();
h.insert("X-Key".to_string(), "val".to_string());
let result = merge_bearer_token(Some(&h), None).unwrap();
assert_eq!(result["X-Key"], "val");
assert!(!result.contains_key("Authorization"));
}
#[test]
fn merge_bearer_token_token_only_injects_bearer() {
let result = merge_bearer_token(None, Some("tok123".to_string())).unwrap();
assert_eq!(result["Authorization"], "Bearer tok123");
}
#[test]
fn merge_bearer_token_both_merges_and_overrides_authorization() {
let mut h = IndexMap::new();
h.insert("Authorization".to_string(), "old".to_string());
h.insert("X-Custom".to_string(), "keep".to_string());
let result = merge_bearer_token(Some(&h), Some("newtoken".to_string())).unwrap();
assert_eq!(result["Authorization"], "Bearer newtoken");
assert_eq!(result["X-Custom"], "keep");
}
#[test]
fn is_auth_required_error_matches_rmcp_message() {
let e = anyhow!("Auth required, when send initialize request");
assert!(is_auth_required_error(&e));
}
#[test]
fn is_auth_required_error_does_not_match_unrelated() {
let e = anyhow!("Connection refused");
assert!(!is_auth_required_error(&e));
}
} }
+329
View File
@@ -0,0 +1,329 @@
use crate::client::oauth::{OAuthProvider, TokenRequestFormat, load_oauth_tokens, run_oauth_flow};
use crate::config::paths;
use anyhow::{Context, Result, anyhow};
use chrono::Utc;
use inquire::Text;
use log::warn;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::fs;
use std::net::TcpListener;
use url::Url;
#[derive(Debug, Deserialize)]
struct ProtectedResourceMetadata {
#[serde(default)]
authorization_servers: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct OAuthServerMetadata {
authorization_endpoint: String,
token_endpoint: String,
#[serde(default)]
scopes_supported: Vec<String>,
registration_endpoint: Option<String>,
}
#[derive(Serialize, Deserialize)]
struct McpRegistration {
client_id: String,
}
struct McpOAuthProvider {
client_id: String,
authorize_url: String,
token_url: String,
scopes: String,
fixed_redirect: String,
}
impl OAuthProvider for McpOAuthProvider {
fn provider_name(&self) -> &str {
"MCP"
}
fn client_id(&self) -> &str {
&self.client_id
}
fn authorize_url(&self) -> &str {
&self.authorize_url
}
fn token_url(&self) -> &str {
&self.token_url
}
fn redirect_uri(&self) -> &str {
""
}
fn scopes(&self) -> &str {
&self.scopes
}
fn token_request_format(&self) -> TokenRequestFormat {
TokenRequestFormat::FormUrlEncoded
}
fn uses_localhost_redirect(&self) -> bool {
false
}
fn fixed_redirect_uri(&self) -> Option<String> {
Some(self.fixed_redirect.clone())
}
}
pub async fn run_mcp_oauth_flow(
server_name: &str,
server_url: &str,
configured_client_id: Option<&str>,
) -> Result<()> {
let metadata = discover_oauth_metadata(server_url).await?;
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let redirect_uri = format!("http://127.0.0.1:{port}/callback");
let client_id = if let Some(id) = configured_client_id {
id.to_string()
} else if let Some(cached) = load_registered_client_id(server_name) {
cached
} else if let Some(reg_endpoint) = &metadata.registration_endpoint {
match register_client(reg_endpoint, &redirect_uri).await {
Ok(id) => {
let _ = save_registered_client_id(server_name, &id);
id
}
Err(e) => {
warn!("Dynamic client registration failed: {e}. Falling back to manual entry.");
Text::new("Enter the OAuth client ID for this MCP server:")
.prompt()
.context("Failed to read client ID")?
}
}
} else {
Text::new("Enter the OAuth client ID for this MCP server:")
.prompt()
.context("Failed to read client ID")?
};
let provider = McpOAuthProvider {
client_id,
authorize_url: metadata.authorization_endpoint,
token_url: metadata.token_endpoint,
scopes: metadata.scopes_supported.join(" "),
fixed_redirect: redirect_uri,
};
run_oauth_flow(&provider, &mcp_token_key(server_name)).await
}
pub fn load_valid_mcp_token(server_name: &str) -> Option<String> {
let tokens = load_oauth_tokens(&mcp_token_key(server_name))?;
if Utc::now().timestamp() < tokens.expires_at {
Some(tokens.access_token)
} else {
None
}
}
fn mcp_token_key(server_name: &str) -> String {
format!("mcp_{server_name}")
}
fn load_registered_client_id(server_name: &str) -> Option<String> {
let path = paths::oauth_tokens_path().join(format!("mcp_{server_name}_registration.json"));
let content = fs::read_to_string(path).ok()?;
let reg: McpRegistration = serde_json::from_str(&content).ok()?;
Some(reg.client_id)
}
fn save_registered_client_id(server_name: &str, client_id: &str) -> Result<()> {
let dir = paths::oauth_tokens_path();
fs::create_dir_all(&dir)?;
let path = dir.join(format!("mcp_{server_name}_registration.json"));
let reg = McpRegistration {
client_id: client_id.to_string(),
};
fs::write(path, serde_json::to_string_pretty(&reg)?)?;
Ok(())
}
async fn register_client(endpoint: &str, redirect_uri: &str) -> Result<String> {
let body = serde_json::json!({
"client_name": "Coyote",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none"
});
let response: serde_json::Value = Client::new()
.post(endpoint)
.json(&body)
.send()
.await
.context("Failed to reach registration endpoint")?
.json()
.await
.context("Failed to parse registration response")?;
response["client_id"]
.as_str()
.ok_or_else(|| anyhow!("Missing client_id in registration response: {response}"))
.map(|s| s.to_string())
}
async fn discover_oauth_metadata(server_url: &str) -> Result<OAuthServerMetadata> {
let base = extract_base_url(server_url)?;
let client = Client::new();
// RFC 9728: try protected resource metadata first; it points to the auth server
let pr_url = format!("{base}/.well-known/oauth-protected-resource");
if let Ok(resp) = client.get(&pr_url).send().await
&& resp.status().is_success()
&& let Ok(pr) = resp.json::<ProtectedResourceMetadata>().await
&& let Some(auth_server) = pr.authorization_servers.first()
{
let as_url = format!("{auth_server}/.well-known/oauth-authorization-server");
if let Ok(resp) = client.get(&as_url).send().await
&& resp.status().is_success()
&& let Ok(meta) = resp.json::<OAuthServerMetadata>().await
{
return Ok(meta);
}
}
let as_url = format!("{base}/.well-known/oauth-authorization-server");
let resp = client
.get(&as_url)
.send()
.await
.with_context(|| format!("Failed to reach {as_url}"))?;
if resp.status().is_success() {
return resp
.json::<OAuthServerMetadata>()
.await
.with_context(|| format!("Failed to parse OAuth metadata from {as_url}"));
}
Err(anyhow!(
"Could not discover OAuth metadata for '{server_url}'.\n\
Tried:\n {pr_url}\n {as_url}\n\
Ensure the server supports MCP OAuth discovery, or consult its documentation."
))
}
fn extract_base_url(url: &str) -> Result<String> {
let parsed = Url::parse(url).with_context(|| format!("Invalid URL: {url}"))?;
let scheme = parsed.scheme();
let host = parsed
.host_str()
.ok_or_else(|| anyhow!("No host in URL: {url}"))?;
let port = parsed.port().map(|p| format!(":{p}")).unwrap_or_default();
Ok(format!("{scheme}://{host}{port}"))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::get_env_name;
use serial_test::serial;
use std::{
env, fs,
time::{self, SystemTime},
};
fn with_temp_cache<F: FnOnce()>(f: F) {
let unique = SystemTime::now()
.duration_since(time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let root = env::temp_dir().join(format!("coyote-mcp-oauth-test-{unique}"));
fs::create_dir_all(&root).unwrap();
let env_key = get_env_name("cache_dir");
let prev = env::var_os(&env_key);
unsafe {
env::set_var(&env_key, &root);
}
f();
unsafe {
match prev {
Some(v) => env::set_var(&env_key, v),
None => env::remove_var(&env_key),
}
}
let _ = fs::remove_dir_all(&root);
}
#[test]
fn extract_base_url_strips_path_and_query() {
let result = extract_base_url("https://mcp.notion.com/mcp?foo=bar").unwrap();
assert_eq!(result, "https://mcp.notion.com");
}
#[test]
fn extract_base_url_preserves_explicit_port() {
let result = extract_base_url("http://localhost:8080/mcp").unwrap();
assert_eq!(result, "http://localhost:8080");
}
#[test]
fn extract_base_url_standard_port_omitted() {
let result = extract_base_url("https://example.com/mcp/v1").unwrap();
assert_eq!(result, "https://example.com");
}
#[test]
fn extract_base_url_rejects_invalid_url() {
assert!(extract_base_url("not-a-url").is_err());
}
#[test]
#[serial]
fn registered_client_id_roundtrip() {
with_temp_cache(|| {
save_registered_client_id("notion", "client-xyz-123").unwrap();
let loaded = load_registered_client_id("notion");
assert_eq!(loaded, Some("client-xyz-123".to_string()));
});
}
#[test]
#[serial]
fn load_registered_client_id_returns_none_for_missing() {
with_temp_cache(|| {
let loaded = load_registered_client_id("no-such-server");
assert!(loaded.is_none());
});
}
#[test]
#[serial]
fn registered_client_id_second_save_overwrites_first() {
with_temp_cache(|| {
save_registered_client_id("github", "first-id").unwrap();
save_registered_client_id("github", "second-id").unwrap();
let loaded = load_registered_client_id("github");
assert_eq!(loaded, Some("second-id".to_string()));
});
}
}
+163 -8
View File
@@ -6,7 +6,10 @@ use self::completer::ReplCompleter;
use self::highlighter::ReplHighlighter; use self::highlighter::ReplHighlighter;
use self::prompt::ReplPrompt; use self::prompt::ReplPrompt;
use crate::client::{call_chat_completions, call_chat_completions_streaming, init_client, oauth}; use crate::client::{
Message, MessageRole, call_chat_completions, call_chat_completions_streaming, init_client,
oauth,
};
use crate::config::{ use crate::config::{
AgentVariables, AppConfig, AssertState, Input, LastMessage, RequestContext, StateFlags, AgentVariables, AppConfig, AssertState, Input, LastMessage, RequestContext, StateFlags,
macro_execute, macro_execute,
@@ -20,7 +23,7 @@ use crate::utils::{
}; };
use crate::sandbox::SANDBOX_ENV_FLAG; use crate::sandbox::SANDBOX_ENV_FLAG;
use crate::{config, graph, resolve_oauth_client}; use crate::{config, graph, mcp, resolve_oauth_client};
use anyhow::{Context, Result, bail}; use anyhow::{Context, Result, bail};
use crossterm::cursor::SetCursorStyle; use crossterm::cursor::SetCursorStyle;
use fancy_regex::Regex; use fancy_regex::Regex;
@@ -29,9 +32,9 @@ use log::warn;
use parking_lot::RwLock; use parking_lot::RwLock;
use reedline::CursorConfig; use reedline::CursorConfig;
use reedline::{ use reedline::{
ColumnarMenu, EditCommand, EditMode, Emacs, KeyCode, KeyModifiers, Keybindings, Reedline, ColumnarMenu, EditCommand, EditMode, Emacs, FileBackedHistory, KeyCode, KeyModifiers,
ReedlineEvent, ReedlineMenu, ValidationResult, Validator, Vi, default_emacs_keybindings, Keybindings, Reedline, ReedlineEvent, ReedlineMenu, ValidationResult, Validator, Vi,
default_vi_insert_keybindings, default_vi_normal_keybindings, default_emacs_keybindings, default_vi_insert_keybindings, default_vi_normal_keybindings,
}; };
use reedline::{MenuBuilder, Signal}; use reedline::{MenuBuilder, Signal};
use std::sync::LazyLock; use std::sync::LazyLock;
@@ -49,7 +52,7 @@ pub const DEFAULT_CONTINUATION_PROMPT: &str = indoc! {"
4. Continue with the next pending item now. Call tools immediately." 4. Continue with the next pending item now. Call tools immediately."
}; };
static REPL_COMMANDS: LazyLock<[ReplCommand; 49]> = LazyLock::new(|| { static REPL_COMMANDS: LazyLock<[ReplCommand; 50]> = LazyLock::new(|| {
[ [
ReplCommand::new(".help", "Show this help guide", AssertState::pass()), ReplCommand::new(".help", "Show this help guide", AssertState::pass()),
ReplCommand::new(".info", "Show system info", AssertState::pass()), ReplCommand::new(".info", "Show system info", AssertState::pass()),
@@ -63,6 +66,11 @@ static REPL_COMMANDS: LazyLock<[ReplCommand; 49]> = LazyLock::new(|| {
"Authenticate the current model client via OAuth (if configured)", "Authenticate the current model client via OAuth (if configured)",
AssertState::pass(), AssertState::pass(),
), ),
ReplCommand::new(
".mcp auth",
"Authenticate with an MCP server via OAuth",
AssertState::pass(),
),
ReplCommand::new( ReplCommand::new(
".edit config", ".edit config",
"Modify configuration file", "Modify configuration file",
@@ -313,6 +321,58 @@ Type ".help" for additional help.
} }
} }
{
let (messages_snapshot, compressed_count) = {
let ctx = self.ctx.read();
if let Some(session) = &ctx.session {
let msgs: Vec<Message> = session
.messages()
.iter()
.filter(|m| !m.role.is_system())
.cloned()
.collect();
let compressed = session.compressed_messages().len();
(msgs, compressed)
} else {
(vec![], 0)
}
};
if !messages_snapshot.is_empty() || compressed_count > 0 {
let app = Arc::clone(&self.ctx.read().app.config);
if compressed_count > 0 {
println!(
"{}",
dimmed_text(&format!(
"({compressed_count} earlier messages not shown; compressed for context)"
))
);
println!();
}
for message in &messages_snapshot {
match message.role {
MessageRole::User => {
if let Some(text) = message.content.as_text() {
println!("{}", dimmed_text("You:"));
println!("{text}");
println!();
}
}
MessageRole::Assistant => {
if let Some(text) = message.content.as_text() {
app.print_markdown(text)?;
println!();
}
}
_ => {}
}
}
println!("{}", dimmed_text("─── ↑ previous conversation ↑ ───"));
println!();
}
}
loop { loop {
if self.abort_signal.aborted_ctrld() { if self.abort_signal.aborted_ctrld() {
break; break;
@@ -388,6 +448,14 @@ Type ".help" for additional help.
editor = editor.with_buffer_editor(command, temp_file); editor = editor.with_buffer_editor(command, temp_file);
} }
if app.save_shell_history {
let ctx = ctx.read();
let history_path = paths::repl_history_file(&ctx.session);
if let Ok(history) = FileBackedHistory::with_file(1000, history_path) {
editor = editor.with_history(Box::new(history));
}
}
Ok(editor) Ok(editor)
} }
@@ -541,6 +609,53 @@ pub async fn run_repl_command(
let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?; let (client_name, provider) = resolve_oauth_client(Some(client.name()), &clients)?;
oauth::run_oauth_flow(&*provider, &client_name).await?; oauth::run_oauth_flow(&*provider, &client_name).await?;
} }
".mcp" => match args {
Some(args) => {
let mut parts = args.splitn(2, char::is_whitespace);
let sub = parts.next().unwrap_or("").trim();
let rest = parts.next().map(str::trim).unwrap_or("");
match sub {
"auth" => {
if rest.is_empty() {
println!("Usage: .mcp auth <server_name>");
} else {
let server_name = rest;
let server_spec = ctx
.app
.mcp_config
.as_ref()
.and_then(|c| c.mcp_servers.get(server_name))
.cloned();
match server_spec {
None => {
bail!("MCP server '{}' not found in mcp.json.", server_name)
}
Some(spec) if !spec.is_remote() => bail!(
"MCP server '{}' uses stdio transport; \
OAuth is only supported for http/sse servers.",
server_name
),
Some(spec) => {
let url = spec
.url
.as_deref()
.expect("validated: remote spec has url");
let client_id = spec.oauth_client_id.as_deref();
mcp::oauth::run_mcp_oauth_flow(server_name, url, client_id)
.await?;
println!(
"Authentication saved. \
Restart Coyote to connect to '{server_name}'."
);
}
}
}
}
_ => unknown_command()?,
}
}
None => println!("Usage: .mcp auth <server_name>"),
},
".prompt" => match args { ".prompt" => match args {
Some(text) => { Some(text) => {
let app = Arc::clone(&ctx.app.config); let app = Arc::clone(&ctx.app.config);
@@ -632,6 +747,46 @@ pub async fn run_repl_command(
session.set_autonaming(false); session.set_autonaming(false);
} }
} }
if let Some(session) = &ctx.session {
let messages_snapshot: Vec<Message> = session
.messages()
.iter()
.filter(|m| !m.role.is_system())
.cloned()
.collect();
let compressed_count = session.compressed_messages().len();
if !messages_snapshot.is_empty() || compressed_count > 0 {
if compressed_count > 0 {
println!(
"{}",
dimmed_text(&format!(
"({compressed_count} earlier messages not shown — compressed for context)"
))
);
println!();
}
for message in &messages_snapshot {
match message.role {
MessageRole::User => {
if let Some(text) = message.content.as_text() {
println!("{}", dimmed_text("You:"));
println!("{text}");
println!();
}
}
MessageRole::Assistant => {
if let Some(text) = message.content.as_text() {
app.print_markdown(text)?;
println!();
}
}
_ => {}
}
}
println!("{}", dimmed_text("─── ↑ previous conversation ↑ ───"));
println!();
}
}
} }
".install" => { ".install" => {
let trimmed = args.map(str::trim).unwrap_or(""); let trimmed = args.map(str::trim).unwrap_or("");
@@ -1415,8 +1570,8 @@ mod tests {
} }
#[test] #[test]
fn repl_commands_has_49_entries() { fn repl_commands_has_50_entries() {
assert_eq!(REPL_COMMANDS.len(), 49); assert_eq!(REPL_COMMANDS.len(), 50);
} }
#[test] #[test]