use super::{ToolCall, catch_error}; use crate::utils::AbortSignal; use anyhow::{Context, Result, anyhow, bail}; use futures_util::{Stream, StreamExt}; use reqwest::RequestBuilder; use reqwest_eventsource::{Error as EventSourceError, Event, RequestBuilderExt}; use serde_json::Value; use tokio::sync::mpsc::UnboundedSender; pub struct SseHandler { sender: UnboundedSender, abort_signal: AbortSignal, buffer: String, tool_calls: Vec, } impl SseHandler { pub fn new(sender: UnboundedSender, abort_signal: AbortSignal) -> Self { Self { sender, abort_signal, buffer: String::new(), tool_calls: Vec::new(), } } pub fn text(&mut self, text: &str) -> Result<()> { // debug!("HandleText: {}", text); if text.is_empty() { return Ok(()); } self.buffer.push_str(text); let ret = self .sender .send(SseEvent::Text(text.to_string())) .with_context(|| "Failed to send SseEvent:Text"); if let Err(err) = ret { if self.abort_signal.aborted() { return Ok(()); } return Err(err); } Ok(()) } pub fn done(&mut self) { // debug!("HandleDone"); let ret = self.sender.send(SseEvent::Done); if ret.is_err() { if self.abort_signal.aborted() { return; } warn!("Failed to send SseEvent:Done"); } } pub fn tool_call(&mut self, call: ToolCall) -> Result<()> { self.tool_calls.push(call); Ok(()) } pub fn abort(&self) -> AbortSignal { self.abort_signal.clone() } pub fn take(self) -> (String, Vec) { let Self { buffer, tool_calls, .. } = self; (buffer, tool_calls) } } #[derive(Debug)] pub enum SseEvent { Text(String), Done, } #[derive(Debug)] pub struct SseMessage { #[allow(unused)] pub event: String, pub data: String, } pub async fn sse_stream(builder: RequestBuilder, mut handle: F) -> Result<()> where F: FnMut(SseMessage) -> Result, { let mut es = builder.eventsource()?; while let Some(event) = es.next().await { match event { Ok(Event::Open) => {} Ok(Event::Message(message)) => { let message = SseMessage { event: message.event, data: message.data, }; if handle(message)? { break; } } Err(err) => { match err { EventSourceError::StreamEnded => {} EventSourceError::InvalidStatusCode(status, res) => { let text = res.text().await?; let data: Value = match text.parse() { Ok(data) => data, Err(_) => { bail!( "Invalid response data: {text} (status: {})", status.as_u16() ); } }; catch_error(&data, status.as_u16())?; } EventSourceError::InvalidContentType(header_value, res) => { let text = res.text().await?; bail!( "Invalid response event-stream. content-type: {}, data: {text}", header_value.to_str().unwrap_or_default() ); } _ => { bail!("{}", err); } } es.close(); } } } Ok(()) } pub async fn json_stream(mut stream: S, mut handle: F) -> Result<()> where S: Stream> + Unpin, F: FnMut(&str) -> Result<()>, E: std::error::Error, { let mut parser = JsonStreamParser::default(); let mut unparsed_bytes = vec![]; while let Some(chunk_bytes) = stream.next().await { let chunk_bytes = chunk_bytes.map_err(|err| anyhow!("Failed to read json stream, {err}"))?; unparsed_bytes.extend(chunk_bytes); match std::str::from_utf8(&unparsed_bytes) { Ok(text) => { parser.process(text, &mut handle)?; unparsed_bytes.clear(); } Err(_) => { continue; } } } if !unparsed_bytes.is_empty() { let text = std::str::from_utf8(&unparsed_bytes)?; parser.process(text, &mut handle)?; } Ok(()) } #[derive(Debug, Default)] struct JsonStreamParser { buffer: Vec, cursor: usize, start: Option, balances: Vec, quoting: bool, escape: bool, } impl JsonStreamParser { fn process(&mut self, text: &str, handle: &mut F) -> Result<()> where F: FnMut(&str) -> Result<()>, { self.buffer.extend(text.chars()); for i in self.cursor..self.buffer.len() { let ch = self.buffer[i]; if self.quoting { if ch == '\\' { self.escape = !self.escape; } else { if !self.escape && ch == '"' { self.quoting = false; } self.escape = false; } continue; } match ch { '"' => { self.quoting = true; self.escape = false; } '{' => { if self.balances.is_empty() { self.start = Some(i); } self.balances.push(ch); } '[' => { if self.start.is_some() { self.balances.push(ch); } } '}' => { self.balances.pop(); if self.balances.is_empty() && let Some(start) = self.start.take() { let value: String = self.buffer[start..=i].iter().collect(); handle(&value)?; } } ']' => { self.balances.pop(); } _ => {} } } self.cursor = self.buffer.len(); Ok(()) } } #[cfg(test)] mod tests { use super::*; use bytes::Bytes; use futures_util::stream; use rand::Rng; fn split_chunks(text: &str) -> Vec> { let mut rng = rand::rng(); let len = text.len(); let cut1 = rng.random_range(1..len - 1); let cut2 = rng.random_range(cut1 + 1..len); let chunk1 = text.as_bytes()[..cut1].to_vec(); let chunk2 = text.as_bytes()[cut1..cut2].to_vec(); let chunk3 = text.as_bytes()[cut2..].to_vec(); vec![chunk1, chunk2, chunk3] } macro_rules! assert_json_stream { ($input:expr, $output:expr) => { let chunks: Vec<_> = split_chunks($input) .into_iter() .map(|chunk| Ok::<_, std::convert::Infallible>(Bytes::from(chunk))) .collect(); let stream = stream::iter(chunks); let mut output = vec![]; let ret = json_stream(stream, |data| { output.push(data.to_string()); Ok(()) }) .await; assert!(ret.is_ok()); assert_eq!($output.replace("\r\n", "\n"), output.join("\n")) }; } #[tokio::test] async fn test_json_stream_ndjson() { let data = r#"{"key": "value"} {"key": "value2"} {"key": "value3"}"#; assert_json_stream!(data, data); } #[tokio::test] async fn test_json_stream_array() { let input = r#"[ {"key": "value"}, {"key": "value2"}, {"key": "value3"},"#; let output = r#"{"key": "value"} {"key": "value2"} {"key": "value3"}"#; assert_json_stream!(input, output); } }