feat: support http/sse transport types for MCP server configurations so it fully supports claude desktop-style MCP configs
This commit is contained in:
+114
-3
@@ -7,7 +7,9 @@ use anyhow::{Context, Result, anyhow};
|
||||
use futures_util::{StreamExt, TryStreamExt, stream};
|
||||
use indoc::formatdoc;
|
||||
use rmcp::service::RunningService;
|
||||
use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::TokioChildProcess;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
use rmcp::{RoleClient, ServiceExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@@ -15,6 +17,7 @@ use std::fs::OpenOptions;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use http::{HeaderName, HeaderValue};
|
||||
use tokio::process::Command;
|
||||
|
||||
pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
|
||||
@@ -50,11 +53,67 @@ pub(crate) struct McpServersConfig {
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub(crate) struct McpServer {
|
||||
pub command: String,
|
||||
#[serde(rename = "type")]
|
||||
pub transport_type: McpTransportType,
|
||||
pub command: Option<String>,
|
||||
pub args: Option<Vec<String>>,
|
||||
pub env: Option<HashMap<String, JsonField>>,
|
||||
pub cwd: Option<String>,
|
||||
pub url: Option<String>,
|
||||
pub headers: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
pub fn is_remote(&self) -> bool {
|
||||
matches!(
|
||||
self.transport_type,
|
||||
McpTransportType::Http | McpTransportType::Sse
|
||||
)
|
||||
}
|
||||
|
||||
pub fn validate(&self, name: &str) -> Result<()> {
|
||||
if self.is_remote() {
|
||||
let type_label = match self.transport_type {
|
||||
McpTransportType::Http => "http",
|
||||
McpTransportType::Sse => "sse",
|
||||
_ => unreachable!(),
|
||||
};
|
||||
if self.url.is_none() {
|
||||
return Err(anyhow!(
|
||||
"MCP server '{name}' has type \"{type_label}\" but is missing a \"url\" field"
|
||||
));
|
||||
}
|
||||
if self.command.is_some() || self.args.is_some() || self.cwd.is_some() {
|
||||
return Err(anyhow!(
|
||||
"MCP server '{name}' has type \"{type_label}\" but also specifies stdio fields \
|
||||
(command/args/cwd). Remove the stdio fields or change the type to \"stdio\"."
|
||||
));
|
||||
}
|
||||
} else {
|
||||
if self.command.is_none() {
|
||||
return Err(anyhow!(
|
||||
"MCP server '{name}' is missing a \"command\" field (required for stdio transport)"
|
||||
));
|
||||
}
|
||||
if self.url.is_some() || self.headers.is_some() {
|
||||
return Err(anyhow!(
|
||||
"MCP server '{name}' has type \"stdio\" but also specifies remote fields \
|
||||
(url/headers). Remove the remote fields or change the type to \"http\" or \"sse\"."
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum McpTransportType {
|
||||
Stdio,
|
||||
Http,
|
||||
Sse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -126,6 +185,11 @@ impl McpRegistry {
|
||||
|
||||
let mcp_servers_config: McpServersConfig =
|
||||
serde_json::from_str(&parsed_content).with_context(err)?;
|
||||
|
||||
for (name, spec) in &mcp_servers_config.mcp_servers {
|
||||
spec.validate(name)?;
|
||||
}
|
||||
|
||||
registry.config = Some(mcp_servers_config);
|
||||
|
||||
if start_mcp_servers && app_config.mcp_server_support {
|
||||
@@ -264,7 +328,54 @@ pub(crate) async fn spawn_mcp_server(
|
||||
spec: &McpServer,
|
||||
log_path: Option<&Path>,
|
||||
) -> Result<Arc<ConnectedServer>> {
|
||||
let mut cmd = Command::new(&spec.command);
|
||||
if spec.is_remote() {
|
||||
let url = spec.url.as_deref().expect("validated: remote spec has url");
|
||||
spawn_remote_mcp_server(url, spec.headers.as_ref()).await
|
||||
} else {
|
||||
let command = spec
|
||||
.command
|
||||
.as_deref()
|
||||
.expect("validated: stdio spec has command");
|
||||
spawn_stdio_mcp_server(command, spec, log_path).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn spawn_remote_mcp_server(
|
||||
url: &str,
|
||||
headers: Option<&HashMap<String, String>>,
|
||||
) -> Result<Arc<ConnectedServer>> {
|
||||
let transport = if let Some(hdrs) = headers
|
||||
&& !hdrs.is_empty()
|
||||
{
|
||||
let mut custom = HashMap::new();
|
||||
for (k, v) in hdrs {
|
||||
let name = k
|
||||
.parse::<HeaderName>()
|
||||
.with_context(|| format!("Invalid header name: {k}"))?;
|
||||
let value = v
|
||||
.parse::<HeaderValue>()
|
||||
.with_context(|| format!("Invalid header value for {k}"))?;
|
||||
custom.insert(name, value);
|
||||
}
|
||||
let config = StreamableHttpClientTransportConfig::with_uri(url).custom_headers(custom);
|
||||
StreamableHttpClientTransport::from_config(config)
|
||||
} else {
|
||||
StreamableHttpClientTransport::from_uri(url)
|
||||
};
|
||||
let service = Arc::new(
|
||||
().serve(transport)
|
||||
.await
|
||||
.with_context(|| format!("Failed to connect to remote MCP server: {url}"))?,
|
||||
);
|
||||
Ok(service)
|
||||
}
|
||||
|
||||
async fn spawn_stdio_mcp_server(
|
||||
command: &str,
|
||||
spec: &McpServer,
|
||||
log_path: Option<&Path>,
|
||||
) -> Result<Arc<ConnectedServer>> {
|
||||
let mut cmd = Command::new(command);
|
||||
if let Some(args) = &spec.args {
|
||||
cmd.args(args);
|
||||
}
|
||||
@@ -299,7 +410,7 @@ pub(crate) async fn spawn_mcp_server(
|
||||
let service = Arc::new(
|
||||
().serve(transport)
|
||||
.await
|
||||
.with_context(|| format!("Failed to start MCP server: {}", &spec.command))?,
|
||||
.with_context(|| format!("Failed to start MCP server: {command}"))?,
|
||||
);
|
||||
Ok(service)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user