Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
d5e0728532
|
|||
|
25c0885dcc
|
|||
|
f56ed7d005
|
|||
|
d79e4b9dff
|
@@ -659,6 +659,14 @@
|
|||||||
# - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini
|
# - https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/gemini
|
||||||
- provider: vertexai
|
- provider: vertexai
|
||||||
models:
|
models:
|
||||||
|
- name: gemini-3-pro-preview
|
||||||
|
hipaa_safe: true
|
||||||
|
max_input_tokens: 1048576
|
||||||
|
max_output_tokens: 65536
|
||||||
|
input_price: 0
|
||||||
|
output_price: 0
|
||||||
|
supports_vision: true
|
||||||
|
supports_function_calling: true
|
||||||
- name: gemini-2.5-flash
|
- name: gemini-2.5-flash
|
||||||
max_input_tokens: 1048576
|
max_input_tokens: 1048576
|
||||||
max_output_tokens: 65536
|
max_output_tokens: 65536
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ async fn chat_completions_streaming(
|
|||||||
}
|
}
|
||||||
let arguments: Value =
|
let arguments: Value =
|
||||||
function_arguments.parse().with_context(|| {
|
function_arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
||||||
})?;
|
})?;
|
||||||
handler.tool_call(ToolCall::new(
|
handler.tool_call(ToolCall::new(
|
||||||
function_name.clone(),
|
function_name.clone(),
|
||||||
@@ -272,7 +272,7 @@ async fn chat_completions_streaming(
|
|||||||
function_arguments = String::from("{}");
|
function_arguments = String::from("{}");
|
||||||
}
|
}
|
||||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
||||||
})?;
|
})?;
|
||||||
handler.tool_call(ToolCall::new(
|
handler.tool_call(ToolCall::new(
|
||||||
function_name.clone(),
|
function_name.clone(),
|
||||||
|
|||||||
@@ -93,10 +93,13 @@ pub async fn claude_chat_completions_streaming(
|
|||||||
data["content_block"]["id"].as_str(),
|
data["content_block"]["id"].as_str(),
|
||||||
) {
|
) {
|
||||||
if !function_name.is_empty() {
|
if !function_name.is_empty() {
|
||||||
let arguments: Value =
|
let arguments: Value = if function_arguments.is_empty() {
|
||||||
|
json!({})
|
||||||
|
} else {
|
||||||
function_arguments.parse().with_context(|| {
|
function_arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
||||||
})?;
|
})?
|
||||||
|
};
|
||||||
handler.tool_call(ToolCall::new(
|
handler.tool_call(ToolCall::new(
|
||||||
function_name.clone(),
|
function_name.clone(),
|
||||||
arguments,
|
arguments,
|
||||||
@@ -134,7 +137,7 @@ pub async fn claude_chat_completions_streaming(
|
|||||||
json!({})
|
json!({})
|
||||||
} else {
|
} else {
|
||||||
function_arguments.parse().with_context(|| {
|
function_arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
handler.tool_call(ToolCall::new(
|
handler.tool_call(ToolCall::new(
|
||||||
@@ -286,7 +289,7 @@ pub fn claude_build_chat_completions_body(
|
|||||||
body["tools"] = functions
|
body["tools"] = functions
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| {
|
.map(|v| {
|
||||||
if v.parameters.type_value.is_none() {
|
if v.parameters.is_empty_properties() {
|
||||||
json!({
|
json!({
|
||||||
"name": v.name,
|
"name": v.name,
|
||||||
"description": v.description,
|
"description": v.description,
|
||||||
|
|||||||
@@ -167,7 +167,7 @@ async fn chat_completions_streaming(
|
|||||||
"tool-call-end" => {
|
"tool-call-end" => {
|
||||||
if !function_name.is_empty() {
|
if !function_name.is_empty() {
|
||||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
||||||
})?;
|
})?;
|
||||||
handler.tool_call(ToolCall::new(
|
handler.tool_call(ToolCall::new(
|
||||||
function_name.clone(),
|
function_name.clone(),
|
||||||
@@ -230,7 +230,7 @@ fn extract_chat_completions(data: &Value) -> Result<ChatCompletionsOutput> {
|
|||||||
call["id"].as_str(),
|
call["id"].as_str(),
|
||||||
) {
|
) {
|
||||||
let arguments: Value = arguments.parse().with_context(|| {
|
let arguments: Value = arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
|
format!("Tool call '{name}' has non-JSON arguments '{arguments}'")
|
||||||
})?;
|
})?;
|
||||||
tool_calls.push(ToolCall::new(
|
tool_calls.push(ToolCall::new(
|
||||||
name.to_string(),
|
name.to_string(),
|
||||||
|
|||||||
+14
-8
@@ -433,10 +433,13 @@ pub async fn call_chat_completions(
|
|||||||
client.global_config().read().print_markdown(&text)?;
|
client.global_config().read().print_markdown(&text)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok((
|
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
||||||
text,
|
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
||||||
eval_tool_calls(client.global_config(), tool_calls).await?,
|
tool_results
|
||||||
))
|
.iter()
|
||||||
|
.for_each(|res| tracker.record_call(res.call.clone()));
|
||||||
|
}
|
||||||
|
Ok((text, tool_results))
|
||||||
}
|
}
|
||||||
Err(err) => Err(err),
|
Err(err) => Err(err),
|
||||||
}
|
}
|
||||||
@@ -467,10 +470,13 @@ pub async fn call_chat_completions_streaming(
|
|||||||
if !text.is_empty() && !text.ends_with('\n') {
|
if !text.is_empty() && !text.ends_with('\n') {
|
||||||
println!();
|
println!();
|
||||||
}
|
}
|
||||||
Ok((
|
let tool_results = eval_tool_calls(client.global_config(), tool_calls).await?;
|
||||||
text,
|
if let Some(tracker) = client.global_config().write().tool_call_tracker.as_mut() {
|
||||||
eval_tool_calls(client.global_config(), tool_calls).await?,
|
tool_results
|
||||||
))
|
.iter()
|
||||||
|
.for_each(|res| tracker.record_call(res.call.clone()));
|
||||||
|
}
|
||||||
|
Ok((text, tool_results))
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ pub async fn openai_chat_completions_streaming(
|
|||||||
function_arguments = String::from("{}");
|
function_arguments = String::from("{}");
|
||||||
}
|
}
|
||||||
let arguments: Value = function_arguments.parse().with_context(|| {
|
let arguments: Value = function_arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{function_name}' have non-JSON arguments '{function_arguments}'")
|
format!("Tool call '{function_name}' has non-JSON arguments '{function_arguments}'")
|
||||||
})?;
|
})?;
|
||||||
handler.tool_call(ToolCall::new(
|
handler.tool_call(ToolCall::new(
|
||||||
function_name.clone(),
|
function_name.clone(),
|
||||||
@@ -370,7 +370,7 @@ pub fn openai_extract_chat_completions(data: &Value) -> Result<ChatCompletionsOu
|
|||||||
call["id"].as_str(),
|
call["id"].as_str(),
|
||||||
) {
|
) {
|
||||||
let arguments: Value = arguments.parse().with_context(|| {
|
let arguments: Value = arguments.parse().with_context(|| {
|
||||||
format!("Tool call '{name}' have non-JSON arguments '{arguments}'")
|
format!("Tool call '{name}' has non-JSON arguments '{arguments}'")
|
||||||
})?;
|
})?;
|
||||||
tool_calls.push(ToolCall::new(
|
tool_calls.push(ToolCall::new(
|
||||||
name.to_string(),
|
name.to_string(),
|
||||||
|
|||||||
+153
-2
@@ -13,6 +13,9 @@ pub struct SseHandler {
|
|||||||
abort_signal: AbortSignal,
|
abort_signal: AbortSignal,
|
||||||
buffer: String,
|
buffer: String,
|
||||||
tool_calls: Vec<ToolCall>,
|
tool_calls: Vec<ToolCall>,
|
||||||
|
last_tool_calls: Vec<ToolCall>,
|
||||||
|
max_call_repeats: usize,
|
||||||
|
call_repeat_chain_len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SseHandler {
|
impl SseHandler {
|
||||||
@@ -22,11 +25,13 @@ impl SseHandler {
|
|||||||
abort_signal,
|
abort_signal,
|
||||||
buffer: String::new(),
|
buffer: String::new(),
|
||||||
tool_calls: Vec::new(),
|
tool_calls: Vec::new(),
|
||||||
|
last_tool_calls: Vec::new(),
|
||||||
|
max_call_repeats: 2,
|
||||||
|
call_repeat_chain_len: 3,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn text(&mut self, text: &str) -> Result<()> {
|
pub fn text(&mut self, text: &str) -> Result<()> {
|
||||||
// debug!("HandleText: {}", text);
|
|
||||||
if text.is_empty() {
|
if text.is_empty() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
@@ -45,7 +50,6 @@ impl SseHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn done(&mut self) {
|
pub fn done(&mut self) {
|
||||||
// debug!("HandleDone");
|
|
||||||
let ret = self.sender.send(SseEvent::Done);
|
let ret = self.sender.send(SseEvent::Done);
|
||||||
if ret.is_err() {
|
if ret.is_err() {
|
||||||
if self.abort_signal.aborted() {
|
if self.abort_signal.aborted() {
|
||||||
@@ -56,14 +60,114 @@ impl SseHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
|
pub fn tool_call(&mut self, call: ToolCall) -> Result<()> {
|
||||||
|
if self.is_call_loop(&call) {
|
||||||
|
let loop_message = self.create_loop_detection_message(&call);
|
||||||
|
return Err(anyhow!(loop_message));
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.last_tool_calls.len() == self.call_repeat_chain_len * self.max_call_repeats {
|
||||||
|
self.last_tool_calls.remove(0);
|
||||||
|
}
|
||||||
|
self.last_tool_calls.push(call.clone());
|
||||||
|
|
||||||
self.tool_calls.push(call);
|
self.tool_calls.push(call);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_call_loop(&self, new_call: &ToolCall) -> bool {
|
||||||
|
if self.last_tool_calls.len() < self.call_repeat_chain_len {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(last_call) = self.last_tool_calls.last()
|
||||||
|
&& self.calls_match(last_call, new_call)
|
||||||
|
{
|
||||||
|
let mut repeat_count = 1;
|
||||||
|
for i in (0..self.last_tool_calls.len()).rev() {
|
||||||
|
if i == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if self.calls_match(&self.last_tool_calls[i - 1], &self.last_tool_calls[i]) {
|
||||||
|
repeat_count += 1;
|
||||||
|
if repeat_count >= self.max_call_repeats {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let chain_start = self
|
||||||
|
.last_tool_calls
|
||||||
|
.len()
|
||||||
|
.saturating_sub(self.call_repeat_chain_len);
|
||||||
|
let chain = &self.last_tool_calls[chain_start..];
|
||||||
|
|
||||||
|
if chain.len() == self.call_repeat_chain_len {
|
||||||
|
let mut is_repeating = true;
|
||||||
|
for i in 0..chain.len() - 1 {
|
||||||
|
if !self.calls_match(&chain[i], &chain[i + 1]) {
|
||||||
|
is_repeating = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if is_repeating && self.calls_match(&chain[chain.len() - 1], new_call) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calls_match(&self, call1: &ToolCall, call2: &ToolCall) -> bool {
|
||||||
|
call1.name == call2.name && call1.arguments == call2.arguments
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_loop_detection_message(&self, new_call: &ToolCall) -> String {
|
||||||
|
let mut message = String::from("⚠️ Call loop detected! ⚠️");
|
||||||
|
|
||||||
|
message.push_str(&format!(
|
||||||
|
"The call '{}' with arguments '{}' is repeating.\n",
|
||||||
|
new_call.name, new_call.arguments
|
||||||
|
));
|
||||||
|
|
||||||
|
if self.last_tool_calls.len() >= self.call_repeat_chain_len {
|
||||||
|
let chain_start = self
|
||||||
|
.last_tool_calls
|
||||||
|
.len()
|
||||||
|
.saturating_sub(self.call_repeat_chain_len);
|
||||||
|
let chain = &self.last_tool_calls[chain_start..];
|
||||||
|
|
||||||
|
message.push_str("The following sequence of calls is repeating:\n");
|
||||||
|
for (i, call) in chain.iter().enumerate() {
|
||||||
|
message.push_str(&format!(
|
||||||
|
" {}. {} with arguments {}\n",
|
||||||
|
i + 1,
|
||||||
|
call.name,
|
||||||
|
call.arguments
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
message.push_str("\nPlease move on to the next task in your sequence using the last output you got from the call or chain you are trying to re-execute. ");
|
||||||
|
message.push_str(
|
||||||
|
"Consider using different parameters or a different approach to avoid this loop.",
|
||||||
|
);
|
||||||
|
|
||||||
|
message
|
||||||
|
}
|
||||||
|
|
||||||
pub fn abort(&self) -> AbortSignal {
|
pub fn abort(&self) -> AbortSignal {
|
||||||
self.abort_signal.clone()
|
self.abort_signal.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub fn last_tool_calls(&self) -> &[ToolCall] {
|
||||||
|
&self.last_tool_calls
|
||||||
|
}
|
||||||
|
|
||||||
pub fn take(self) -> (String, Vec<ToolCall>) {
|
pub fn take(self) -> (String, Vec<ToolCall>) {
|
||||||
let Self {
|
let Self {
|
||||||
buffer, tool_calls, ..
|
buffer, tool_calls, ..
|
||||||
@@ -239,6 +343,53 @@ mod tests {
|
|||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures_util::stream;
|
use futures_util::stream;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_last_tool_calls_ring_buffer() {
|
||||||
|
let (sender, _) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
let abort_signal = crate::utils::create_abort_signal();
|
||||||
|
let mut handler = SseHandler::new(sender, abort_signal);
|
||||||
|
|
||||||
|
for i in 0..15 {
|
||||||
|
let call = ToolCall::new(format!("test_function_{}", i), json!({"param": i}), None);
|
||||||
|
handler.tool_call(call.clone()).unwrap();
|
||||||
|
}
|
||||||
|
let lt_len = handler.call_repeat_chain_len * handler.max_call_repeats;
|
||||||
|
assert_eq!(handler.last_tool_calls().len(), lt_len);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
handler.last_tool_calls()[lt_len - 1].name,
|
||||||
|
"test_function_14"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
handler.last_tool_calls()[0].name,
|
||||||
|
format!("test_function_{}", 14 - lt_len + 1)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_call_loop_detection() {
|
||||||
|
let (sender, _) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
let abort_signal = crate::utils::create_abort_signal();
|
||||||
|
let mut handler = SseHandler::new(sender, abort_signal);
|
||||||
|
|
||||||
|
handler.max_call_repeats = 2;
|
||||||
|
handler.call_repeat_chain_len = 3;
|
||||||
|
|
||||||
|
let call = ToolCall::new("test_function_loop".to_string(), json!({"param": 1}), None);
|
||||||
|
|
||||||
|
for _ in 0..3 {
|
||||||
|
handler.tool_call(call.clone()).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = handler.tool_call(call.clone());
|
||||||
|
assert!(result.is_err());
|
||||||
|
let error_message = result.unwrap_err().to_string();
|
||||||
|
assert!(error_message.contains("Call loop detected!"));
|
||||||
|
assert!(error_message.contains("test_function_loop"));
|
||||||
|
}
|
||||||
|
|
||||||
fn split_chunks(text: &str) -> Vec<Vec<u8>> {
|
fn split_chunks(text: &str) -> Vec<Vec<u8>> {
|
||||||
let mut rng = rand::rng();
|
let mut rng = rand::rng();
|
||||||
|
|||||||
+4
-1
@@ -17,7 +17,7 @@ use crate::client::{
|
|||||||
ClientConfig, MessageContentToolCalls, Model, ModelType, OPENAI_COMPATIBLE_PROVIDERS,
|
ClientConfig, MessageContentToolCalls, Model, ModelType, OPENAI_COMPATIBLE_PROVIDERS,
|
||||||
ProviderModels, create_client_config, list_client_types, list_models,
|
ProviderModels, create_client_config, list_client_types, list_models,
|
||||||
};
|
};
|
||||||
use crate::function::{FunctionDeclaration, Functions, ToolResult};
|
use crate::function::{FunctionDeclaration, Functions, ToolCallTracker, ToolResult};
|
||||||
use crate::rag::Rag;
|
use crate::rag::Rag;
|
||||||
use crate::render::{MarkdownRender, RenderOptions};
|
use crate::render::{MarkdownRender, RenderOptions};
|
||||||
use crate::utils::*;
|
use crate::utils::*;
|
||||||
@@ -199,6 +199,8 @@ pub struct Config {
|
|||||||
pub rag: Option<Arc<Rag>>,
|
pub rag: Option<Arc<Rag>>,
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
pub agent: Option<Agent>,
|
pub agent: Option<Agent>,
|
||||||
|
#[serde(skip)]
|
||||||
|
pub(crate) tool_call_tracker: Option<ToolCallTracker>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
@@ -271,6 +273,7 @@ impl Default for Config {
|
|||||||
session: None,
|
session: None,
|
||||||
rag: None,
|
rag: None,
|
||||||
agent: None,
|
agent: None,
|
||||||
|
tool_call_tracker: Some(ToolCallTracker::default()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+122
-5
@@ -15,6 +15,7 @@ use indoc::formatdoc;
|
|||||||
use rust_embed::Embed;
|
use rust_embed::Embed;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
|
use std::collections::VecDeque;
|
||||||
use std::ffi::OsStr;
|
use std::ffi::OsStr;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
@@ -90,6 +91,19 @@ pub async fn eval_tool_calls(
|
|||||||
}
|
}
|
||||||
let mut is_all_null = true;
|
let mut is_all_null = true;
|
||||||
for call in calls {
|
for call in calls {
|
||||||
|
if let Some(checker) = &config.read().tool_call_tracker
|
||||||
|
&& let Some(msg) = checker.check_loop(&call.clone())
|
||||||
|
{
|
||||||
|
let dup_msg = format!("{{\"tool_call_loop_alert\":{}}}", &msg.trim());
|
||||||
|
println!(
|
||||||
|
"{}",
|
||||||
|
warning_text(format!("{}: ⚠️ Tool-call loop detected! ⚠️", &call.name).as_str())
|
||||||
|
);
|
||||||
|
let val = json!(dup_msg);
|
||||||
|
output.push(ToolResult::new(call, val));
|
||||||
|
is_all_null = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let mut result = call.eval(config).await?;
|
let mut result = call.eval(config).await?;
|
||||||
if result.is_null() {
|
if result.is_null() {
|
||||||
result = json!("DONE");
|
result = json!("DONE");
|
||||||
@@ -841,11 +855,14 @@ impl ToolCall {
|
|||||||
_ if cmd_name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) => {
|
_ if cmd_name.starts_with(MCP_INVOKE_META_FUNCTION_NAME_PREFIX) => {
|
||||||
Self::invoke_mcp_tool(config, &cmd_name, &json_data).await?
|
Self::invoke_mcp_tool(config, &cmd_name, &json_data).await?
|
||||||
}
|
}
|
||||||
_ => match run_llm_function(cmd_name, cmd_args, envs, agent_name)? {
|
_ => match run_llm_function(cmd_name, cmd_args, envs, agent_name) {
|
||||||
Some(contents) => serde_json::from_str(&contents)
|
Ok(Some(contents)) => serde_json::from_str(&contents)
|
||||||
.ok()
|
.ok()
|
||||||
.unwrap_or_else(|| json!({"output": contents})),
|
.unwrap_or_else(|| json!({"output": contents})),
|
||||||
None => Value::Null,
|
Ok(None) => Value::Null,
|
||||||
|
Err(e) => serde_json::from_str(&e.to_string())
|
||||||
|
.ok()
|
||||||
|
.unwrap_or_else(|| json!({"output": e.to_string()})),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -978,7 +995,9 @@ pub fn run_llm_function(
|
|||||||
agent_name: Option<String>,
|
agent_name: Option<String>,
|
||||||
) -> Result<Option<String>> {
|
) -> Result<Option<String>> {
|
||||||
let mut bin_dirs: Vec<PathBuf> = vec![];
|
let mut bin_dirs: Vec<PathBuf> = vec![];
|
||||||
|
let mut command_name = cmd_name.clone();
|
||||||
if let Some(agent_name) = agent_name {
|
if let Some(agent_name) = agent_name {
|
||||||
|
command_name = cmd_args[0].clone();
|
||||||
let dir = Config::agent_bin_dir(&agent_name);
|
let dir = Config::agent_bin_dir(&agent_name);
|
||||||
if dir.exists() {
|
if dir.exists() {
|
||||||
bin_dirs.push(dir);
|
bin_dirs.push(dir);
|
||||||
@@ -1001,9 +1020,13 @@ pub fn run_llm_function(
|
|||||||
let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dirs);
|
let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dirs);
|
||||||
|
|
||||||
let exit_code = run_command(&cmd_name, &cmd_args, Some(envs))
|
let exit_code = run_command(&cmd_name, &cmd_args, Some(envs))
|
||||||
.map_err(|err| anyhow!("Unable to run {cmd_name}, {err}"))?;
|
.map_err(|err| anyhow!("Unable to run {command_name}, {err}"))?;
|
||||||
if exit_code != 0 {
|
if exit_code != 0 {
|
||||||
bail!("Tool call exited with {exit_code}");
|
let tool_error_message =
|
||||||
|
format!("⚠️ Tool call '{command_name}' threw exit code {exit_code} ⚠️");
|
||||||
|
println!("{}", warning_text(&tool_error_message));
|
||||||
|
let tool_error_json = format!("{{\"tool_call_error\":\"{}\"}}", &tool_error_message);
|
||||||
|
return Ok(Some(tool_error_json));
|
||||||
}
|
}
|
||||||
let mut output = None;
|
let mut output = None;
|
||||||
if temp_file.exists() {
|
if temp_file.exists() {
|
||||||
@@ -1032,3 +1055,97 @@ fn polyfill_cmd_name<T: AsRef<Path>>(cmd_name: &str, bin_dir: &[T]) -> String {
|
|||||||
}
|
}
|
||||||
cmd_name
|
cmd_name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ToolCallTracker {
|
||||||
|
last_calls: VecDeque<ToolCall>,
|
||||||
|
max_repeats: usize,
|
||||||
|
chain_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolCallTracker {
|
||||||
|
pub fn new(max_repeats: usize, chain_len: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
last_calls: VecDeque::new(),
|
||||||
|
max_repeats,
|
||||||
|
chain_len,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn default() -> Self {
|
||||||
|
Self::new(2, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_loop(&self, new_call: &ToolCall) -> Option<String> {
|
||||||
|
if self.last_calls.len() < self.max_repeats {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(last) = self.last_calls.back()
|
||||||
|
&& self.calls_match(last, new_call)
|
||||||
|
{
|
||||||
|
let mut repeat_count = 1;
|
||||||
|
for i in (1..self.last_calls.len()).rev() {
|
||||||
|
if self.calls_match(&self.last_calls[i - 1], &self.last_calls[i]) {
|
||||||
|
repeat_count += 1;
|
||||||
|
if repeat_count >= self.max_repeats {
|
||||||
|
return Some(self.create_loop_message());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let start = self.last_calls.len().saturating_sub(self.chain_len);
|
||||||
|
let chain: Vec<_> = self.last_calls.iter().skip(start).collect();
|
||||||
|
if chain.len() == self.chain_len {
|
||||||
|
let mut is_repeating = true;
|
||||||
|
for i in 0..chain.len() - 1 {
|
||||||
|
if !self.calls_match(chain[i], chain[i + 1]) {
|
||||||
|
is_repeating = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if is_repeating && self.calls_match(chain[chain.len() - 1], new_call) {
|
||||||
|
return Some(self.create_loop_message());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn calls_match(&self, a: &ToolCall, b: &ToolCall) -> bool {
|
||||||
|
a.name == b.name && a.arguments == b.arguments
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_loop_message(&self) -> String {
|
||||||
|
let message = r#"{"error":{"message":"⚠️ Tool-call loop detected! ⚠️","code":400,"param":"Use the output of the last call to this function and parameter-set then move on to the next step of workflow, change tools/parameters called, or request assistance in the conversation sream"}}"#;
|
||||||
|
|
||||||
|
if self.last_calls.len() >= self.chain_len {
|
||||||
|
let start = self.last_calls.len().saturating_sub(self.chain_len);
|
||||||
|
let chain: Vec<_> = self.last_calls.iter().skip(start).collect();
|
||||||
|
let mut loopset = "[".to_string();
|
||||||
|
for c in chain {
|
||||||
|
loopset +=
|
||||||
|
format!("{{\"name\":{},\"parameters\":{}}},", c.name, c.arguments).as_str();
|
||||||
|
}
|
||||||
|
let _ = loopset.pop();
|
||||||
|
loopset.push(']');
|
||||||
|
format!(
|
||||||
|
"{},\"call_history\":{}}}}}",
|
||||||
|
&message[..(&message.len() - 2)],
|
||||||
|
loopset
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
message.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn record_call(&mut self, call: ToolCall) {
|
||||||
|
if self.last_calls.len() >= self.chain_len * self.max_repeats {
|
||||||
|
self.last_calls.pop_front();
|
||||||
|
}
|
||||||
|
self.last_calls.push_back(call);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user