diff --git a/src/mcp/sse_transport.rs b/src/mcp/sse_transport.rs index 876fa3d..bc19b74 100644 --- a/src/mcp/sse_transport.rs +++ b/src/mcp/sse_transport.rs @@ -1,7 +1,7 @@ use anyhow::{Context, Result, anyhow}; use fmt::{Display, Formatter}; use futures_util::StreamExt; -use mpsc::{Receiver, Sender, channel}; +use mpsc::{Receiver, Sender, channel, OwnedPermit}; use reqwest::Client; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest_eventsource::{Event, EventSource}; @@ -9,8 +9,10 @@ use rmcp::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; use std::collections::HashMap; use std::error::Error; use std::fmt; +use std::future::Future; use std::pin::Pin; use std::task::Poll; +use mpsc::error::SendError; use tokio::sync::mpsc; use tokio::time::Duration; use url::Url; @@ -73,6 +75,7 @@ impl LegacySseTransport { tx: PollSender { tx: self.tx, permit: None, + acquiring: None, }, }, SseStream { rx: self.rx }, @@ -231,9 +234,13 @@ impl Display for SseSinkError { impl Error for SseSinkError {} +type ReserveOwned = + Pin, SendError<()>>> + Send>>; + struct PollSender { tx: Sender, - permit: Option>, + permit: Option>, + acquiring: Option>, } impl PollSender { @@ -241,14 +248,21 @@ impl PollSender { if self.permit.is_some() { return Poll::Ready(Ok(())); } - let tx = self.tx.clone(); - let mut fut = Box::pin(tx.reserve_owned()); + + let fut = self + .acquiring + .get_or_insert_with(|| Box::pin(self.tx.clone().reserve_owned())); + match fut.as_mut().poll(cx) { Poll::Ready(Ok(permit)) => { + self.acquiring = None; self.permit = Some(permit); Poll::Ready(Ok(())) } - Poll::Ready(Err(_)) => Poll::Ready(Err(SseSinkError::Closed)), + Poll::Ready(Err(_)) => { + self.acquiring = None; + Poll::Ready(Err(SseSinkError::Closed)) + } Poll::Pending => Poll::Pending, } }