193 lines
5.3 KiB
Rust
193 lines
5.3 KiB
Rust
use crate::function::{Functions, ToolCallTracker};
|
|
use crate::mcp::{CatalogItem, ConnectedServer, McpRegistry};
|
|
|
|
use anyhow::{Context, Result, anyhow};
|
|
use bm25::{Document, Language, SearchEngineBuilder};
|
|
use rmcp::model::{CallToolRequestParams, CallToolResult};
|
|
use serde_json::{Value, json};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
#[derive(Clone)]
|
|
pub struct ToolScope {
|
|
pub functions: Functions,
|
|
pub mcp_runtime: McpRuntime,
|
|
pub tool_tracker: ToolCallTracker,
|
|
}
|
|
|
|
impl Default for ToolScope {
|
|
fn default() -> Self {
|
|
Self {
|
|
functions: Functions::default(),
|
|
mcp_runtime: McpRuntime::default(),
|
|
tool_tracker: ToolCallTracker::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Default, Clone)]
|
|
pub struct McpRuntime {
|
|
pub servers: HashMap<String, Arc<ConnectedServer>>,
|
|
}
|
|
|
|
impl McpRuntime {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
pub fn is_empty(&self) -> bool {
|
|
self.servers.is_empty()
|
|
}
|
|
|
|
pub fn insert(&mut self, name: String, handle: Arc<ConnectedServer>) {
|
|
self.servers.insert(name, handle);
|
|
}
|
|
|
|
pub fn get(&self, name: &str) -> Option<&Arc<ConnectedServer>> {
|
|
self.servers.get(name)
|
|
}
|
|
|
|
pub fn server_names(&self) -> Vec<String> {
|
|
self.servers.keys().cloned().collect()
|
|
}
|
|
|
|
pub fn sync_from_registry(&mut self, registry: &McpRegistry) {
|
|
self.servers.clear();
|
|
for (name, handle) in registry.running_servers() {
|
|
self.servers.insert(name.clone(), Arc::clone(handle));
|
|
}
|
|
}
|
|
|
|
async fn catalog_items(&self, server: &str) -> Result<HashMap<String, CatalogItem>> {
|
|
let server_handle = self
|
|
.get(server)
|
|
.cloned()
|
|
.with_context(|| format!("{server} MCP server not found in runtime"))?;
|
|
let tools = server_handle.list_tools(None).await?;
|
|
let mut items = HashMap::new();
|
|
|
|
for tool in tools.tools {
|
|
let item = CatalogItem {
|
|
name: tool.name.to_string(),
|
|
server: server.to_string(),
|
|
description: tool.description.unwrap_or_default().to_string(),
|
|
};
|
|
items.insert(item.name.clone(), item);
|
|
}
|
|
|
|
Ok(items)
|
|
}
|
|
|
|
pub async fn search(
|
|
&self,
|
|
server: &str,
|
|
query: &str,
|
|
top_k: usize,
|
|
) -> Result<Vec<CatalogItem>> {
|
|
let items = self.catalog_items(server).await?;
|
|
let docs = items.values().map(|item| Document {
|
|
id: item.name.clone(),
|
|
contents: format!(
|
|
"{}\n{}\nserver:{}",
|
|
item.name, item.description, item.server
|
|
),
|
|
});
|
|
let engine = SearchEngineBuilder::<String>::with_documents(Language::English, docs).build();
|
|
|
|
Ok(engine
|
|
.search(query, top_k.min(20))
|
|
.into_iter()
|
|
.filter_map(|result| items.get(&result.document.id))
|
|
.take(top_k)
|
|
.cloned()
|
|
.collect())
|
|
}
|
|
|
|
pub async fn describe(&self, server: &str, tool: &str) -> Result<Value> {
|
|
let server_handle = self
|
|
.get(server)
|
|
.cloned()
|
|
.with_context(|| format!("{server} MCP server not found in runtime"))?;
|
|
|
|
let tool_schema = server_handle
|
|
.list_tools(None)
|
|
.await?
|
|
.tools
|
|
.into_iter()
|
|
.find(|item| item.name == tool)
|
|
.ok_or_else(|| anyhow!("{tool} not found in {server} MCP server catalog"))?
|
|
.input_schema;
|
|
|
|
Ok(json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"tool": {
|
|
"type": "string",
|
|
},
|
|
"arguments": tool_schema
|
|
}
|
|
}))
|
|
}
|
|
|
|
pub async fn invoke(
|
|
&self,
|
|
server: &str,
|
|
tool: &str,
|
|
arguments: Value,
|
|
) -> Result<CallToolResult> {
|
|
let server_handle = self
|
|
.get(server)
|
|
.cloned()
|
|
.with_context(|| format!("Invoked MCP server does not exist: {server}"))?;
|
|
|
|
let mut request = CallToolRequestParams::new(tool.to_owned());
|
|
request.arguments = arguments.as_object().cloned();
|
|
|
|
server_handle.call_tool(request).await.map_err(Into::into)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::function::ToolCall;
|
|
|
|
#[test]
|
|
fn mcp_runtime_new_is_empty() {
|
|
let runtime = McpRuntime::new();
|
|
assert!(runtime.is_empty());
|
|
assert!(runtime.server_names().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn mcp_runtime_default_is_empty() {
|
|
let runtime = McpRuntime::default();
|
|
assert!(runtime.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn mcp_runtime_get_returns_none_for_missing_server() {
|
|
let runtime = McpRuntime::new();
|
|
assert!(runtime.get("nonexistent").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn tool_scope_default_has_empty_mcp_runtime() {
|
|
let scope = ToolScope::default();
|
|
assert!(scope.mcp_runtime.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn tool_scope_default_has_empty_functions() {
|
|
let scope = ToolScope::default();
|
|
assert!(scope.functions.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn tool_scope_default_tracker_has_no_loops() {
|
|
let scope = ToolScope::default();
|
|
let dummy_call = ToolCall::default();
|
|
assert!(scope.tool_tracker.check_loop(&dummy_call).is_none());
|
|
}
|
|
}
|