feat: legacy SSE support for MCP server configurations

This commit is contained in:
2026-04-20 14:10:26 -06:00
parent 8f7d3bd13c
commit 49aa9fad41
3 changed files with 299 additions and 13 deletions
-1
View File
@@ -5,4 +5,3 @@
.idea/
/loki.iml
/.idea/
src
+33 -7
View File
@@ -1,3 +1,5 @@
mod sse_transport;
use crate::config::AppConfig;
use crate::config::paths;
use crate::utils::{AbortSignal, abortable_run_with_spinner};
@@ -5,6 +7,7 @@ use crate::vault::Vault;
use crate::vault::interpolate_secrets;
use anyhow::{Context, Result, anyhow};
use futures_util::{StreamExt, TryStreamExt, stream};
use http::{HeaderName, HeaderValue};
use indoc::formatdoc;
use rmcp::service::RunningService;
use rmcp::transport::StreamableHttpClientTransport;
@@ -12,12 +15,12 @@ use rmcp::transport::TokioChildProcess;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::{RoleClient, ServiceExt};
use serde::{Deserialize, Serialize};
use sse_transport::LegacySseTransport;
use std::collections::{HashMap, HashSet};
use std::fs::OpenOptions;
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::sync::Arc;
use http::{HeaderName, HeaderValue};
use tokio::process::Command;
pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
@@ -328,10 +331,16 @@ pub(crate) async fn spawn_mcp_server(
spec: &McpServer,
log_path: Option<&Path>,
) -> Result<Arc<ConnectedServer>> {
if spec.is_remote() {
let url = spec.url.as_deref().expect("validated: remote spec has url");
spawn_remote_mcp_server(url, spec.headers.as_ref()).await
} else {
match spec.transport_type {
McpTransportType::Http => {
let url = spec.url.as_deref().expect("validated: http spec has url");
spawn_http_mcp_server(url, spec.headers.as_ref()).await
}
McpTransportType::Sse => {
let url = spec.url.as_deref().expect("validated: sse spec has url");
spawn_sse_mcp_server(url, spec.headers.as_ref()).await
}
McpTransportType::Stdio => {
let command = spec
.command
.as_deref()
@@ -339,8 +348,9 @@ pub(crate) async fn spawn_mcp_server(
spawn_stdio_mcp_server(command, spec, log_path).await
}
}
}
async fn spawn_remote_mcp_server(
async fn spawn_http_mcp_server(
url: &str,
headers: Option<&HashMap<String, String>>,
) -> Result<Arc<ConnectedServer>> {
@@ -365,7 +375,23 @@ async fn spawn_remote_mcp_server(
let service = Arc::new(
().serve(transport)
.await
.with_context(|| format!("Failed to connect to remote MCP server: {url}"))?,
.with_context(|| format!("Failed to connect to HTTP MCP server: {url}"))?,
);
Ok(service)
}
async fn spawn_sse_mcp_server(
url: &str,
headers: Option<&HashMap<String, String>>,
) -> Result<Arc<ConnectedServer>> {
let sse = LegacySseTransport::connect(url, headers)
.await
.with_context(|| format!("Failed to connect to SSE MCP server: {url}"))?;
let (sink, stream) = sse.into_parts();
let service = Arc::new(
().serve((sink, stream))
.await
.with_context(|| format!("Failed to initialize SSE MCP server: {url}"))?,
);
Ok(service)
}
+261
View File
@@ -0,0 +1,261 @@
use anyhow::{Context, Result, anyhow};
use fmt::{Display, Formatter};
use futures_util::StreamExt;
use mpsc::{Receiver, Sender, channel};
use reqwest::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest_eventsource::{Event, EventSource};
use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::pin::Pin;
use std::task::Poll;
use tokio::sync::mpsc;
use tokio::time::Duration;
use url::Url;
const CHANNEL_BUF: usize = 64;
pub struct LegacySseTransport {
tx: Sender<ClientJsonRpcMessage>,
rx: Receiver<ServerJsonRpcMessage>,
}
impl LegacySseTransport {
pub async fn connect(sse_url: &str, headers: Option<&HashMap<String, String>>) -> Result<Self> {
let base_url =
Url::parse(sse_url).with_context(|| format!("Invalid SSE URL: {sse_url}"))?;
let mut client_builder = Client::builder();
let mut header_map = HeaderMap::new();
if let Some(hdrs) = headers {
for (k, v) in hdrs {
let name = k
.parse::<HeaderName>()
.with_context(|| format!("Invalid header name: {k}"))?;
let value = v
.parse::<HeaderValue>()
.with_context(|| format!("Invalid header value for {k}"))?;
header_map.insert(name, value);
}
client_builder = client_builder.default_headers(header_map);
}
let client = client_builder
.build()
.context("Failed to build HTTP client")?;
let request = client.get(sse_url);
let mut es = EventSource::new(request).context("Failed to open SSE connection")?;
let post_endpoint = wait_for_endpoint_event(&mut es, &base_url).await?;
let (outgoing_tx, outgoing_rx) = channel::<ClientJsonRpcMessage>(CHANNEL_BUF);
let (incoming_tx, incoming_rx) = channel::<ServerJsonRpcMessage>(CHANNEL_BUF);
tokio::spawn(sse_reader_task(es, incoming_tx));
tokio::spawn(post_writer_task(client, post_endpoint, outgoing_rx));
Ok(Self {
tx: outgoing_tx,
rx: incoming_rx,
})
}
pub fn into_parts(
self,
) -> (
SseSink<ClientJsonRpcMessage>,
SseStream<ServerJsonRpcMessage>,
) {
(
SseSink {
tx: PollSender {
tx: self.tx,
permit: None,
},
},
SseStream { rx: self.rx },
)
}
}
async fn wait_for_endpoint_event(es: &mut EventSource, base_url: &Url) -> Result<String> {
let timeout = Duration::from_secs(30);
tokio::time::timeout(timeout, async {
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(msg)) if msg.event == "endpoint" => {
let endpoint = msg.data.trim().to_string();
let resolved = resolve_endpoint(&endpoint, base_url)?;
return Ok(resolved);
}
Ok(Event::Message(_)) => {}
Err(e) => {
return Err(anyhow!(
"SSE connection error while waiting for endpoint event: {e}"
));
}
}
}
Err(anyhow!("SSE stream closed before receiving endpoint event"))
})
.await
.map_err(|_| anyhow!("Timed out waiting for endpoint event from SSE server (30s)"))?
}
fn resolve_endpoint(endpoint: &str, base_url: &Url) -> Result<String> {
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
Ok(endpoint.to_string())
} else {
let mut resolved = base_url.clone();
let (path, query) = endpoint.split_once('?').unwrap_or((endpoint, ""));
resolved.set_path(path);
resolved.set_query(if query.is_empty() { None } else { Some(query) });
Ok(resolved.to_string())
}
}
async fn sse_reader_task(mut es: EventSource, tx: Sender<ServerJsonRpcMessage>) {
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) if msg.event == "message" => {
match serde_json::from_str::<ServerJsonRpcMessage>(&msg.data) {
Ok(rpc_msg) => {
if tx.send(rpc_msg).await.is_err() {
break;
}
}
Err(e) => {
warn!("Failed to parse SSE message as JSON-RPC: {e}");
}
}
}
Ok(_) => {}
Err(reqwest_eventsource::Error::StreamEnded) => break,
Err(e) => {
error!("SSE stream error: {e}");
break;
}
}
}
es.close();
}
async fn post_writer_task(
client: Client,
endpoint: String,
mut rx: Receiver<ClientJsonRpcMessage>,
) {
while let Some(msg) = rx.recv().await {
let body = match serde_json::to_string(&msg) {
Ok(b) => b,
Err(e) => {
error!("Failed to serialize JSON-RPC message: {e}");
continue;
}
};
if let Err(e) = client
.post(&endpoint)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
{
error!("Failed to POST message to SSE endpoint: {e}");
}
}
}
pub struct SseSink<T> {
tx: PollSender<T>,
}
pub struct SseStream<T> {
rx: Receiver<T>,
}
impl<T: Send + 'static> futures_util::Sink<T> for SseSink<T> {
type Error = SseSinkError;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.tx.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.tx.start_send(item)
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
impl<T: Send + 'static> futures_util::Stream for SseStream<T> {
type Item = T;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
#[derive(Debug)]
pub enum SseSinkError {
Closed,
}
impl Display for SseSinkError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
SseSinkError::Closed => write!(f, "SSE transport channel closed"),
}
}
}
impl Error for SseSinkError {}
struct PollSender<T> {
tx: Sender<T>,
permit: Option<mpsc::OwnedPermit<T>>,
}
impl<T: Send + 'static> PollSender<T> {
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), SseSinkError>> {
if self.permit.is_some() {
return Poll::Ready(Ok(()));
}
let tx = self.tx.clone();
let mut fut = Box::pin(tx.reserve_owned());
match fut.as_mut().poll(cx) {
Poll::Ready(Ok(permit)) => {
self.permit = Some(permit);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(_)) => Poll::Ready(Err(SseSinkError::Closed)),
Poll::Pending => Poll::Pending,
}
}
fn start_send(&mut self, item: T) -> Result<(), SseSinkError> {
let permit = self.permit.take().ok_or(SseSinkError::Closed)?;
permit.send(item);
Ok(())
}
}