feat: improved UX for parallel graph execution

This commit is contained in:
2026-05-20 18:54:20 -06:00
parent 3c7d19da07
commit 81c037515e
12 changed files with 82 additions and 140 deletions
+6 -4
View File
@@ -1,6 +1,6 @@
use super::*; use super::*;
use crate::config::{paths, RenderMode}; use crate::config::{RenderMode, paths};
use crate::{ use crate::{
config::{AppConfig, Input, RequestContext}, config::{AppConfig, Input, RequestContext},
function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls}, function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls},
@@ -418,7 +418,8 @@ pub async fn call_chat_completions(
abort_signal: AbortSignal, abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> { ) -> Result<(String, Vec<ToolResult>)> {
let is_child_agent = ctx.current_depth > 0; let is_child_agent = ctx.current_depth > 0;
let spinner_message = if is_child_agent { "" } else { "Generating" }; let suppress_spinner = is_child_agent || ctx.render_mode == RenderMode::Silent;
let spinner_message = if suppress_spinner { "" } else { "Generating" };
let ret = abortable_run_with_spinner( let ret = abortable_run_with_spinner(
client.chat_completions(input.clone()), client.chat_completions(input.clone()),
spinner_message, spinner_message,
@@ -459,13 +460,14 @@ pub async fn call_chat_completions_streaming(
) -> Result<(String, Vec<ToolResult>)> { ) -> Result<(String, Vec<ToolResult>)> {
let (tx, rx) = unbounded_channel(); let (tx, rx) = unbounded_channel();
let mut handler = SseHandler::new(tx, abort_signal.clone()); let mut handler = SseHandler::new(tx, abort_signal.clone());
if ctx.render_mode == RenderMode::Silent { let silent = ctx.render_mode == RenderMode::Silent;
if silent {
handler.set_silent(true); handler.set_silent(true);
} }
let (send_ret, render_ret) = tokio::join!( let (send_ret, render_ret) = tokio::join!(
client.chat_completions_streaming(input, &mut handler), client.chat_completions_streaming(input, &mut handler),
render_stream(rx, client.app_config(), abort_signal.clone()), render_stream(rx, client.app_config(), abort_signal.clone(), silent),
); );
if handler.abort().aborted() { if handler.abort().aborted() {
+28 -15
View File
@@ -2,7 +2,7 @@ use super::{FunctionDeclaration, JsonSchema};
use crate::config::RequestContext; use crate::config::RequestContext;
use crate::supervisor::escalation::{EscalationRequest, new_escalation_id}; use crate::supervisor::escalation::{EscalationRequest, new_escalation_id};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow, bail};
use indexmap::IndexMap; use indexmap::IndexMap;
use inquire::{Confirm, MultiSelect, Select, Text}; use inquire::{Confirm, MultiSelect, Select, Text};
use serde_json::{Value, json}; use serde_json::{Value, json};
@@ -155,7 +155,10 @@ fn handle_direct_ask(args: &Value) -> Result<Value> {
let mut options = parse_options(args)?; let mut options = parse_options(args)?;
options.push(CUSTOM_MULTI_CHOICE_ANSWER_OPTION.to_string()); options.push(CUSTOM_MULTI_CHOICE_ANSWER_OPTION.to_string());
let mut answer = Select::new(question, options).prompt()?; let mut answer = Select::new(question, options)
.without_filtering()
.with_help_message("↑↓ to move, enter to select")
.prompt()?;
if answer == CUSTOM_MULTI_CHOICE_ANSWER_OPTION { if answer == CUSTOM_MULTI_CHOICE_ANSWER_OPTION {
answer = Text::new("Custom response:").prompt()? answer = Text::new("Custom response:").prompt()?
@@ -205,12 +208,11 @@ async fn handle_escalated(ctx: &RequestContext, action: &str, args: &Value) -> R
.ok_or_else(|| anyhow!("'question' is required"))? .ok_or_else(|| anyhow!("'question' is required"))?
.to_string(); .to_string();
let options: Option<Vec<String>> = args.get("options").and_then(Value::as_array).map(|arr| { let options: Option<Vec<String>> = if args.get("options").is_some() {
arr.iter() Some(parse_options(args)?)
.filter_map(Value::as_str) } else {
.map(String::from) None
.collect() };
});
let from_agent_id = ctx let from_agent_id = ctx
.self_agent_id .self_agent_id
@@ -262,13 +264,24 @@ async fn handle_escalated(ctx: &RequestContext, action: &str, args: &Value) -> R
} }
fn parse_options(args: &Value) -> Result<Vec<String>> { fn parse_options(args: &Value) -> Result<Vec<String>> {
args.get("options") let raw = args
.and_then(Value::as_array) .get("options")
.map(|arr| { .ok_or_else(|| anyhow!("'options' is required and must be an array of strings"))?;
arr.iter()
let arr: Vec<Value> = match raw {
Value::Array(arr) => arr.clone(),
Value::String(s) => serde_json::from_str::<Vec<Value>>(s).map_err(|_| {
anyhow!(
"'options' was a string but did not parse as a JSON array. \
Pass options as a native JSON array, e.g. [\"yes\", \"no\"]."
)
})?,
_ => bail!("'options' is required and must be an array of strings"),
};
Ok(arr
.iter()
.filter_map(Value::as_str) .filter_map(Value::as_str)
.map(String::from) .map(String::from)
.collect() .collect())
})
.ok_or_else(|| anyhow!("'options' is required and must be an array of strings"))
} }
-24
View File
@@ -3,7 +3,6 @@ use super::structured;
use super::types::AgentNode; use super::types::AgentNode;
use crate::config::RequestContext; use crate::config::RequestContext;
use crate::function::supervisor::run_agent_for_graph; use crate::function::supervisor::run_agent_for_graph;
use crate::utils::dimmed_text;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use serde_json::Value; use serde_json::Value;
use std::time::Duration; use std::time::Duration;
@@ -24,14 +23,6 @@ impl AgentNodeExecutor {
.interpolate(&node.prompt) .interpolate(&node.prompt)
.with_context(|| format!("Failed to interpolate prompt for agent '{}'", node.agent))?; .with_context(|| format!("Failed to interpolate prompt for agent '{}'", node.agent))?;
eprintln!(
"{}",
dimmed_text(&format!("▸ spawning agent '{}' with prompt:", node.agent))
);
for line in indent_prompt(&prompt, 6) {
eprintln!("{}", dimmed_text(&line));
}
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS)); let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS));
let raw = timeout( let raw = timeout(
@@ -66,21 +57,6 @@ impl AgentNodeExecutor {
} }
} }
fn indent_prompt(prompt: &str, prefix_spaces: usize) -> Vec<String> {
const MAX_LINES: usize = 12;
let pad = " ".repeat(prefix_spaces);
let mut out: Vec<String> = prompt
.lines()
.take(MAX_LINES)
.map(|line| format!("{pad}{line}"))
.collect();
let total = prompt.lines().count();
if total > MAX_LINES {
out.push(format!("{pad}... ({} more lines)", total - MAX_LINES));
}
out
}
fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &Value) { fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &Value) {
if node.output_schema.is_some() if node.output_schema.is_some()
&& let Some(obj) = output.as_object() && let Some(obj) = output.as_object()
+18 -14
View File
@@ -1,6 +1,6 @@
use super::agent::AgentNodeExecutor; use super::agent::AgentNodeExecutor;
use super::llm::{LlmExecutionOutcome, LlmNodeExecutor}; use super::llm::{LlmExecutionOutcome, LlmNodeExecutor};
use super::logging::GraphLogger; use super::logging::{GraphLogger, node_type_label};
use super::map::MapNodeExecutor; use super::map::MapNodeExecutor;
use super::progress::{BranchProgressHandle, BranchProgressTracker}; use super::progress::{BranchProgressHandle, BranchProgressTracker};
use super::rag::RagNodeExecutor; use super::rag::RagNodeExecutor;
@@ -146,11 +146,12 @@ impl GraphExecutor {
let semaphore = Arc::new(Semaphore::new(max_concurrency)); let semaphore = Arc::new(Semaphore::new(max_concurrency));
let frontier_size = frontier.len(); let frontier_size = frontier.len();
let progress_tracker = if frontier_size > 1 { let has_progress_nodes = frontier.iter().any(|nid| {
Some(BranchProgressTracker::new()) graph.get_node(nid).is_some_and(|n| {
} else { !matches!(n.node_type, NodeType::Approval(_) | NodeType::Input(_))
None })
}; });
let progress_tracker = has_progress_nodes.then(BranchProgressTracker::new);
let mut branch_tasks = Vec::with_capacity(frontier_size); let mut branch_tasks = Vec::with_capacity(frontier_size);
for node_id in &frontier { for node_id in &frontier {
let node = graph let node = graph
@@ -161,19 +162,24 @@ impl GraphExecutor {
.clone(); .clone();
let branch_state = state.fork_for_branch_state(); let branch_state = state.fork_for_branch_state();
let mut branch_ctx = ctx.fork_for_branch(); let mut branch_ctx = ctx.fork_for_branch();
if frontier_size > 1 {
branch_ctx.render_mode = RenderMode::Silent; branch_ctx.render_mode = RenderMode::Silent;
}
let script_exec_clone = script_executor.clone(); let script_exec_clone = script_executor.clone();
let graph_clone = Arc::clone(&graph); let graph_clone = Arc::clone(&graph);
let current = node_id.clone(); let current = node_id.clone();
let sem_clone = semaphore.clone(); let sem_clone = semaphore.clone();
let abort_clone = abort_signal.clone(); let abort_clone = abort_signal.clone();
let progress_handle: Option<BranchProgressHandle> = let progress_handle = match (
progress_tracker.as_ref().map(|t| t.add_branch(node_id)); matches!(node.node_type, NodeType::Approval(_) | NodeType::Input(_)),
&progress_tracker,
) {
(false, Some(tracker)) => {
tracker.add_branch(&format!("{} ({})", node_id, node_type_label(&node)))
}
_ => BranchProgressHandle::disabled(),
};
let task = tokio::spawn(async move { let task = tokio::spawn(async move {
let mut progress_handle = progress_handle; let mut progress_handle = Some(progress_handle);
let _permit = sem_clone let _permit = sem_clone
.acquire() .acquire()
.await .await
@@ -212,9 +218,7 @@ impl GraphExecutor {
} }
let joined = join_all(branch_tasks).await; let joined = join_all(branch_tasks).await;
if let Some(t) = &progress_tracker { drop(progress_tracker);
t.clear();
}
let mut branch_writes: Vec<BranchWrites> = Vec::new(); let mut branch_writes: Vec<BranchWrites> = Vec::new();
let mut next_frontier: HashSet<String> = HashSet::new(); let mut next_frontier: HashSet<String> = HashSet::new();
+1 -25
View File
@@ -3,7 +3,7 @@ use super::structured;
use super::types::LlmNode; use super::types::LlmNode;
use crate::client::{Model, ModelType, call_chat_completions}; use crate::client::{Model, ModelType, call_chat_completions};
use crate::config::{Input, RequestContext, Role, RoleLike}; use crate::config::{Input, RequestContext, Role, RoleLike};
use crate::utils::{create_abort_signal, dimmed_text}; use crate::utils::create_abort_signal;
use anyhow::{Context, Error, Result, anyhow, bail}; use anyhow::{Context, Error, Result, anyhow, bail};
use serde_json::Value; use serde_json::Value;
use std::collections::HashSet; use std::collections::HashSet;
@@ -101,15 +101,6 @@ async fn run(
let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref()); let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref());
validate_tools_subset(&regular_tools, &mcp_servers, parent_ctx)?; validate_tools_subset(&regular_tools, &mcp_servers, parent_ctx)?;
eprintln!(
"{}",
dimmed_text(&format!(
"▸ llm call: model={} tools={}",
node.model.as_deref().unwrap_or("<active>"),
describe_tools_filter(node.tools.as_deref())
))
);
let role = build_inline_role( let role = build_inline_role(
node, node,
instructions.as_deref(), instructions.as_deref(),
@@ -363,13 +354,6 @@ fn format_schema_hint(schema: &Value) -> String {
) )
} }
fn describe_tools_filter(tools: Option<&[String]>) -> String {
match tools {
Some(t) if !t.is_empty() => t.join(","),
_ => "<none>".into(),
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::types::*; use super::super::types::*;
@@ -571,14 +555,6 @@ mod tests {
assert!(hint.contains("ONLY")); assert!(hint.contains("ONLY"));
} }
#[test]
fn describe_tools_filter_renders_each_case() {
assert_eq!(describe_tools_filter(None), "<none>");
assert_eq!(describe_tools_filter(Some(&[])), "<none>");
let tools = vec!["a".to_string(), "b".to_string()];
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
}
#[test] #[test]
fn categorize_tools_splits_mcp_and_regular() { fn categorize_tools_splits_mcp_and_regular() {
let entries = vec![ let entries = vec![
+1 -5
View File
@@ -72,10 +72,6 @@ impl GraphLogger {
"[graph:{}] entering '{}' (visit {visit})", "[graph:{}] entering '{}' (visit {visit})",
self.graph_name, node.id self.graph_name, node.id
); );
eprintln!(
"{}",
dimmed_text(&format!("{} ({})", node.id, node_type_label(node)))
);
} }
pub fn record_timing(&mut self, node_id: &str, elapsed: Duration) { pub fn record_timing(&mut self, node_id: &str, elapsed: Duration) {
@@ -142,7 +138,7 @@ impl GraphLogger {
} }
} }
fn node_type_label(node: &Node) -> &'static str { pub(super) fn node_type_label(node: &Node) -> &'static str {
match &node.node_type { match &node.node_type {
NodeType::Agent(_) => "agent", NodeType::Agent(_) => "agent",
NodeType::Script(_) => "script", NodeType::Script(_) => "script",
+1 -1
View File
@@ -123,7 +123,7 @@ impl MapNodeExecutor {
} }
let joined = join_all(sub_tasks).await; let joined = join_all(sub_tasks).await;
progress_tracker.clear(); drop(progress_tracker);
// Collect outputs keyed by input index so order is preserved regardless // Collect outputs keyed by input index so order is preserved regardless
// of finish order. This is the user-facing contract from plan E.2. // of finish order. This is the user-facing contract from plan E.2.
+1 -7
View File
@@ -46,12 +46,6 @@ impl BranchProgressTracker {
started: Instant::now(), started: Instant::now(),
} }
} }
pub fn clear(&self) {
if let Some(multi) = &self.multi {
let _ = multi.clear();
}
}
} }
pub(super) struct BranchProgressHandle { pub(super) struct BranchProgressHandle {
@@ -60,7 +54,7 @@ pub(super) struct BranchProgressHandle {
} }
impl BranchProgressHandle { impl BranchProgressHandle {
fn disabled() -> Self { pub fn disabled() -> Self {
Self { Self {
bar: None, bar: None,
started: Instant::now(), started: Instant::now(),
+1 -6
View File
@@ -1,7 +1,7 @@
use super::state::StateManager; use super::state::StateManager;
use super::types::RagNode; use super::types::RagNode;
use crate::config::RequestContext; use crate::config::RequestContext;
use crate::utils::{create_abort_signal, dimmed_text}; use crate::utils::create_abort_signal;
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, Result, anyhow};
use serde_json::{Map, Value}; use serde_json::{Map, Value};
use std::time::Duration; use std::time::Duration;
@@ -34,11 +34,6 @@ impl RagNodeExecutor {
let top_k = node.top_k.unwrap_or_else(|| rag.configured_top_k()); let top_k = node.top_k.unwrap_or_else(|| rag.configured_top_k());
let rerank = rag.configured_reranker(); let rerank = rag.configured_reranker();
eprintln!(
"{}",
dimmed_text(&format!("▸ rag lookup: node={node_id} top_k={top_k}"))
);
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_RAG_TIMEOUT_SECS)); let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_RAG_TIMEOUT_SECS));
let abort = create_abort_signal(); let abort = create_abort_signal();
let (context, sources_str, _ids) = let (context, sources_str, _ids) =
-23
View File
@@ -1,7 +1,6 @@
use super::state::{StateManager, StateRepresentation}; use super::state::{StateManager, StateRepresentation};
use super::types::ScriptNode; use super::types::ScriptNode;
use crate::function::Language; use crate::function::Language;
use crate::utils::dimmed_text;
use anyhow::{Context, Result, anyhow, bail}; use anyhow::{Context, Result, anyhow, bail};
use serde_json::Value; use serde_json::Value;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
@@ -32,11 +31,6 @@ impl ScriptExecutor {
bail!("Script file not found: '{}'", script_path.display()); bail!("Script file not found: '{}'", script_path.display());
} }
eprintln!(
"{}",
dimmed_text(&format!("▸ running script '{}'", node.script))
);
let language = detect_language(&script_path)?; let language = detect_language(&script_path)?;
let state_repr = state_manager.serialize_state()?; let state_repr = state_manager.serialize_state()?;
@@ -98,23 +92,6 @@ impl ScriptExecutor {
) )
})?; })?;
if let Ok(parsed) = serde_json::from_str::<serde_json::Map<String, Value>>(json_output) {
let keys: Vec<&str> = parsed
.keys()
.filter(|k| k.as_str() != "_next")
.map(|s| s.as_str())
.collect();
if !keys.is_empty() {
eprintln!(
"{}",
dimmed_text(&format!("▸ merged: {}", keys.join(", ")))
);
}
if let Some(n) = &next {
eprintln!("{}", dimmed_text(&format!("▸ script set _next = '{n}'")));
}
}
apply_state_updates(node, state_manager); apply_state_updates(node, state_manager);
Ok(next) Ok(next)
+2 -13
View File
@@ -1,6 +1,6 @@
use crate::client::call_chat_completions; use crate::client::call_chat_completions;
use crate::config::{Input, RequestContext, Role, RoleLike}; use crate::config::{Input, RequestContext, Role, RoleLike};
use crate::utils::{create_abort_signal, dimmed_text}; use crate::utils::create_abort_signal;
use anyhow::{Context, Result, bail}; use anyhow::{Context, Result, bail};
use serde_json::Value; use serde_json::Value;
use std::sync::Arc; use std::sync::Arc;
@@ -24,10 +24,6 @@ pub async fn extract(raw: &str, schema: &Value, parent_ctx: &mut RequestContext)
return Ok(parsed); return Ok(parsed);
} }
eprintln!(
"{}",
dimmed_text("▸ structured-output: parsing raw output failed, invoking extractor")
);
extract_via_extractor(raw, schema, parent_ctx, false).await extract_via_extractor(raw, schema, parent_ctx, false).await
} }
@@ -53,14 +49,7 @@ async fn extract_via_extractor(
"Structured-output extractor failed to produce valid JSON after repair retry. \ "Structured-output extractor failed to produce valid JSON after repair retry. \
Last response:\n{output}" Last response:\n{output}"
), ),
None => { None => Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await,
eprintln!(
"{}",
dimmed_text("▸ structured-output: extractor returned invalid JSON, retrying")
);
Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await
}
} }
} }
+20
View File
@@ -17,7 +17,11 @@ pub async fn render_stream(
rx: UnboundedReceiver<SseEvent>, rx: UnboundedReceiver<SseEvent>,
app: &AppConfig, app: &AppConfig,
abort_signal: AbortSignal, abort_signal: AbortSignal,
silent: bool,
) -> Result<()> { ) -> Result<()> {
if silent {
return drain_silently(rx, &abort_signal).await;
}
let ret = if *IS_STDOUT_TERMINAL && app.highlight { let ret = if *IS_STDOUT_TERMINAL && app.highlight {
let render_options = app.render_options()?; let render_options = app.render_options()?;
let mut render = MarkdownRender::init(render_options)?; let mut render = MarkdownRender::init(render_options)?;
@@ -28,6 +32,22 @@ pub async fn render_stream(
ret.map_err(|err| err.context("Failed to reader stream")) ret.map_err(|err| err.context("Failed to reader stream"))
} }
async fn drain_silently(
mut rx: UnboundedReceiver<SseEvent>,
abort_signal: &AbortSignal,
) -> Result<()> {
loop {
if abort_signal.aborted() {
break;
}
match rx.recv().await {
Some(SseEvent::Done) | None => break,
Some(SseEvent::Text(_)) => {}
}
}
Ok(())
}
pub fn render_error(err: anyhow::Error) { pub fn render_error(err: anyhow::Error) {
eprintln!("{}", error_text(&pretty_error(&err))); eprintln!("{}", error_text(&pretty_error(&err)));
} }