feat: Added explicit guardrail handling for pending agents
This commit is contained in:
+152
-18
@@ -3,7 +3,7 @@ use crate::client::{Model, ModelType, call_chat_completions};
|
||||
use crate::config::{Agent, AppState, Input, RequestContext, Role, RoleLike};
|
||||
use crate::supervisor::mailbox::{Envelope, EnvelopePayload, Inbox};
|
||||
use crate::supervisor::{AgentExitStatus, AgentHandle, AgentResult, Supervisor};
|
||||
use crate::utils::{AbortSignal, create_abort_signal};
|
||||
use crate::utils::{AbortSignal, create_abort_signal, wait_abort_signal};
|
||||
|
||||
use crate::graph;
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
@@ -16,10 +16,69 @@ use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::time;
|
||||
use tokio::time::Instant;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub const SUPERVISOR_FUNCTION_PREFIX: &str = "agent__";
|
||||
|
||||
pub const PENDING_AGENTS_GUARDRAIL_MAX: u32 = 3;
|
||||
|
||||
pub enum GuardrailAction {
|
||||
NoAction,
|
||||
Inject(String),
|
||||
ForceTerminate(Vec<String>),
|
||||
}
|
||||
|
||||
pub fn pending_agent_ids(ctx: &RequestContext) -> Vec<String> {
|
||||
let Some(sup) = ctx.supervisor.as_ref() else {
|
||||
return Vec::new();
|
||||
};
|
||||
let sup = sup.read();
|
||||
sup.list_agents()
|
||||
.into_iter()
|
||||
.filter_map(|(id, _)| match sup.is_finished(id) {
|
||||
Some(false) => Some(id.to_string()),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn build_pending_agents_guardrail_prompt(ids: &[String]) -> String {
|
||||
let count = ids.len();
|
||||
let id_list = ids
|
||||
.iter()
|
||||
.map(|id| format!("- {id}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
format!(
|
||||
"[SYSTEM GUARDRAIL] You attempted to end your turn while {count} spawned background agent(s) \
|
||||
are still running:\n{id_list}\n\nThese agents will be abandoned if your turn ends now. You MUST \
|
||||
reclaim each one before ending your turn. For each agent: call `agent__collect` (blocks until \
|
||||
done, returns output) or `agent__cancel` (discards). Do NOT emit a text-only response \
|
||||
expecting them to 'report back' — they will not."
|
||||
)
|
||||
}
|
||||
|
||||
pub fn check_pending_agents_guardrail(ctx: &mut RequestContext) -> GuardrailAction {
|
||||
let pending = pending_agent_ids(ctx);
|
||||
if pending.is_empty() {
|
||||
ctx.pending_agents_guardrail_count = 0;
|
||||
return GuardrailAction::NoAction;
|
||||
}
|
||||
|
||||
if ctx.pending_agents_guardrail_count >= PENDING_AGENTS_GUARDRAIL_MAX {
|
||||
if let Some(sup) = ctx.supervisor.as_ref().cloned() {
|
||||
sup.read().cancel_recursive();
|
||||
}
|
||||
ctx.pending_agents_guardrail_count = 0;
|
||||
|
||||
return GuardrailAction::ForceTerminate(pending);
|
||||
}
|
||||
|
||||
ctx.pending_agents_guardrail_count += 1;
|
||||
GuardrailAction::Inject(build_pending_agents_guardrail_prompt(&pending))
|
||||
}
|
||||
|
||||
pub fn escalation_function_declarations() -> Vec<FunctionDeclaration> {
|
||||
vec![FunctionDeclaration {
|
||||
name: format!("{SUPERVISOR_FUNCTION_PREFIX}reply_escalation"),
|
||||
@@ -55,7 +114,11 @@ pub fn supervisor_function_declarations() -> Vec<FunctionDeclaration> {
|
||||
vec![
|
||||
FunctionDeclaration {
|
||||
name: format!("{SUPERVISOR_FUNCTION_PREFIX}spawn"),
|
||||
description: "Spawn a subagent to run in the background. Returns a task_id for tracking. The agent runs in parallel. You can continue working while it executes.".to_string(),
|
||||
description: "Spawn a subagent to run in the background. Returns an `id` immediately so you can continue \
|
||||
working in parallel. CRITICAL: every spawned agent MUST be reclaimed before you end your \
|
||||
turn — call `agent__collect` to retrieve its output, or `agent__cancel` if you no longer \
|
||||
need it. Ending your turn with pending agents will abandon their work and the system will \
|
||||
reject the turn-end.".to_string(),
|
||||
parameters: JsonSchema {
|
||||
type_value: Some("object".to_string()),
|
||||
properties: Some(IndexMap::from([
|
||||
@@ -109,7 +172,11 @@ pub fn supervisor_function_declarations() -> Vec<FunctionDeclaration> {
|
||||
},
|
||||
FunctionDeclaration {
|
||||
name: format!("{SUPERVISOR_FUNCTION_PREFIX}collect"),
|
||||
description: "Wait for a spawned agent to finish and return its result. Blocks until the agent completes.".to_string(),
|
||||
description: "Block until the named spawned agent finishes and return its result. This is your primary \
|
||||
wait primitive — it pauses your execution until the agent completes (or you are interrupted). \
|
||||
Call this for every agent you spawned before ending your turn. Do NOT end your turn assuming \
|
||||
agents will 'report back later' — they will not; they will be abandoned. If you no longer \
|
||||
need an agent's result, call `agent__cancel` instead.".to_string(),
|
||||
parameters: JsonSchema {
|
||||
type_value: Some("object".to_string()),
|
||||
properties: Some(IndexMap::from([(
|
||||
@@ -137,7 +204,10 @@ pub fn supervisor_function_declarations() -> Vec<FunctionDeclaration> {
|
||||
},
|
||||
FunctionDeclaration {
|
||||
name: format!("{SUPERVISOR_FUNCTION_PREFIX}cancel"),
|
||||
description: "Cancel a running subagent by its ID.".to_string(),
|
||||
description: "Cancel a running subagent by its ID. Use this when an agent's output is no longer needed \
|
||||
(e.g. you changed direction, or you're about to end your turn and don't want to wait). \
|
||||
Cancellation cascades: all of the cancelled agent's own descendants are also cancelled. This \
|
||||
call waits briefly for the agent to actually finish cleanup before returning.".to_string(),
|
||||
parameters: JsonSchema {
|
||||
type_value: Some("object".to_string()),
|
||||
properties: Some(IndexMap::from([(
|
||||
@@ -315,7 +385,7 @@ pub async fn handle_supervisor_tool(
|
||||
"check" => handle_check(ctx, args).await,
|
||||
"collect" => handle_collect(ctx, args).await,
|
||||
"list" => handle_list(ctx),
|
||||
"cancel" => handle_cancel(ctx, args),
|
||||
"cancel" => handle_cancel(ctx, args).await,
|
||||
"send_message" => handle_send_message(ctx, args),
|
||||
"check_inbox" => handle_check_inbox(ctx),
|
||||
"task_create" => handle_task_create(ctx, args),
|
||||
@@ -370,14 +440,28 @@ pub fn run_child_agent(
|
||||
}
|
||||
|
||||
if tool_results.is_empty() {
|
||||
break;
|
||||
match check_pending_agents_guardrail(&mut child_ctx) {
|
||||
GuardrailAction::NoAction => break,
|
||||
GuardrailAction::ForceTerminate(ids) => {
|
||||
log::warn!(
|
||||
"Pending-agent guardrail force-cancelled {} agent(s) after max reminders: {:?}",
|
||||
ids.len(),
|
||||
ids
|
||||
);
|
||||
break;
|
||||
}
|
||||
GuardrailAction::Inject(prompt) => {
|
||||
input = Input::from_str(&child_ctx, &prompt, None)?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
input = input.merge_tool_results(output, tool_results);
|
||||
}
|
||||
|
||||
if let Some(supervisor) = child_ctx.supervisor.clone() {
|
||||
supervisor.read().cancel_all();
|
||||
supervisor.read().cancel_recursive();
|
||||
}
|
||||
|
||||
Ok(accumulated_output)
|
||||
@@ -642,6 +726,7 @@ async fn handle_spawn(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let spawn_agent_id = agent_id.clone();
|
||||
let spawn_agent_name = agent_name.clone();
|
||||
let spawn_abort = child_abort.clone();
|
||||
let child_supervisor = child_ctx.supervisor.clone();
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
let result = run_child_agent(child_ctx, input, spawn_abort).await;
|
||||
@@ -669,6 +754,7 @@ async fn handle_spawn(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
inbox: child_inbox,
|
||||
abort_signal: child_abort,
|
||||
join_handle,
|
||||
child_supervisor,
|
||||
};
|
||||
|
||||
let supervisor = ctx
|
||||
@@ -683,7 +769,11 @@ async fn handle_spawn(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
"status": "ok",
|
||||
"id": agent_id,
|
||||
"agent": agent_name,
|
||||
"message": format!("Agent '{agent_name}' spawned as '{agent_id}'. Use agent__check or agent__collect to get results."),
|
||||
"message": format!("Agent '{agent_name}' spawned as '{agent_id}' and is running in the background. CRITICAL: \
|
||||
you MUST reclaim this agent before ending your turn — call `agent__collect` (blocks until \
|
||||
done, returns output) or `agent__cancel` (if you no longer need it). Ending your turn with \
|
||||
unreclaimed agents will be rejected and forces you to handle them. Do NOT assume the agent \
|
||||
will 'report back' on its own."),
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -743,7 +833,7 @@ async fn handle_collect(ctx: &mut RequestContext, args: &Value) -> Result<Value>
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
|
||||
{
|
||||
let target_abort = {
|
||||
let sup = supervisor.read();
|
||||
if sup.is_finished(id).is_none() {
|
||||
return Ok(json!({
|
||||
@@ -751,7 +841,8 @@ async fn handle_collect(ctx: &mut RequestContext, args: &Value) -> Result<Value>
|
||||
"message": format!("Agent '{id}' not found. Use agent__check to verify it exists and is finished.")
|
||||
}));
|
||||
}
|
||||
}
|
||||
sup.abort_signal_for(id)
|
||||
};
|
||||
|
||||
loop {
|
||||
let is_finished = {
|
||||
@@ -775,7 +866,27 @@ async fn handle_collect(ctx: &mut RequestContext, args: &Value) -> Result<Value>
|
||||
}));
|
||||
}
|
||||
|
||||
time::sleep(Duration::from_millis(200)).await;
|
||||
match target_abort.as_ref() {
|
||||
Some(abort) if abort.aborted() => {
|
||||
let deadline = Instant::now() + Duration::from_secs(2);
|
||||
while Instant::now() < deadline {
|
||||
if supervisor.read().is_finished(id).unwrap_or(false) {
|
||||
break;
|
||||
}
|
||||
time::sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Some(abort) => {
|
||||
tokio::select! {
|
||||
_ = time::sleep(Duration::from_millis(200)) => {}
|
||||
_ = wait_abort_signal(abort) => {}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
time::sleep(Duration::from_millis(200)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let handle = {
|
||||
@@ -792,6 +903,7 @@ async fn handle_collect(ctx: &mut RequestContext, args: &Value) -> Result<Value>
|
||||
.map_err(|e| anyhow!("Agent failed: {e}"))?;
|
||||
|
||||
let output = summarize_output(ctx, &result.agent_name, &result.output).await?;
|
||||
ctx.pending_agents_guardrail_count = 0;
|
||||
|
||||
Ok(json!({
|
||||
"status": "completed",
|
||||
@@ -836,7 +948,7 @@ fn handle_list(ctx: &mut RequestContext) -> Result<Value> {
|
||||
}))
|
||||
}
|
||||
|
||||
fn handle_cancel(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
async fn handle_cancel(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
let id = args
|
||||
.get("id")
|
||||
.and_then(Value::as_str)
|
||||
@@ -847,14 +959,34 @@ fn handle_cancel(ctx: &mut RequestContext, args: &Value) -> Result<Value> {
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("No supervisor active"))?;
|
||||
let mut sup = supervisor.write();
|
||||
|
||||
match sup.take(id) {
|
||||
let handle = {
|
||||
let mut sup = supervisor.write();
|
||||
sup.take(id)
|
||||
};
|
||||
|
||||
match handle {
|
||||
Some(handle) => {
|
||||
let agent_name = handle.agent_name.clone();
|
||||
if let Some(child_sup) = handle.child_supervisor.as_ref() {
|
||||
child_sup.read().cancel_recursive();
|
||||
}
|
||||
handle.abort_signal.set_ctrlc();
|
||||
|
||||
let cleanup = tokio::time::timeout(Duration::from_secs(5), handle.join_handle).await;
|
||||
|
||||
ctx.pending_agents_guardrail_count = 0;
|
||||
|
||||
let message = match cleanup {
|
||||
Ok(_) => format!("Cancelled agent '{agent_name}' and waited for cleanup."),
|
||||
Err(_) => format!(
|
||||
"Cancelled agent '{agent_name}'; cleanup did not complete within 5s. Its descendants have been signalled and will tear down asynchronously."
|
||||
),
|
||||
};
|
||||
|
||||
Ok(json!({
|
||||
"status": "ok",
|
||||
"message": format!("Cancelled agent '{}'", handle.agent_name),
|
||||
"message": message,
|
||||
}))
|
||||
}
|
||||
None => Ok(json!({
|
||||
@@ -1283,6 +1415,7 @@ mod tests {
|
||||
inbox: Arc::new(Inbox::new()),
|
||||
abort_signal: create_abort_signal(),
|
||||
join_handle,
|
||||
child_supervisor: None,
|
||||
};
|
||||
ctx.supervisor
|
||||
.as_ref()
|
||||
@@ -1362,6 +1495,7 @@ mod tests {
|
||||
inbox,
|
||||
abort_signal: abort,
|
||||
join_handle,
|
||||
child_supervisor: None,
|
||||
};
|
||||
ctx.supervisor
|
||||
.as_ref()
|
||||
@@ -1381,7 +1515,7 @@ mod tests {
|
||||
fn handle_cancel_registered_agent() {
|
||||
let mut ctx = ctx_with_supervisor(4, 3);
|
||||
register_fake_agent(&mut ctx, "a1", "explore");
|
||||
let result = handle_cancel(&mut ctx, &json!({"id": "a1"})).unwrap();
|
||||
let result = run_async(handle_cancel(&mut ctx, &json!({"id": "a1"}))).unwrap();
|
||||
assert_eq!(result["status"], "ok");
|
||||
assert_eq!(ctx.supervisor.as_ref().unwrap().read().active_count(), 0);
|
||||
}
|
||||
@@ -1389,14 +1523,14 @@ mod tests {
|
||||
#[test]
|
||||
fn handle_cancel_unknown_agent() {
|
||||
let mut ctx = ctx_with_supervisor(4, 3);
|
||||
let result = handle_cancel(&mut ctx, &json!({"id": "missing"})).unwrap();
|
||||
let result = run_async(handle_cancel(&mut ctx, &json!({"id": "missing"}))).unwrap();
|
||||
assert_eq!(result["status"], "error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_cancel_no_supervisor_errors() {
|
||||
let mut ctx = RequestContext::new(default_app_state(), WorkingMode::Cmd);
|
||||
let result = handle_cancel(&mut ctx, &json!({"id": "x"}));
|
||||
let result = run_async(handle_cancel(&mut ctx, &json!({"id": "x"})));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user