4 Commits

9 changed files with 315 additions and 27 deletions
+8
View File
@@ -659,6 +659,14 @@
# - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini
- provider: vertexai
models:
- name: gemini-3-pro-preview
hipaa_safe: true
max_input_tokens: 1048576
max_output_tokens: 65536
input_price: 0
output_price: 0
supports_vision: true
supports_function_calling: true
- name: gemini-2.5-flash
max_input_tokens: 1048576
max_output_tokens: 65536
+2 -2
View File
@@ -234,7 +234,7 @@ async fn chat_completions_streaming(
}
let arguments: Value =
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
@@ -272,7 +272,7 @@ async fn chat_completions_streaming(
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
+8 -5
View File
@@ -93,10 +93,13 @@ pub async fn claude_chat_completions_streaming(
data["content_block"]["id"].as_str(),
) {
if !function_name.is_empty() {
let arguments: Value =
let arguments: Value = if function_arguments.is_empty() {
json!({})
} else {
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
})?;
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
})?
};
handler.tool_call(ToolCall::new(
function_name.clone(),
arguments,
@@ -134,7 +137,7 @@ pub async fn claude_chat_completions_streaming(
json!({})
} else {
function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
})?
};
handler.tool_call(ToolCall::new(
@@ -286,7 +289,7 @@ pub fn claude_build_chat_completions_body(
body["tools"] = functions
.iter()
.map(|v| {
if v.parameters.type_value.is_none() {
if v.parameters.is_empty_properties() {
json!({
"name": v.name,
"description": v.description,
+2 -2
View File
@@ -167,7 +167,7 @@ async fn chat_completions_streaming(
"tool-call-end" => {
if !function_name.is_empty() {
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
@@ -230,7 +230,7 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
call["id"].as_str(),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
format!("Tool call '{name}' has non-JSON arguments '{arguments}'")
})?;
tool_calls.push(ToolCall::new(
name.to_string(),
+14 -8
View File
@@ -433,10 +433,13 @@ pub async fn call_chat_completions(
client.global_config().read().print_markdown(&text)?;
}
}
Ok((
text,
eval_tool_calls(client.global_config(), tool_calls).await?,
))
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
tool_results
.iter()
.for_each(|res| tracker.record_call(res.call.clone()));
}
Ok((text, tool_results))
}
Err(err) => Err(err),
}
@@ -467,10 +470,13 @@ pub async fn call_chat_completions_streaming(
if !text.is_empty() && !text.ends_with('\n') {
println!();
}
Ok((
text,
eval_tool_calls(client.global_config(), tool_calls).await?,
))
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
tool_results
.iter()
.for_each(|res| tracker.record_call(res.call.clone()));
}
Ok((text, tool_results))
}
Err(err) => {
if !text.is_empty() {
+2 -2
View File
@@ -164,7 +164,7 @@ pub async fn openai_chat_completions_streaming(
function_arguments = String::from("{}");
}
let arguments: Value = function_arguments.parse().with_context(|| {
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
})?;
handler.tool_call(ToolCall::new(
function_name.clone(),
@@ -370,7 +370,7 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu
call["id"].as_str(),
) {
let arguments: Value = arguments.parse().with_context(|| {
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
format!("Tool call '{name}' has non-JSON arguments '{arguments}'")
})?;
tool_calls.push(ToolCall::new(
name.to_string(),
+153 -2
View File
@@ -13,6 +13,9 @@ pub struct SseHandler {
abort_signal: AbortSignal,
buffer: String,
tool_calls: Vec<ToolCall>,
last_tool_calls: Vec<ToolCall>,
max_call_repeats: usize,
call_repeat_chain_len: usize,
}
impl SseHandler {
@@ -22,11 +25,13 @@ impl SseHandler {
abort_signal,
buffer: String::new(),
tool_calls: Vec::new(),
last_tool_calls: Vec::new(),
max_call_repeats: 2,
call_repeat_chain_len: 3,
}
}
pub fn text(&mut self, text: &str) -> Result<()> {
// debug!("HandleText: {}", text);
if text.is_empty() {
return Ok(());
}
@@ -45,7 +50,6 @@ impl SseHandler {
}
pub fn done(&mut self) {
// debug!("HandleDone");
let ret = self.sender.send(SseEvent::Done);
if ret.is_err() {
if self.abort_signal.aborted() {
@@ -56,14 +60,114 @@ impl SseHandler {
}
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
if self.is_call_loop(&call) {
let loop_message = self.create_loop_detection_message(&call);
return Err(anyhow!(loop_message));
}
if self.last_tool_calls.len() == self.call_repeat_chain_len * self.max_call_repeats {
self.last_tool_calls.remove(0);
}
self.last_tool_calls.push(call.clone());
self.tool_calls.push(call);
Ok(())
}
fn is_call_loop(&self, new_call: &ToolCall) -> bool {
if self.last_tool_calls.len() < self.call_repeat_chain_len {
return false;
}
if let Some(last_call) = self.last_tool_calls.last()
&& self.calls_match(last_call, new_call)
{
let mut repeat_count = 1;
for i in (0..self.last_tool_calls.len()).rev() {
if i == 0 {
break;
}
if self.calls_match(&self.last_tool_calls[i - 1], &self.last_tool_calls[i]) {
repeat_count += 1;
if repeat_count >= self.max_call_repeats {
return true;
}
} else {
break;
}
}
}
let chain_start = self
.last_tool_calls
.len()
.saturating_sub(self.call_repeat_chain_len);
let chain = &self.last_tool_calls[chain_start..];
if chain.len() == self.call_repeat_chain_len {
let mut is_repeating = true;
for i in 0..chain.len() - 1 {
if !self.calls_match(&chain[i], &chain[i + 1]) {
is_repeating = false;
break;
}
}
if is_repeating && self.calls_match(&chain[chain.len() - 1], new_call) {
return true;
}
}
false
}
fn calls_match(&self, call1: &ToolCall, call2: &ToolCall) -> bool {
call1.name == call2.name && call1.arguments == call2.arguments
}
fn create_loop_detection_message(&self, new_call: &ToolCall) -> String {
let mut message = String::from("⚠️ Call loop detected! ⚠️");
message.push_str(&format!(
"The call '{}' with arguments '{}' is repeating.\n",
new_call.name, new_call.arguments
));
if self.last_tool_calls.len() >= self.call_repeat_chain_len {
let chain_start = self
.last_tool_calls
.len()
.saturating_sub(self.call_repeat_chain_len);
let chain = &self.last_tool_calls[chain_start..];
message.push_str("The following sequence of calls is repeating:\n");
for (i, call) in chain.iter().enumerate() {
message.push_str(&format!(
" {}. {} with arguments {}\n",
i + 1,
call.name,
call.arguments
));
}
}
message.push_str("\nPlease move on to the next task in your sequence using the last output you got from the call or chain you are trying to re-execute. ");
message.push_str(
"Consider using different parameters or a different approach to avoid this loop.",
);
message
}
pub fn abort(&self) -> AbortSignal {
self.abort_signal.clone()
}
#[cfg(test)]
pub fn last_tool_calls(&self) -> &[ToolCall] {
&self.last_tool_calls
}
pub fn take(self) -> (String, Vec<ToolCall>) {
let Self {
buffer, tool_calls, ..
@@ -239,6 +343,53 @@ mod tests {
use bytes::Bytes;
use futures_util::stream;
use rand::Rng;
use serde_json::json;
#[test]
fn test_last_tool_calls_ring_buffer() {
let (sender, _) = tokio::sync::mpsc::unbounded_channel();
let abort_signal = crate::utils::create_abort_signal();
let mut handler = SseHandler::new(sender, abort_signal);
for i in 0..15 {
let call = ToolCall::new(format!("test_function_{}", i), json!({"param": i}), None);
handler.tool_call(call.clone()).unwrap();
}
let lt_len = handler.call_repeat_chain_len * handler.max_call_repeats;
assert_eq!(handler.last_tool_calls().len(), lt_len);
assert_eq!(
handler.last_tool_calls()[lt_len - 1].name,
"test_function_14"
);
assert_eq!(
handler.last_tool_calls()[0].name,
format!("test_function_{}", 14 - lt_len + 1)
);
}
#[test]
fn test_call_loop_detection() {
let (sender, _) = tokio::sync::mpsc::unbounded_channel();
let abort_signal = crate::utils::create_abort_signal();
let mut handler = SseHandler::new(sender, abort_signal);
handler.max_call_repeats = 2;
handler.call_repeat_chain_len = 3;
let call = ToolCall::new("test_function_loop".to_string(), json!({"param": 1}), None);
for _ in 0..3 {
handler.tool_call(call.clone()).unwrap();
}
let result = handler.tool_call(call.clone());
assert!(result.is_err());
let error_message = result.unwrap_err().to_string();
assert!(error_message.contains("Call loop detected!"));
assert!(error_message.contains("test_function_loop"));
}
fn split_chunks(text: &str) -> Vec<Vec<u8>> {
let mut rng = rand::rng();
+4 -1
View File
@@ -17,7 +17,7 @@ use crate::client::{
ClientConfig, MessageContentToolCalls, Model, ModelType, OPENAI_COMPATIBLE_PROVIDERS,
ProviderModels, create_client_config, list_client_types, list_models,
};
use crate::function::{FunctionDeclaration, Functions, ToolResult};
use crate::function::{FunctionDeclaration, Functions, ToolCallTracker, ToolResult};
use crate::rag::Rag;
use crate::render::{MarkdownRender, RenderOptions};
use crate::utils::*;
@@ -199,6 +199,8 @@ pub struct Config {
pub rag: Option<Arc<Rag>>,
#[serde(skip)]
pub agent: Option<Agent>,
#[serde(skip)]
pub(crate) tool_call_tracker: Option<ToolCallTracker>,
}
impl Default for Config {
@@ -271,6 +273,7 @@ impl Default for Config {
session: None,
rag: None,
agent: None,
tool_call_tracker: Some(ToolCallTracker::default()),
}
}
}
+122 -5
View File
@@ -15,6 +15,7 @@ use indoc::formatdoc;
use rust_embed::Embed;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::collections::VecDeque;
use std::ffi::OsStr;
use std::fs::File;
use std::io::Write;
@@ -90,6 +91,19 @@ pub async fn eval_tool_calls(
}
let mut is_all_null = true;
for call in calls {
if let Some(checker) = &config.read().tool_call_tracker
&& let Some(msg) = checker.check_loop(&call.clone())
{
let dup_msg = format!("{{\"tool_call_loop_alert\":{}}}", &msg.trim());
println!(
"{}",
warning_text(format!("{}: ⚠️ Tool-call loop detected! ⚠️", &call.name).as_str())
);
let val = json!(dup_msg);
output.push(ToolResult::new(call, val));
is_all_null = false;
continue;
}
let mut result = call.eval(config).await?;
if result.is_null() {
result = json!("DONE");
@@ -841,11 +855,14 @@ impl ToolCall {
_ if cmd_name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) => {
Self::invoke_mcp_tool(config, &cmd_name, &json_data).await?
}
_ => match run_llm_function(cmd_name, cmd_args, envs, agent_name)? {
Some(contents) => serde_json::from_str(&contents)
_ => match run_llm_function(cmd_name, cmd_args, envs, agent_name) {
Ok(Some(contents)) => serde_json::from_str(&contents)
.ok()
.unwrap_or_else(|| json!({"output": contents})),
None => Value::Null,
Ok(None) => Value::Null,
Err(e) => serde_json::from_str(&e.to_string())
.ok()
.unwrap_or_else(|| json!({"output": e.to_string()})),
},
};
@@ -978,7 +995,9 @@ pub fn run_llm_function(
agent_name: Option<String>,
) -> Result<Option<String>> {
let mut bin_dirs: Vec<PathBuf> = vec![];
let mut command_name = cmd_name.clone();
if let Some(agent_name) = agent_name {
command_name = cmd_args[0].clone();
let dir = Config::agent_bin_dir(&agent_name);
if dir.exists() {
bin_dirs.push(dir);
@@ -1001,9 +1020,13 @@ pub fn run_llm_function(
let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dirs);
let exit_code = run_command(&cmd_name, &cmd_args, Some(envs))
.map_err(|err| anyhow!("Unable to run {cmd_name}, {err}"))?;
.map_err(|err| anyhow!("Unable to run {command_name}, {err}"))?;
if exit_code != 0 {
bail!("Tool call exited with {exit_code}");
let tool_error_message =
format!("⚠️ Tool call '{command_name}' threw exit code {exit_code} ⚠️");
println!("{}", warning_text(&tool_error_message));
let tool_error_json = format!("{{\"tool_call_error\":\"{}\"}}", &tool_error_message);
return Ok(Some(tool_error_json));
}
let mut output = None;
if temp_file.exists() {
@@ -1032,3 +1055,97 @@ fn polyfill_cmd_name<T: AsRef<Path>>(cmd_name: &str, bin_dir: &[T]) -> String {
}
cmd_name
}
#[derive(Debug, Clone)]
pub struct ToolCallTracker {
last_calls: VecDeque<ToolCall>,
max_repeats: usize,
chain_len: usize,
}
impl ToolCallTracker {
pub fn new(max_repeats: usize, chain_len: usize) -> Self {
Self {
last_calls: VecDeque::new(),
max_repeats,
chain_len,
}
}
pub fn default() -> Self {
Self::new(2, 3)
}
pub fn check_loop(&self, new_call: &ToolCall) -> Option<String> {
if self.last_calls.len() < self.max_repeats {
return None;
}
if let Some(last) = self.last_calls.back()
&& self.calls_match(last, new_call)
{
let mut repeat_count = 1;
for i in (1..self.last_calls.len()).rev() {
if self.calls_match(&self.last_calls[i - 1], &self.last_calls[i]) {
repeat_count += 1;
if repeat_count >= self.max_repeats {
return Some(self.create_loop_message());
}
} else {
break;
}
}
}
let start = self.last_calls.len().saturating_sub(self.chain_len);
let chain: Vec<_> = self.last_calls.iter().skip(start).collect();
if chain.len() == self.chain_len {
let mut is_repeating = true;
for i in 0..chain.len() - 1 {
if !self.calls_match(chain[i], chain[i + 1]) {
is_repeating = false;
break;
}
}
if is_repeating && self.calls_match(chain[chain.len() - 1], new_call) {
return Some(self.create_loop_message());
}
}
None
}
fn calls_match(&self, a: &ToolCall, b: &ToolCall) -> bool {
a.name == b.name && a.arguments == b.arguments
}
fn create_loop_message(&self) -> String {
let message = r#"{"error":{"message":"⚠️ Tool-call loop detected! ⚠️","code":400,"param":"Use the output of the last call to this function and parameter-set then move on to the next step of workflow, change tools/parameters called, or request assistance in the conversation sream"}}"#;
if self.last_calls.len() >= self.chain_len {
let start = self.last_calls.len().saturating_sub(self.chain_len);
let chain: Vec<_> = self.last_calls.iter().skip(start).collect();
let mut loopset = "[".to_string();
for c in chain {
loopset +=
format!("{{\"name\":{},\"parameters\":{}}},", c.name, c.arguments).as_str();
}
let _ = loopset.pop();
loopset.push(']');
format!(
"{},\"call_history\":{}}}}}",
&message[..(&message.len() - 2)],
loopset
)
} else {
message.to_string()
}
}
pub fn record_call(&mut self, call: ToolCall) {
if self.last_calls.len() >= self.chain_len * self.max_repeats {
self.last_calls.pop_front();
}
self.last_calls.push_back(call);
}
}