feat: legacy SSE support for MCP server configurations
This commit is contained in:
@@ -5,4 +5,3 @@
|
|||||||
.idea/
|
.idea/
|
||||||
/loki.iml
|
/loki.iml
|
||||||
/.idea/
|
/.idea/
|
||||||
src
|
|
||||||
|
|||||||
+33
-7
@@ -1,3 +1,5 @@
|
|||||||
|
mod sse_transport;
|
||||||
|
|
||||||
use crate::config::AppConfig;
|
use crate::config::AppConfig;
|
||||||
use crate::config::paths;
|
use crate::config::paths;
|
||||||
use crate::utils::{AbortSignal, abortable_run_with_spinner};
|
use crate::utils::{AbortSignal, abortable_run_with_spinner};
|
||||||
@@ -5,6 +7,7 @@ use crate::vault::Vault;
|
|||||||
use crate::vault::interpolate_secrets;
|
use crate::vault::interpolate_secrets;
|
||||||
use anyhow::{Context, Result, anyhow};
|
use anyhow::{Context, Result, anyhow};
|
||||||
use futures_util::{StreamExt, TryStreamExt, stream};
|
use futures_util::{StreamExt, TryStreamExt, stream};
|
||||||
|
use http::{HeaderName, HeaderValue};
|
||||||
use indoc::formatdoc;
|
use indoc::formatdoc;
|
||||||
use rmcp::service::RunningService;
|
use rmcp::service::RunningService;
|
||||||
use rmcp::transport::StreamableHttpClientTransport;
|
use rmcp::transport::StreamableHttpClientTransport;
|
||||||
@@ -12,12 +15,12 @@ use rmcp::transport::TokioChildProcess;
|
|||||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||||
use rmcp::{RoleClient, ServiceExt};
|
use rmcp::{RoleClient, ServiceExt};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sse_transport::LegacySseTransport;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::fs::OpenOptions;
|
use std::fs::OpenOptions;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::process::Stdio;
|
use std::process::Stdio;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use http::{HeaderName, HeaderValue};
|
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
|
|
||||||
pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
|
pub const MCP_INVOKE_META_FUNCTION_NAME_PREFIX: &str = "mcp_invoke";
|
||||||
@@ -328,10 +331,16 @@ pub(crate) async fn spawn_mcp_server(
|
|||||||
spec: &McpServer,
|
spec: &McpServer,
|
||||||
log_path: Option<&Path>,
|
log_path: Option<&Path>,
|
||||||
) -> Result<Arc<ConnectedServer>> {
|
) -> Result<Arc<ConnectedServer>> {
|
||||||
if spec.is_remote() {
|
match spec.transport_type {
|
||||||
let url = spec.url.as_deref().expect("validated: remote spec has url");
|
McpTransportType::Http => {
|
||||||
spawn_remote_mcp_server(url, spec.headers.as_ref()).await
|
let url = spec.url.as_deref().expect("validated: http spec has url");
|
||||||
} else {
|
spawn_http_mcp_server(url, spec.headers.as_ref()).await
|
||||||
|
}
|
||||||
|
McpTransportType::Sse => {
|
||||||
|
let url = spec.url.as_deref().expect("validated: sse spec has url");
|
||||||
|
spawn_sse_mcp_server(url, spec.headers.as_ref()).await
|
||||||
|
}
|
||||||
|
McpTransportType::Stdio => {
|
||||||
let command = spec
|
let command = spec
|
||||||
.command
|
.command
|
||||||
.as_deref()
|
.as_deref()
|
||||||
@@ -339,8 +348,9 @@ pub(crate) async fn spawn_mcp_server(
|
|||||||
spawn_stdio_mcp_server(command, spec, log_path).await
|
spawn_stdio_mcp_server(command, spec, log_path).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn spawn_remote_mcp_server(
|
async fn spawn_http_mcp_server(
|
||||||
url: &str,
|
url: &str,
|
||||||
headers: Option<&HashMap<String, String>>,
|
headers: Option<&HashMap<String, String>>,
|
||||||
) -> Result<Arc<ConnectedServer>> {
|
) -> Result<Arc<ConnectedServer>> {
|
||||||
@@ -365,7 +375,23 @@ async fn spawn_remote_mcp_server(
|
|||||||
let service = Arc::new(
|
let service = Arc::new(
|
||||||
().serve(transport)
|
().serve(transport)
|
||||||
.await
|
.await
|
||||||
.with_context(|| format!("Failed to connect to remote MCP server: {url}"))?,
|
.with_context(|| format!("Failed to connect to HTTP MCP server: {url}"))?,
|
||||||
|
);
|
||||||
|
Ok(service)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn spawn_sse_mcp_server(
|
||||||
|
url: &str,
|
||||||
|
headers: Option<&HashMap<String, String>>,
|
||||||
|
) -> Result<Arc<ConnectedServer>> {
|
||||||
|
let sse = LegacySseTransport::connect(url, headers)
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Failed to connect to SSE MCP server: {url}"))?;
|
||||||
|
let (sink, stream) = sse.into_parts();
|
||||||
|
let service = Arc::new(
|
||||||
|
().serve((sink, stream))
|
||||||
|
.await
|
||||||
|
.with_context(|| format!("Failed to initialize SSE MCP server: {url}"))?,
|
||||||
);
|
);
|
||||||
Ok(service)
|
Ok(service)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,261 @@
|
|||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use fmt::{Display, Formatter};
|
||||||
|
use futures_util::StreamExt;
|
||||||
|
use mpsc::{Receiver, Sender, channel};
|
||||||
|
use reqwest::Client;
|
||||||
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
||||||
|
use reqwest_eventsource::{Event, EventSource};
|
||||||
|
use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::Poll;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::time::Duration;
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
const CHANNEL_BUF: usize = 64;
|
||||||
|
|
||||||
|
pub struct LegacySseTransport {
|
||||||
|
tx: Sender<ClientJsonRpcMessage>,
|
||||||
|
rx: Receiver<ServerJsonRpcMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LegacySseTransport {
|
||||||
|
pub async fn connect(sse_url: &str, headers: Option<&HashMap<String, String>>) -> Result<Self> {
|
||||||
|
let base_url =
|
||||||
|
Url::parse(sse_url).with_context(|| format!("Invalid SSE URL: {sse_url}"))?;
|
||||||
|
|
||||||
|
let mut client_builder = Client::builder();
|
||||||
|
let mut header_map = HeaderMap::new();
|
||||||
|
if let Some(hdrs) = headers {
|
||||||
|
for (k, v) in hdrs {
|
||||||
|
let name = k
|
||||||
|
.parse::<HeaderName>()
|
||||||
|
.with_context(|| format!("Invalid header name: {k}"))?;
|
||||||
|
let value = v
|
||||||
|
.parse::<HeaderValue>()
|
||||||
|
.with_context(|| format!("Invalid header value for {k}"))?;
|
||||||
|
header_map.insert(name, value);
|
||||||
|
}
|
||||||
|
client_builder = client_builder.default_headers(header_map);
|
||||||
|
}
|
||||||
|
let client = client_builder
|
||||||
|
.build()
|
||||||
|
.context("Failed to build HTTP client")?;
|
||||||
|
|
||||||
|
let request = client.get(sse_url);
|
||||||
|
let mut es = EventSource::new(request).context("Failed to open SSE connection")?;
|
||||||
|
|
||||||
|
let post_endpoint = wait_for_endpoint_event(&mut es, &base_url).await?;
|
||||||
|
|
||||||
|
let (outgoing_tx, outgoing_rx) = channel::<ClientJsonRpcMessage>(CHANNEL_BUF);
|
||||||
|
let (incoming_tx, incoming_rx) = channel::<ServerJsonRpcMessage>(CHANNEL_BUF);
|
||||||
|
|
||||||
|
tokio::spawn(sse_reader_task(es, incoming_tx));
|
||||||
|
tokio::spawn(post_writer_task(client, post_endpoint, outgoing_rx));
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
tx: outgoing_tx,
|
||||||
|
rx: incoming_rx,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_parts(
|
||||||
|
self,
|
||||||
|
) -> (
|
||||||
|
SseSink<ClientJsonRpcMessage>,
|
||||||
|
SseStream<ServerJsonRpcMessage>,
|
||||||
|
) {
|
||||||
|
(
|
||||||
|
SseSink {
|
||||||
|
tx: PollSender {
|
||||||
|
tx: self.tx,
|
||||||
|
permit: None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SseStream { rx: self.rx },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn wait_for_endpoint_event(es: &mut EventSource, base_url: &Url) -> Result<String> {
|
||||||
|
let timeout = Duration::from_secs(30);
|
||||||
|
tokio::time::timeout(timeout, async {
|
||||||
|
while let Some(event) = es.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(Event::Open) => {}
|
||||||
|
Ok(Event::Message(msg)) if msg.event == "endpoint" => {
|
||||||
|
let endpoint = msg.data.trim().to_string();
|
||||||
|
let resolved = resolve_endpoint(&endpoint, base_url)?;
|
||||||
|
return Ok(resolved);
|
||||||
|
}
|
||||||
|
Ok(Event::Message(_)) => {}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(anyhow!(
|
||||||
|
"SSE connection error while waiting for endpoint event: {e}"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(anyhow!("SSE stream closed before receiving endpoint event"))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|_| anyhow!("Timed out waiting for endpoint event from SSE server (30s)"))?
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_endpoint(endpoint: &str, base_url: &Url) -> Result<String> {
|
||||||
|
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
|
||||||
|
Ok(endpoint.to_string())
|
||||||
|
} else {
|
||||||
|
let mut resolved = base_url.clone();
|
||||||
|
let (path, query) = endpoint.split_once('?').unwrap_or((endpoint, ""));
|
||||||
|
resolved.set_path(path);
|
||||||
|
resolved.set_query(if query.is_empty() { None } else { Some(query) });
|
||||||
|
Ok(resolved.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn sse_reader_task(mut es: EventSource, tx: Sender<ServerJsonRpcMessage>) {
|
||||||
|
while let Some(event) = es.next().await {
|
||||||
|
match event {
|
||||||
|
Ok(Event::Message(msg)) if msg.event == "message" => {
|
||||||
|
match serde_json::from_str::<ServerJsonRpcMessage>(&msg.data) {
|
||||||
|
Ok(rpc_msg) => {
|
||||||
|
if tx.send(rpc_msg).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to parse SSE message as JSON-RPC: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(reqwest_eventsource::Error::StreamEnded) => break,
|
||||||
|
Err(e) => {
|
||||||
|
error!("SSE stream error: {e}");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
es.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn post_writer_task(
|
||||||
|
client: Client,
|
||||||
|
endpoint: String,
|
||||||
|
mut rx: Receiver<ClientJsonRpcMessage>,
|
||||||
|
) {
|
||||||
|
while let Some(msg) = rx.recv().await {
|
||||||
|
let body = match serde_json::to_string(&msg) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to serialize JSON-RPC message: {e}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Err(e) = client
|
||||||
|
.post(&endpoint)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.body(body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
error!("Failed to POST message to SSE endpoint: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SseSink<T> {
|
||||||
|
tx: PollSender<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SseStream<T> {
|
||||||
|
rx: Receiver<T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Send + 'static> futures_util::Sink<T> for SseSink<T> {
|
||||||
|
type Error = SseSinkError;
|
||||||
|
|
||||||
|
fn poll_ready(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.tx.poll_ready(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
|
||||||
|
self.tx.start_send(item)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), Self::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Send + 'static> futures_util::Stream for SseStream<T> {
|
||||||
|
type Item = T;
|
||||||
|
|
||||||
|
fn poll_next(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<Option<Self::Item>> {
|
||||||
|
self.rx.poll_recv(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum SseSinkError {
|
||||||
|
Closed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for SseSinkError {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
SseSinkError::Closed => write!(f, "SSE transport channel closed"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for SseSinkError {}
|
||||||
|
|
||||||
|
struct PollSender<T> {
|
||||||
|
tx: Sender<T>,
|
||||||
|
permit: Option<mpsc::OwnedPermit<T>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Send + 'static> PollSender<T> {
|
||||||
|
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), SseSinkError>> {
|
||||||
|
if self.permit.is_some() {
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
let tx = self.tx.clone();
|
||||||
|
let mut fut = Box::pin(tx.reserve_owned());
|
||||||
|
match fut.as_mut().poll(cx) {
|
||||||
|
Poll::Ready(Ok(permit)) => {
|
||||||
|
self.permit = Some(permit);
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
Poll::Ready(Err(_)) => Poll::Ready(Err(SseSinkError::Closed)),
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(&mut self, item: T) -> Result<(), SseSinkError> {
|
||||||
|
let permit = self.permit.take().ok_or(SseSinkError::Closed)?;
|
||||||
|
permit.send(item);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user