diff --git a/src/config/mod.rs b/src/config/mod.rs index 25969d7..9dbde6f 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -24,7 +24,8 @@ use crate::utils::*; use crate::config::macros::Macro; use crate::mcp::{ - MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX, McpRegistry, + MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, MCP_INVOKE_META_FUNCTION_NAME_PREFIX, + MCP_SEARCH_META_FUNCTION_NAME_PREFIX, McpRegistry, }; use crate::vault::{GlobalVault, Vault, create_vault_password_file, interpolate_secrets}; use anyhow::{Context, Result, anyhow, bail}; @@ -1972,7 +1973,8 @@ impl Config { .iter() .filter(|v| { !v.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) - && !v.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + && !v.name.starts_with(MCP_SEARCH_META_FUNCTION_NAME_PREFIX) + && !v.name.starts_with(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX) }) .map(|v| v.name.to_string()) .collect(); @@ -2015,7 +2017,8 @@ impl Config { .into_iter() .filter(|v| { !v.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) - && !v.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + && !v.name.starts_with(MCP_SEARCH_META_FUNCTION_NAME_PREFIX) + && !v.name.starts_with(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX) }) .collect(); let tool_names: HashSet = agent_functions @@ -2051,7 +2054,8 @@ impl Config { .iter() .filter(|v| { v.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) - || v.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + || v.name.starts_with(MCP_SEARCH_META_FUNCTION_NAME_PREFIX) + || v.name.starts_with(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX) }) .map(|v| v.name.to_string()) .collect(); @@ -2062,8 +2066,10 @@ impl Config { let item = item.trim(); let item_invoke_name = format!("{}_{item}", MCP_INVOKE_META_FUNCTION_NAME_PREFIX); - let item_list_name = - format!("{}_{item}", MCP_LIST_META_FUNCTION_NAME_PREFIX); + let item_search_name = + format!("{}_{item}", MCP_SEARCH_META_FUNCTION_NAME_PREFIX); + let item_describe_name = + format!("{}_{item}", MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX); if let Some(values) = self.mapping_mcp_servers.get(item) { server_names.extend( values @@ -2077,7 +2083,12 @@ impl Config { ), format!( "{}_{}", - MCP_LIST_META_FUNCTION_NAME_PREFIX, + MCP_SEARCH_META_FUNCTION_NAME_PREFIX, + v.to_string() + ), + format!( + "{}_{}", + MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, v.to_string() ), ] @@ -2086,7 +2097,8 @@ impl Config { ) } else if mcp_declaration_names.contains(&item_invoke_name) { server_names.insert(item_invoke_name); - server_names.insert(item_list_name); + server_names.insert(item_search_name); + server_names.insert(item_describe_name); } } } @@ -2112,7 +2124,8 @@ impl Config { .into_iter() .filter(|v| { v.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) - || v.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + || v.name.starts_with(MCP_SEARCH_META_FUNCTION_NAME_PREFIX) + || v.name.starts_with(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX) }) .collect(); let tool_names: HashSet = agent_functions diff --git a/src/function/mod.rs b/src/function/mod.rs index 20810d9..28bccde 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -4,7 +4,10 @@ use crate::{ }; use crate::config::ensure_parent_exists; -use crate::mcp::{MCP_INVOKE_META_FUNCTION_NAME_PREFIX, MCP_LIST_META_FUNCTION_NAME_PREFIX}; +use crate::mcp::{ + MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX, MCP_INVOKE_META_FUNCTION_NAME_PREFIX, + MCP_SEARCH_META_FUNCTION_NAME_PREFIX, +}; use crate::parsers::{bash, python}; use anyhow::{Context, Result, anyhow, bail}; use indexmap::IndexMap; @@ -247,19 +250,13 @@ impl Functions { pub fn clear_mcp_meta_functions(&mut self) { self.declarations.retain(|d| { !d.name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) - && !d.name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) + && !d.name.starts_with(MCP_SEARCH_META_FUNCTION_NAME_PREFIX) + && !d.name.starts_with(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX) }); } pub fn append_mcp_meta_functions(&mut self, mcp_servers: Vec) { let mut invoke_function_properties = IndexMap::new(); - invoke_function_properties.insert( - "server".to_string(), - JsonSchema { - type_value: Some("string".to_string()), - ..Default::default() - }, - ); invoke_function_properties.insert( "tool".to_string(), JsonSchema { @@ -275,32 +272,85 @@ impl Functions { }, ); + let mut search_function_properties = IndexMap::new(); + search_function_properties.insert( + "query".to_string(), + JsonSchema { + type_value: Some("string".to_string()), + description: Some("Generalized explanation of what you want to do".into()), + ..Default::default() + }, + ); + search_function_properties.insert( + "top_k".to_string(), + JsonSchema { + type_value: Some("integer".to_string()), + description: Some("How many results to return, between 1 and 20".into()), + default: Some(Value::from(8usize)), + ..Default::default() + }, + ); + + let mut describe_function_properties = IndexMap::new(); + describe_function_properties.insert( + "tool".to_string(), + JsonSchema { + type_value: Some("string".to_string()), + description: Some("The name of the tool; e.g., search_issues".into()), + ..Default::default() + }, + ); + for server in mcp_servers { + let search_function_name = format!("{}_{server}", MCP_SEARCH_META_FUNCTION_NAME_PREFIX); + let describe_function_name = format!("{}_{server}", MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX); let invoke_function_name = format!("{}_{server}", MCP_INVOKE_META_FUNCTION_NAME_PREFIX); let invoke_function_declaration = FunctionDeclaration { name: invoke_function_name.clone(), description: formatdoc!( r#" - Invoke the specified tool on the {server} MCP server. Always call {invoke_function_name} first to find the - correct names of tools before calling '{invoke_function_name}'. + Invoke the specified tool on the {server} MCP server. Always call {describe_function_name} first to + find the correct invocation schema for the given tool. "# ), parameters: JsonSchema { type_value: Some("object".to_string()), properties: Some(invoke_function_properties.clone()), - required: Some(vec!["server".to_string(), "tool".to_string()]), + required: Some(vec!["tool".to_string()]), ..Default::default() }, agent: false, }; - let list_functions_declaration = FunctionDeclaration { - name: format!("{}_{}", MCP_LIST_META_FUNCTION_NAME_PREFIX, server), - description: format!("List all the available tools for the {server} MCP server"), - parameters: JsonSchema::default(), + let search_functions_declaration = FunctionDeclaration { + name: search_function_name.clone(), + description: formatdoc!( + r#" + Find candidate tools by keywords for the {server} MCP server. Returns small suggestions; fetch + schemas with {describe_function_name}. + "# + ), + parameters: JsonSchema { + type_value: Some("object".to_string()), + properties: Some(search_function_properties.clone()), + required: Some(vec!["query".to_string()]), + ..Default::default() + }, + agent: false, + }; + let describe_functions_declaration = FunctionDeclaration { + name: describe_function_name.clone(), + description: "Get the full JSON schema for exactly one MCP tool.".to_string(), + parameters: JsonSchema { + type_value: Some("object".to_string()), + properties: Some(describe_function_properties.clone()), + required: Some(vec!["tool".to_string()]), + ..Default::default() + }, agent: false, }; self.declarations.push(invoke_function_declaration); - self.declarations.push(list_functions_declaration); + self.declarations.push(search_functions_declaration); + self.declarations.push(describe_functions_declaration); } } @@ -771,39 +821,14 @@ impl ToolCall { } let output = match cmd_name.as_str() { - _ if cmd_name.starts_with(MCP_LIST_META_FUNCTION_NAME_PREFIX) => { - let registry_arc = { - let cfg = config.read(); - cfg.mcp_registry - .clone() - .with_context(|| "MCP is not configured")? - }; - - registry_arc.catalog().await? + _ if cmd_name.starts_with(MCP_SEARCH_META_FUNCTION_NAME_PREFIX) => { + Self::search_mcp_tools(config, &cmd_name, &json_data)? + } + _ if cmd_name.starts_with(MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX) => { + Self::describe_mcp_tool(config, &cmd_name, json_data).await? } _ if cmd_name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) => { - let server = json_data - .get("server") - .ok_or_else(|| anyhow!("Missing 'server' in arguments"))? - .as_str() - .ok_or_else(|| anyhow!("Invalid 'server' in arguments"))?; - let tool = json_data - .get("tool") - .ok_or_else(|| anyhow!("Missing 'tool' in arguments"))? - .as_str() - .ok_or_else(|| anyhow!("Invalid 'tool' in arguments"))?; - let arguments = json_data - .get("arguments") - .cloned() - .unwrap_or_else(|| json!({})); - let registry_arc = { - let cfg = config.read(); - cfg.mcp_registry - .clone() - .with_context(|| "MCP is not configured")? - }; - let result = registry_arc.invoke(server, tool, arguments).await?; - serde_json::to_value(result)? + Self::invoke_mcp_tool(config, &cmd_name, &json_data).await? } _ => match run_llm_function(cmd_name, cmd_args, envs, agent_name)? { Some(contents) => serde_json::from_str(&contents) @@ -816,6 +841,82 @@ impl ToolCall { Ok(output) } + async fn describe_mcp_tool( + config: &GlobalConfig, + cmd_name: &str, + json_data: Value, + ) -> Result { + let server_id = cmd_name.replace(&format!("{MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX}_"), ""); + let tool = json_data + .get("tool") + .ok_or_else(|| anyhow!("Missing 'tool' in arguments"))? + .as_str() + .ok_or_else(|| anyhow!("Invalid 'tool' in arguments"))?; + let registry_arc = { + let cfg = config.read(); + cfg.mcp_registry + .clone() + .with_context(|| "MCP is not configured")? + }; + + let result = registry_arc.describe(&server_id, tool).await?; + Ok(serde_json::to_value(result)?) + } + + fn search_mcp_tools(config: &GlobalConfig, cmd_name: &str, json_data: &Value) -> Result { + let server = cmd_name.replace(&format!("{MCP_SEARCH_META_FUNCTION_NAME_PREFIX}_"), ""); + let query = json_data + .get("query") + .ok_or_else(|| anyhow!("Missing 'query' in arguments"))? + .as_str() + .ok_or_else(|| anyhow!("Invalid 'query' in arguments"))?; + let top_k = json_data + .get("top_k") + .cloned() + .unwrap_or_else(|| Value::from(8u64)) + .as_u64() + .ok_or_else(|| anyhow!("Invalid 'top_k' in arguments"))? as usize; + let registry_arc = { + let cfg = config.read(); + cfg.mcp_registry + .clone() + .with_context(|| "MCP is not configured")? + }; + + let catalog_items = registry_arc + .search_tools_server(&server, query, top_k) + .into_iter() + .map(|it| serde_json::to_value(&it).unwrap_or_default()) + .collect(); + Ok(Value::Array(catalog_items)) + } + + async fn invoke_mcp_tool( + config: &GlobalConfig, + cmd_name: &str, + json_data: &Value, + ) -> Result { + let server = cmd_name.replace(&format!("{MCP_INVOKE_META_FUNCTION_NAME_PREFIX}_"), ""); + let tool = json_data + .get("tool") + .ok_or_else(|| anyhow!("Missing 'tool' in arguments"))? + .as_str() + .ok_or_else(|| anyhow!("Invalid 'tool' in arguments"))?; + let arguments = json_data + .get("arguments") + .cloned() + .unwrap_or_else(|| json!({})); + let registry_arc = { + let cfg = config.read(); + cfg.mcp_registry + .clone() + .with_context(|| "MCP is not configured")? + }; + + let result = registry_arc.invoke(&server, tool, arguments).await?; + Ok(serde_json::to_value(result)?) + } + fn extract_call_config_from_agent( &self, config: &GlobalConfig, diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 32781b9..d6e38c9 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -2,6 +2,7 @@ use crate::config::Config; use crate::utils::{AbortSignal, abortable_run_with_spinner}; use crate::vault::interpolate_secrets; use anyhow::{Context, Result, anyhow}; +use bm25::{Document, Language, SearchEngine, SearchEngineBuilder}; use futures_util::future::BoxFuture; use futures_util::{StreamExt, TryStreamExt, stream}; use indoc::formatdoc; @@ -9,7 +10,7 @@ use rmcp::model::{CallToolRequestParam, CallToolResult}; use rmcp::service::RunningService; use rmcp::transport::TokioChildProcess; use rmcp::{RoleClient, ServiceExt}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; @@ -20,10 +21,46 @@ use std::sync::Arc; use tokio::process::Command; pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke"; -pub const MCP_LIST_META_FUNCTION_NAME_PREFIX: &str = "mcp_list"; +pub const MCP_SEARCH_META_FUNCTION_NAME_PREFIX: &str = "mcp_search"; +pub const MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX: &str = "mcp_describe"; type ConnectedServer = RunningService; +#[derive(Clone, Debug, Default, Serialize)] +pub struct CatalogItem { + pub name: String, + pub server: String, + pub description: String, +} + +#[derive(Debug)] +struct ServerCatalog { + engine: SearchEngine, + items: HashMap, +} + +impl ServerCatalog { + pub fn build_bm25(items: &HashMap) -> SearchEngine { + let docs = items.values().map(|it| { + let contents = format!("{}\n{}\nserver:{}", it.name, it.description, it.server); + Document { + id: it.name.clone(), + contents, + } + }); + SearchEngineBuilder::::with_documents(Language::English, docs).build() + } +} + +impl Clone for ServerCatalog { + fn clone(&self) -> Self { + Self { + engine: Self::build_bm25(&self.items), + items: self.items.clone(), + } + } +} + #[derive(Debug, Clone, Deserialize)] struct McpServersConfig { #[serde(rename = "mcpServers")] @@ -50,7 +87,8 @@ enum JsonField { pub struct McpRegistry { log_path: Option, config: Option, - servers: HashMap>>, + servers: HashMap>, + catalogs: HashMap, } impl McpRegistry { @@ -173,7 +211,7 @@ impl McpRegistry { .collect() }; - let results: Vec<(String, Arc<_>)> = stream::iter( + let results: Vec<(String, Arc<_>, ServerCatalog)> = stream::iter( server_ids .into_iter() .map(|id| async { self.start_server(id).await }), @@ -182,13 +220,24 @@ impl McpRegistry { .try_collect() .await?; - self.servers = results.into_iter().collect(); + self.servers = results + .clone() + .into_iter() + .map(|(id, server, _)| (id, server)) + .collect(); + self.catalogs = results + .into_iter() + .map(|(id, _, catalog)| (id, catalog)) + .collect(); } Ok(()) } - async fn start_server(&self, id: String) -> Result<(String, Arc)> { + async fn start_server( + &self, + id: String, + ) -> Result<(String, Arc, ServerCatalog)> { let server = self .config .as_ref() @@ -231,14 +280,33 @@ impl McpRegistry { .await .with_context(|| format!("Failed to start MCP server: {}", &server.command))?, ); - debug!( - "Available tools for MCP server {id}: {:?}", - service.list_tools(None).await? - ); + let tools = service.list_tools(None).await?; + debug!("Available tools for MCP server {id}: {tools:?}"); + + let mut items_vec = Vec::new(); + for t in tools.tools { + let name = t.name.to_string(); + let description = t.description.unwrap_or_default().to_string(); + items_vec.push(CatalogItem { + name, + server: id.clone(), + description, + }); + } + + let mut items_map = HashMap::new(); + items_vec.into_iter().for_each(|it| { + items_map.insert(it.name.clone(), it); + }); + + let catalog = ServerCatalog { + engine: ServerCatalog::build_bm25(&items_map), + items: items_map, + }; info!("Started MCP server: {id}"); - Ok((id.to_string(), service)) + Ok((id.to_string(), service, catalog)) } pub async fn stop_all_servers(mut self) -> Result { @@ -268,26 +336,48 @@ impl McpRegistry { } } - pub fn catalog(&self) -> BoxFuture<'static, Result> { - let servers: Vec<(String, Arc)> = self + pub fn search_tools_server(&self, server: &str, query: &str, top_k: usize) -> Vec { + let Some(catalog) = self.catalogs.get(server) else { + return vec![]; + }; + let engine = &catalog.engine; + let raw = engine.search(query, top_k.min(20)); + + raw.into_iter() + .filter_map(|r| catalog.items.get(&r.document.id)) + .take(top_k) + .cloned() + .collect() + } + + pub async fn describe(&self, server_id: &str, tool: &str) -> Result { + let server = self .servers .iter() - .map(|(id, s)| (id.clone(), s.clone())) - .collect(); + .filter(|(id, _)| &server_id == id) + .map(|(_, s)| s.clone()) + .next() + .ok_or(anyhow!("{server_id} MCP server not found in config"))?; - Box::pin(async move { - let mut out = Vec::with_capacity(servers.len()); - for (id, server) in servers { - let tools = server.list_tools(None).await?; - let resources = server.list_resources(None).await.unwrap_or_default(); - out.push(json!({ - "server": id, - "tools": tools, - "resources": resources, - })); + let tool_schema = server + .list_tools(None) + .await? + .tools + .into_iter() + .find(|it| it.name == tool) + .ok_or(anyhow!( + "{tool} not found in {server_id} MCP server catalog" + ))? + .input_schema; + Ok(json!({ + "type": "object", + "properties": { + "tool": { + "type": "string", + }, + "arguments": tool_schema } - Ok(Value::Array(out)) - }) + })) } pub fn invoke(