Files
coyote/src/mcp/sse_transport.rs
T

275 lines
8.1 KiB
Rust

use anyhow::{Context, Result, anyhow};
use fmt::{Display, Formatter};
use futures_util::StreamExt;
use mpsc::error::SendError;
use mpsc::{OwnedPermit, Receiver, Sender, channel};
use reqwest::Client;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest_eventsource::{Event, EventSource};
use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage};
use std::collections::HashMap;
use std::error::Error;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::Poll;
use tokio::sync::mpsc;
use tokio::time::Duration;
use url::Url;
const CHANNEL_BUF: usize = 64;
pub struct LegacySseTransport {
tx: Sender<ClientJsonRpcMessage>,
rx: Receiver<ServerJsonRpcMessage>,
}
impl LegacySseTransport {
pub async fn connect(sse_url: &str, headers: Option<&HashMap<String, String>>) -> Result<Self> {
let base_url =
Url::parse(sse_url).with_context(|| format!("Invalid SSE URL: {sse_url}"))?;
let mut client_builder = Client::builder();
let mut header_map = HeaderMap::new();
if let Some(hdrs) = headers {
for (k, v) in hdrs {
let name = k
.parse::<HeaderName>()
.with_context(|| format!("Invalid header name: {k}"))?;
let value = v
.parse::<HeaderValue>()
.with_context(|| format!("Invalid header value for {k}"))?;
header_map.insert(name, value);
}
client_builder = client_builder.default_headers(header_map);
}
let client = client_builder
.build()
.context("Failed to build HTTP client")?;
let request = client.get(sse_url);
let mut es = EventSource::new(request).context("Failed to open SSE connection")?;
let post_endpoint = wait_for_endpoint_event(&mut es, &base_url).await?;
let (outgoing_tx, outgoing_rx) = channel::<ClientJsonRpcMessage>(CHANNEL_BUF);
let (incoming_tx, incoming_rx) = channel::<ServerJsonRpcMessage>(CHANNEL_BUF);
tokio::spawn(sse_reader_task(es, incoming_tx));
tokio::spawn(post_writer_task(client, post_endpoint, outgoing_rx));
Ok(Self {
tx: outgoing_tx,
rx: incoming_rx,
})
}
pub fn into_parts(
self,
) -> (
SseSink<ClientJsonRpcMessage>,
SseStream<ServerJsonRpcMessage>,
) {
(
SseSink {
tx: PollSender {
tx: self.tx,
permit: None,
acquiring: None,
},
},
SseStream { rx: self.rx },
)
}
}
async fn wait_for_endpoint_event(es: &mut EventSource, base_url: &Url) -> Result<String> {
let timeout = Duration::from_secs(30);
tokio::time::timeout(timeout, async {
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => {}
Ok(Event::Message(msg)) if msg.event == "endpoint" => {
let endpoint = msg.data.trim().to_string();
let resolved = resolve_endpoint(&endpoint, base_url)?;
return Ok(resolved);
}
Ok(Event::Message(_)) => {}
Err(e) => {
return Err(anyhow!(
"SSE connection error while waiting for endpoint event: {e}"
));
}
}
}
Err(anyhow!("SSE stream closed before receiving endpoint event"))
})
.await
.map_err(|_| anyhow!("Timed out waiting for endpoint event from SSE server (30s)"))?
}
fn resolve_endpoint(endpoint: &str, base_url: &Url) -> Result<String> {
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
Ok(endpoint.to_string())
} else {
let mut resolved = base_url.clone();
let (path, query) = endpoint.split_once('?').unwrap_or((endpoint, ""));
resolved.set_path(path);
resolved.set_query(if query.is_empty() { None } else { Some(query) });
Ok(resolved.to_string())
}
}
async fn sse_reader_task(mut es: EventSource, tx: Sender<ServerJsonRpcMessage>) {
while let Some(event) = es.next().await {
match event {
Ok(Event::Message(msg)) if msg.event == "message" => {
match serde_json::from_str::<ServerJsonRpcMessage>(&msg.data) {
Ok(rpc_msg) => {
if tx.send(rpc_msg).await.is_err() {
break;
}
}
Err(e) => {
warn!("Failed to parse SSE message as JSON-RPC: {e}");
}
}
}
Ok(_) => {}
Err(reqwest_eventsource::Error::StreamEnded) => break,
Err(e) => {
error!("SSE stream error: {e}");
break;
}
}
}
es.close();
}
async fn post_writer_task(
client: Client,
endpoint: String,
mut rx: Receiver<ClientJsonRpcMessage>,
) {
while let Some(msg) = rx.recv().await {
let body = match serde_json::to_string(&msg) {
Ok(b) => b,
Err(e) => {
error!("Failed to serialize JSON-RPC message: {e}");
continue;
}
};
if let Err(e) = client
.post(&endpoint)
.header("Content-Type", "application/json")
.body(body)
.send()
.await
{
error!("Failed to POST message to SSE endpoint: {e}");
}
}
}
pub struct SseSink<T> {
tx: PollSender<T>,
}
pub struct SseStream<T> {
rx: Receiver<T>,
}
impl<T: Send + 'static> futures_util::Sink<T> for SseSink<T> {
type Error = SseSinkError;
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
self.tx.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.tx.start_send(item)
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
impl<T: Send + 'static> futures_util::Stream for SseStream<T> {
type Item = T;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
#[derive(Debug)]
pub enum SseSinkError {
Closed,
}
impl Display for SseSinkError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self {
SseSinkError::Closed => write!(f, "SSE transport channel closed"),
}
}
}
impl Error for SseSinkError {}
type ReserveOwned<T> = Pin<Box<dyn Future<Output = Result<OwnedPermit<T>, SendError<()>>> + Send>>;
struct PollSender<T> {
tx: Sender<T>,
permit: Option<OwnedPermit<T>>,
acquiring: Option<ReserveOwned<T>>,
}
impl<T: Send + 'static> PollSender<T> {
fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), SseSinkError>> {
if self.permit.is_some() {
return Poll::Ready(Ok(()));
}
let fut = self
.acquiring
.get_or_insert_with(|| Box::pin(self.tx.clone().reserve_owned()));
match fut.as_mut().poll(cx) {
Poll::Ready(Ok(permit)) => {
self.acquiring = None;
self.permit = Some(permit);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(_)) => {
self.acquiring = None;
Poll::Ready(Err(SseSinkError::Closed))
}
Poll::Pending => Poll::Pending,
}
}
fn start_send(&mut self, item: T) -> Result<(), SseSinkError> {
let permit = self.permit.take().ok_or(SseSinkError::Closed)?;
permit.send(item);
Ok(())
}
}