From 49aa9fad4105216bd2e724ce20ce6155e7ff4196 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Mon, 20 Apr 2026 14:10:26 -0600 Subject: [PATCH] feat: legacy SSE support for MCP server configurations --- .gitignore | 1 - src/mcp/mod.rs | 50 ++++++-- src/mcp/sse_transport.rs | 261 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 299 insertions(+), 13 deletions(-) create mode 100644 src/mcp/sse_transport.rs diff --git a/.gitignore b/.gitignore index 58a2721..b937ecf 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,3 @@ .idea/ /loki.iml /.idea/ -src diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 79921d0..3c33fe8 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -1,3 +1,5 @@ +mod sse_transport; + use crate::config::AppConfig; use crate::config::paths; use crate::utils::{AbortSignal, abortable_run_with_spinner}; @@ -5,6 +7,7 @@ use crate::vault::Vault; use crate::vault::interpolate_secrets; use anyhow::{Context, Result, anyhow}; use futures_util::{StreamExt, TryStreamExt, stream}; +use http::{HeaderName, HeaderValue}; use indoc::formatdoc; use rmcp::service::RunningService; use rmcp::transport::StreamableHttpClientTransport; @@ -12,12 +15,12 @@ use rmcp::transport::TokioChildProcess; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use rmcp::{RoleClient, ServiceExt}; use serde::{Deserialize, Serialize}; +use sse_transport::LegacySseTransport; use std::collections::{HashMap, HashSet}; use std::fs::OpenOptions; use std::path::{Path, PathBuf}; use std::process::Stdio; use std::sync::Arc; -use http::{HeaderName, HeaderValue}; use tokio::process::Command; pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke"; @@ -328,19 +331,26 @@ pub(crate) async fn spawn_mcp_server( spec: &McpServer, log_path: Option<&Path>, ) -> Result> { - if spec.is_remote() { - let url = spec.url.as_deref().expect("validated: remote spec has url"); - spawn_remote_mcp_server(url, spec.headers.as_ref()).await - } else { - let command = spec - .command - .as_deref() - .expect("validated: stdio spec has command"); - spawn_stdio_mcp_server(command, spec, log_path).await + match spec.transport_type { + McpTransportType::Http => { + let url = spec.url.as_deref().expect("validated: http spec has url"); + spawn_http_mcp_server(url, spec.headers.as_ref()).await + } + McpTransportType::Sse => { + let url = spec.url.as_deref().expect("validated: sse spec has url"); + spawn_sse_mcp_server(url, spec.headers.as_ref()).await + } + McpTransportType::Stdio => { + let command = spec + .command + .as_deref() + .expect("validated: stdio spec has command"); + spawn_stdio_mcp_server(command, spec, log_path).await + } } } -async fn spawn_remote_mcp_server( +async fn spawn_http_mcp_server( url: &str, headers: Option<&HashMap>, ) -> Result> { @@ -365,7 +375,23 @@ async fn spawn_remote_mcp_server( let service = Arc::new( ().serve(transport) .await - .with_context(|| format!("Failed to connect to remote MCP server: {url}"))?, + .with_context(|| format!("Failed to connect to HTTP MCP server: {url}"))?, + ); + Ok(service) +} + +async fn spawn_sse_mcp_server( + url: &str, + headers: Option<&HashMap>, +) -> Result> { + let sse = LegacySseTransport::connect(url, headers) + .await + .with_context(|| format!("Failed to connect to SSE MCP server: {url}"))?; + let (sink, stream) = sse.into_parts(); + let service = Arc::new( + ().serve((sink, stream)) + .await + .with_context(|| format!("Failed to initialize SSE MCP server: {url}"))?, ); Ok(service) } diff --git a/src/mcp/sse_transport.rs b/src/mcp/sse_transport.rs new file mode 100644 index 0000000..876fa3d --- /dev/null +++ b/src/mcp/sse_transport.rs @@ -0,0 +1,261 @@ +use anyhow::{Context, Result, anyhow}; +use fmt::{Display, Formatter}; +use futures_util::StreamExt; +use mpsc::{Receiver, Sender, channel}; +use reqwest::Client; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use reqwest_eventsource::{Event, EventSource}; +use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; +use std::collections::HashMap; +use std::error::Error; +use std::fmt; +use std::pin::Pin; +use std::task::Poll; +use tokio::sync::mpsc; +use tokio::time::Duration; +use url::Url; + +const CHANNEL_BUF: usize = 64; + +pub struct LegacySseTransport { + tx: Sender, + rx: Receiver, +} + +impl LegacySseTransport { + pub async fn connect(sse_url: &str, headers: Option<&HashMap>) -> Result { + let base_url = + Url::parse(sse_url).with_context(|| format!("Invalid SSE URL: {sse_url}"))?; + + let mut client_builder = Client::builder(); + let mut header_map = HeaderMap::new(); + if let Some(hdrs) = headers { + for (k, v) in hdrs { + let name = k + .parse::() + .with_context(|| format!("Invalid header name: {k}"))?; + let value = v + .parse::() + .with_context(|| format!("Invalid header value for {k}"))?; + header_map.insert(name, value); + } + client_builder = client_builder.default_headers(header_map); + } + let client = client_builder + .build() + .context("Failed to build HTTP client")?; + + let request = client.get(sse_url); + let mut es = EventSource::new(request).context("Failed to open SSE connection")?; + + let post_endpoint = wait_for_endpoint_event(&mut es, &base_url).await?; + + let (outgoing_tx, outgoing_rx) = channel::(CHANNEL_BUF); + let (incoming_tx, incoming_rx) = channel::(CHANNEL_BUF); + + tokio::spawn(sse_reader_task(es, incoming_tx)); + tokio::spawn(post_writer_task(client, post_endpoint, outgoing_rx)); + + Ok(Self { + tx: outgoing_tx, + rx: incoming_rx, + }) + } + + pub fn into_parts( + self, + ) -> ( + SseSink, + SseStream, + ) { + ( + SseSink { + tx: PollSender { + tx: self.tx, + permit: None, + }, + }, + SseStream { rx: self.rx }, + ) + } +} + +async fn wait_for_endpoint_event(es: &mut EventSource, base_url: &Url) -> Result { + let timeout = Duration::from_secs(30); + tokio::time::timeout(timeout, async { + while let Some(event) = es.next().await { + match event { + Ok(Event::Open) => {} + Ok(Event::Message(msg)) if msg.event == "endpoint" => { + let endpoint = msg.data.trim().to_string(); + let resolved = resolve_endpoint(&endpoint, base_url)?; + return Ok(resolved); + } + Ok(Event::Message(_)) => {} + Err(e) => { + return Err(anyhow!( + "SSE connection error while waiting for endpoint event: {e}" + )); + } + } + } + Err(anyhow!("SSE stream closed before receiving endpoint event")) + }) + .await + .map_err(|_| anyhow!("Timed out waiting for endpoint event from SSE server (30s)"))? +} + +fn resolve_endpoint(endpoint: &str, base_url: &Url) -> Result { + if endpoint.starts_with("http://") || endpoint.starts_with("https://") { + Ok(endpoint.to_string()) + } else { + let mut resolved = base_url.clone(); + let (path, query) = endpoint.split_once('?').unwrap_or((endpoint, "")); + resolved.set_path(path); + resolved.set_query(if query.is_empty() { None } else { Some(query) }); + Ok(resolved.to_string()) + } +} + +async fn sse_reader_task(mut es: EventSource, tx: Sender) { + while let Some(event) = es.next().await { + match event { + Ok(Event::Message(msg)) if msg.event == "message" => { + match serde_json::from_str::(&msg.data) { + Ok(rpc_msg) => { + if tx.send(rpc_msg).await.is_err() { + break; + } + } + Err(e) => { + warn!("Failed to parse SSE message as JSON-RPC: {e}"); + } + } + } + Ok(_) => {} + Err(reqwest_eventsource::Error::StreamEnded) => break, + Err(e) => { + error!("SSE stream error: {e}"); + break; + } + } + } + es.close(); +} + +async fn post_writer_task( + client: Client, + endpoint: String, + mut rx: Receiver, +) { + while let Some(msg) = rx.recv().await { + let body = match serde_json::to_string(&msg) { + Ok(b) => b, + Err(e) => { + error!("Failed to serialize JSON-RPC message: {e}"); + continue; + } + }; + if let Err(e) = client + .post(&endpoint) + .header("Content-Type", "application/json") + .body(body) + .send() + .await + { + error!("Failed to POST message to SSE endpoint: {e}"); + } + } +} + +pub struct SseSink { + tx: PollSender, +} + +pub struct SseStream { + rx: Receiver, +} + +impl futures_util::Sink for SseSink { + type Error = SseSinkError; + + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.tx.poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + self.tx.start_send(item) + } + + fn poll_flush( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close( + self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl futures_util::Stream for SseStream { + type Item = T; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + self.rx.poll_recv(cx) + } +} + +#[derive(Debug)] +pub enum SseSinkError { + Closed, +} + +impl Display for SseSinkError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + SseSinkError::Closed => write!(f, "SSE transport channel closed"), + } + } +} + +impl Error for SseSinkError {} + +struct PollSender { + tx: Sender, + permit: Option>, +} + +impl PollSender { + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if self.permit.is_some() { + return Poll::Ready(Ok(())); + } + let tx = self.tx.clone(); + let mut fut = Box::pin(tx.reserve_owned()); + match fut.as_mut().poll(cx) { + Poll::Ready(Ok(permit)) => { + self.permit = Some(permit); + Poll::Ready(Ok(())) + } + Poll::Ready(Err(_)) => Poll::Ready(Err(SseSinkError::Closed)), + Poll::Pending => Poll::Pending, + } + } + + fn start_send(&mut self, item: T) -> Result<(), SseSinkError> { + let permit = self.permit.take().ok_or(SseSinkError::Closed)?; + permit.send(item); + Ok(()) + } +}