diff --git a/src/client/common.rs b/src/client/common.rs index ae6d2a3..ad93f1e 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -433,10 +433,13 @@ pub async fn call_chat_completions( client.global_config().read().print_markdown(&text)?; } } - Ok(( - text, - eval_tool_calls(client.global_config(), tool_calls).await?, - )) + let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?; + if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() { + tool_results + .iter() + .for_each(|res| tracker.record_call(res.call.clone())); + } + Ok((text, tool_results)) } Err(err) => Err(err), } @@ -467,10 +470,13 @@ pub async fn call_chat_completions_streaming( if !text.is_empty() && !text.ends_with('\n') { println!(); } - Ok(( - text, - eval_tool_calls(client.global_config(), tool_calls).await?, - )) + let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?; + if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() { + tool_results + .iter() + .for_each(|res| tracker.record_call(res.call.clone())); + } + Ok((text, tool_results)) } Err(err) => { if !text.is_empty() { diff --git a/src/client/stream.rs b/src/client/stream.rs index d74798d..3780d93 100644 --- a/src/client/stream.rs +++ b/src/client/stream.rs @@ -13,6 +13,9 @@ pub struct SseHandler { abort_signal: AbortSignal, buffer: String, tool_calls: Vec, + last_tool_calls: Vec, + max_call_repeats: usize, + call_repeat_chain_len: usize, } impl SseHandler { @@ -22,11 +25,13 @@ impl SseHandler { abort_signal, buffer: String::new(), tool_calls: Vec::new(), + last_tool_calls: Vec::new(), + max_call_repeats: 2, + call_repeat_chain_len: 3, } } pub fn text(&mut self, text: &str) -> Result<()> { - // debug!("HandleText: {}", text); if text.is_empty() { return Ok(()); } @@ -45,7 +50,6 @@ impl SseHandler { } pub fn done(&mut self) { - // debug!("HandleDone"); let ret = self.sender.send(SseEvent::Done); if ret.is_err() { if self.abort_signal.aborted() { @@ -56,15 +60,114 @@ impl SseHandler { } pub fn tool_call(&mut self, call: ToolCall) -> Result<()> { + if self.is_call_loop(&call) { + let loop_message = self.create_loop_detection_message(&call); + return Err(anyhow!(loop_message)); + } + + if self.last_tool_calls.len() == self.call_repeat_chain_len * self.max_call_repeats { + self.last_tool_calls.remove(0); + } + self.last_tool_calls.push(call.clone()); + self.tool_calls.push(call); Ok(()) } + fn is_call_loop(&self, new_call: &ToolCall) -> bool { + if self.last_tool_calls.len() < self.call_repeat_chain_len { + return false; + } + + if let Some(last_call) = self.last_tool_calls.last() + && self.calls_match(last_call, new_call) + { + let mut repeat_count = 1; + for i in (0..self.last_tool_calls.len()).rev() { + if i == 0 { + break; + } + if self.calls_match(&self.last_tool_calls[i - 1], &self.last_tool_calls[i]) { + repeat_count += 1; + if repeat_count >= self.max_call_repeats { + return true; + } + } else { + break; + } + } + } + + let chain_start = self + .last_tool_calls + .len() + .saturating_sub(self.call_repeat_chain_len); + let chain = &self.last_tool_calls[chain_start..]; + + if chain.len() == self.call_repeat_chain_len { + let mut is_repeating = true; + for i in 0..chain.len() - 1 { + if !self.calls_match(&chain[i], &chain[i + 1]) { + is_repeating = false; + break; + } + } + if is_repeating && self.calls_match(&chain[chain.len() - 1], new_call) { + return true; + } + } + + false + } + + fn calls_match(&self, call1: &ToolCall, call2: &ToolCall) -> bool { + call1.name == call2.name && call1.arguments == call2.arguments + } + + fn create_loop_detection_message(&self, new_call: &ToolCall) -> String { + let mut message = String::from("⚠️ Call loop detected! ⚠️"); + + message.push_str(&format!( + "The call '{}' with arguments '{}' is repeating.\n", + new_call.name, new_call.arguments + )); + + if self.last_tool_calls.len() >= self.call_repeat_chain_len { + let chain_start = self + .last_tool_calls + .len() + .saturating_sub(self.call_repeat_chain_len); + let chain = &self.last_tool_calls[chain_start..]; + + message.push_str("The following sequence of calls is repeating:\n"); + for (i, call) in chain.iter().enumerate() { + message.push_str(&format!( + " {}. {} with arguments {}\n", + i + 1, + call.name, + call.arguments + )); + } + } + + message.push_str("\nPlease move on to the next task in your sequence using the last output you got from the call or chain you are trying to re-execute. "); + message.push_str( + "Consider using different parameters or a different approach to avoid this loop.", + ); + + message + } + pub fn abort(&self) -> AbortSignal { self.abort_signal.clone() } + #[cfg(test)] + pub fn last_tool_calls(&self) -> &[ToolCall] { + &self.last_tool_calls + } + pub fn take(self) -> (String, Vec) { let Self { buffer, tool_calls, .. @@ -240,6 +343,53 @@ mod tests { use bytes::Bytes; use futures_util::stream; use rand::Rng; + use serde_json::json; + + #[test] + fn test_last_tool_calls_ring_buffer() { + let (sender, _) = tokio::sync::mpsc::unbounded_channel(); + let abort_signal = crate::utils::create_abort_signal(); + let mut handler = SseHandler::new(sender, abort_signal); + + for i in 0..15 { + let call = ToolCall::new(format!("test_function_{}", i), json!({"param": i}), None); + handler.tool_call(call.clone()).unwrap(); + } + let lt_len = handler.call_repeat_chain_len * handler.max_call_repeats; + assert_eq!(handler.last_tool_calls().len(), lt_len); + + assert_eq!( + handler.last_tool_calls()[lt_len - 1].name, + "test_function_14" + ); + + assert_eq!( + handler.last_tool_calls()[0].name, + format!("test_function_{}", 14 - lt_len + 1) + ); + } + + #[test] + fn test_call_loop_detection() { + let (sender, _) = tokio::sync::mpsc::unbounded_channel(); + let abort_signal = crate::utils::create_abort_signal(); + let mut handler = SseHandler::new(sender, abort_signal); + + handler.max_call_repeats = 2; + handler.call_repeat_chain_len = 3; + + let call = ToolCall::new("test_function_loop".to_string(), json!({"param": 1}), None); + + for _ in 0..3 { + handler.tool_call(call.clone()).unwrap(); + } + + let result = handler.tool_call(call.clone()); + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message.contains("Call loop detected!")); + assert!(error_message.contains("test_function_loop")); + } fn split_chunks(text: &str) -> Vec> { let mut rng = rand::rng(); diff --git a/src/config/mod.rs b/src/config/mod.rs index 60a3ca0..299bdbf 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -17,7 +17,7 @@ use crate::client::{ ClientConfig, MessageContentToolCalls, Model, ModelType, OPENAI_COMPATIBLE_PROVIDERS, ProviderModels, create_client_config, list_client_types, list_models, }; -use crate::function::{FunctionDeclaration, Functions, ToolResult}; +use crate::function::{FunctionDeclaration, Functions, ToolCallTracker, ToolResult}; use crate::rag::Rag; use crate::render::{MarkdownRender, RenderOptions}; use crate::utils::*; @@ -199,6 +199,8 @@ pub struct Config { pub rag: Option>, #[serde(skip)] pub agent: Option, + #[serde(skip)] + pub(crate) tool_call_tracker: Option, } impl Default for Config { @@ -271,6 +273,7 @@ impl Default for Config { session: None, rag: None, agent: None, + tool_call_tracker: Some(ToolCallTracker::default()), } } }