feat: Implemented retry logic for failed tool invocations so the LLM can learn from the result and try again; Also implemented chain loop detection to prevent loops
This commit is contained in:
+14
-8
@@ -433,10 +433,13 @@ pub async fn call_chat_completions(
|
||||
client.global_config().read().print_markdown(&text)?;
|
||||
}
|
||||
}
|
||||
Ok((
|
||||
text,
|
||||
eval_tool_calls(client.global_config(), tool_calls).await?,
|
||||
))
|
||||
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
||||
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| tracker.record_call(res.call.clone()));
|
||||
}
|
||||
Ok((text, tool_results))
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
@@ -467,10 +470,13 @@ pub async fn call_chat_completions_streaming(
|
||||
if !text.is_empty() && !text.ends_with('\n') {
|
||||
println!();
|
||||
}
|
||||
Ok((
|
||||
text,
|
||||
eval_tool_calls(client.global_config(), tool_calls).await?,
|
||||
))
|
||||
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
||||
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
||||
tool_results
|
||||
.iter()
|
||||
.for_each(|res| tracker.record_call(res.call.clone()));
|
||||
}
|
||||
Ok((text, tool_results))
|
||||
}
|
||||
Err(err) => {
|
||||
if !text.is_empty() {
|
||||
|
||||
+152
-2
@@ -13,6 +13,9 @@ pub struct SseHandler {
|
||||
abort_signal: AbortSignal,
|
||||
buffer: String,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
last_tool_calls: Vec<ToolCall>,
|
||||
max_call_repeats: usize,
|
||||
call_repeat_chain_len: usize,
|
||||
}
|
||||
|
||||
impl SseHandler {
|
||||
@@ -22,11 +25,13 @@ impl SseHandler {
|
||||
abort_signal,
|
||||
buffer: String::new(),
|
||||
tool_calls: Vec::new(),
|
||||
last_tool_calls: Vec::new(),
|
||||
max_call_repeats: 2,
|
||||
call_repeat_chain_len: 3,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn text(&mut self, text: &str) -> Result<()> {
|
||||
// debug!("HandleText: {}", text);
|
||||
if text.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -45,7 +50,6 @@ impl SseHandler {
|
||||
}
|
||||
|
||||
pub fn done(&mut self) {
|
||||
// debug!("HandleDone");
|
||||
let ret = self.sender.send(SseEvent::Done);
|
||||
if ret.is_err() {
|
||||
if self.abort_signal.aborted() {
|
||||
@@ -56,15 +60,114 @@ impl SseHandler {
|
||||
}
|
||||
|
||||
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
|
||||
if self.is_call_loop(&call) {
|
||||
let loop_message = self.create_loop_detection_message(&call);
|
||||
return Err(anyhow!(loop_message));
|
||||
}
|
||||
|
||||
if self.last_tool_calls.len() == self.call_repeat_chain_len * self.max_call_repeats {
|
||||
self.last_tool_calls.remove(0);
|
||||
}
|
||||
self.last_tool_calls.push(call.clone());
|
||||
|
||||
self.tool_calls.push(call);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_call_loop(&self, new_call: &ToolCall) -> bool {
|
||||
if self.last_tool_calls.len() < self.call_repeat_chain_len {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(last_call) = self.last_tool_calls.last()
|
||||
&& self.calls_match(last_call, new_call)
|
||||
{
|
||||
let mut repeat_count = 1;
|
||||
for i in (0..self.last_tool_calls.len()).rev() {
|
||||
if i == 0 {
|
||||
break;
|
||||
}
|
||||
if self.calls_match(&self.last_tool_calls[i - 1], &self.last_tool_calls[i]) {
|
||||
repeat_count += 1;
|
||||
if repeat_count >= self.max_call_repeats {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let chain_start = self
|
||||
.last_tool_calls
|
||||
.len()
|
||||
.saturating_sub(self.call_repeat_chain_len);
|
||||
let chain = &self.last_tool_calls[chain_start..];
|
||||
|
||||
if chain.len() == self.call_repeat_chain_len {
|
||||
let mut is_repeating = true;
|
||||
for i in 0..chain.len() - 1 {
|
||||
if !self.calls_match(&chain[i], &chain[i + 1]) {
|
||||
is_repeating = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if is_repeating && self.calls_match(&chain[chain.len() - 1], new_call) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn calls_match(&self, call1: &ToolCall, call2: &ToolCall) -> bool {
|
||||
call1.name == call2.name && call1.arguments == call2.arguments
|
||||
}
|
||||
|
||||
fn create_loop_detection_message(&self, new_call: &ToolCall) -> String {
|
||||
let mut message = String::from("⚠️ Call loop detected! ⚠️");
|
||||
|
||||
message.push_str(&format!(
|
||||
"The call '{}' with arguments '{}' is repeating.\n",
|
||||
new_call.name, new_call.arguments
|
||||
));
|
||||
|
||||
if self.last_tool_calls.len() >= self.call_repeat_chain_len {
|
||||
let chain_start = self
|
||||
.last_tool_calls
|
||||
.len()
|
||||
.saturating_sub(self.call_repeat_chain_len);
|
||||
let chain = &self.last_tool_calls[chain_start..];
|
||||
|
||||
message.push_str("The following sequence of calls is repeating:\n");
|
||||
for (i, call) in chain.iter().enumerate() {
|
||||
message.push_str(&format!(
|
||||
" {}. {} with arguments {}\n",
|
||||
i + 1,
|
||||
call.name,
|
||||
call.arguments
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
message.push_str("\nPlease move on to the next task in your sequence using the last output you got from the call or chain you are trying to re-execute. ");
|
||||
message.push_str(
|
||||
"Consider using different parameters or a different approach to avoid this loop.",
|
||||
);
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
pub fn abort(&self) -> AbortSignal {
|
||||
self.abort_signal.clone()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn last_tool_calls(&self) -> &[ToolCall] {
|
||||
&self.last_tool_calls
|
||||
}
|
||||
|
||||
pub fn take(self) -> (String, Vec<ToolCall>) {
|
||||
let Self {
|
||||
buffer, tool_calls, ..
|
||||
@@ -240,6 +343,53 @@ mod tests {
|
||||
use bytes::Bytes;
|
||||
use futures_util::stream;
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_last_tool_calls_ring_buffer() {
|
||||
let (sender, _) = tokio::sync::mpsc::unbounded_channel();
|
||||
let abort_signal = crate::utils::create_abort_signal();
|
||||
let mut handler = SseHandler::new(sender, abort_signal);
|
||||
|
||||
for i in 0..15 {
|
||||
let call = ToolCall::new(format!("test_function_{}", i), json!({"param": i}), None);
|
||||
handler.tool_call(call.clone()).unwrap();
|
||||
}
|
||||
let lt_len = handler.call_repeat_chain_len * handler.max_call_repeats;
|
||||
assert_eq!(handler.last_tool_calls().len(), lt_len);
|
||||
|
||||
assert_eq!(
|
||||
handler.last_tool_calls()[lt_len - 1].name,
|
||||
"test_function_14"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
handler.last_tool_calls()[0].name,
|
||||
format!("test_function_{}", 14 - lt_len + 1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_call_loop_detection() {
|
||||
let (sender, _) = tokio::sync::mpsc::unbounded_channel();
|
||||
let abort_signal = crate::utils::create_abort_signal();
|
||||
let mut handler = SseHandler::new(sender, abort_signal);
|
||||
|
||||
handler.max_call_repeats = 2;
|
||||
handler.call_repeat_chain_len = 3;
|
||||
|
||||
let call = ToolCall::new("test_function_loop".to_string(), json!({"param": 1}), None);
|
||||
|
||||
for _ in 0..3 {
|
||||
handler.tool_call(call.clone()).unwrap();
|
||||
}
|
||||
|
||||
let result = handler.tool_call(call.clone());
|
||||
assert!(result.is_err());
|
||||
let error_message = result.unwrap_err().to_string();
|
||||
assert!(error_message.contains("Call loop detected!"));
|
||||
assert!(error_message.contains("test_function_loop"));
|
||||
}
|
||||
|
||||
fn split_chunks(text: &str) -> Vec<Vec<u8>> {
|
||||
let mut rng = rand::rng();
|
||||
|
||||
+4
-1
@@ -17,7 +17,7 @@ use crate::client::{
|
||||
ClientConfig, MessageContentToolCalls, Model, ModelType, OPENAI_COMPATIBLE_PROVIDERS,
|
||||
ProviderModels, create_client_config, list_client_types, list_models,
|
||||
};
|
||||
use crate::function::{FunctionDeclaration, Functions, ToolResult};
|
||||
use crate::function::{FunctionDeclaration, Functions, ToolCallTracker, ToolResult};
|
||||
use crate::rag::Rag;
|
||||
use crate::render::{MarkdownRender, RenderOptions};
|
||||
use crate::utils::*;
|
||||
@@ -199,6 +199,8 @@ pub struct Config {
|
||||
pub rag: Option<Arc<Rag>>,
|
||||
#[serde(skip)]
|
||||
pub agent: Option<Agent>,
|
||||
#[serde(skip)]
|
||||
pub(crate) tool_call_tracker: Option<ToolCallTracker>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@@ -271,6 +273,7 @@ impl Default for Config {
|
||||
session: None,
|
||||
rag: None,
|
||||
agent: None,
|
||||
tool_call_tracker: Some(ToolCallTracker::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user