feat: improved UX for parallel graph execution

This commit is contained in:
2026-05-20 18:54:20 -06:00
parent 28262cd860
commit fd0e4e6d0e
12 changed files with 82 additions and 140 deletions
+6 -4
View File
@@ -1,6 +1,6 @@
use super::*;
use crate::config::{paths, RenderMode};
use crate::config::{RenderMode, paths};
use crate::{
config::{AppConfig, Input, RequestContext},
function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls},
@@ -418,7 +418,8 @@ pub async fn call_chat_completions(
abort_signal: AbortSignal,
) -> Result<(String, Vec<ToolResult>)> {
let is_child_agent = ctx.current_depth > 0;
let spinner_message = if is_child_agent { "" } else { "Generating" };
let suppress_spinner = is_child_agent || ctx.render_mode == RenderMode::Silent;
let spinner_message = if suppress_spinner { "" } else { "Generating" };
let ret = abortable_run_with_spinner(
client.chat_completions(input.clone()),
spinner_message,
@@ -459,13 +460,14 @@ pub async fn call_chat_completions_streaming(
) -> Result<(String, Vec<ToolResult>)> {
let (tx, rx) = unbounded_channel();
let mut handler = SseHandler::new(tx, abort_signal.clone());
if ctx.render_mode == RenderMode::Silent {
let silent = ctx.render_mode == RenderMode::Silent;
if silent {
handler.set_silent(true);
}
let (send_ret, render_ret) = tokio::join!(
client.chat_completions_streaming(input, &mut handler),
render_stream(rx, client.app_config(), abort_signal.clone()),
render_stream(rx, client.app_config(), abort_signal.clone(), silent),
);
if handler.abort().aborted() {
+30 -17
View File
@@ -2,7 +2,7 @@ use super::{FunctionDeclaration, JsonSchema};
use crate::config::RequestContext;
use crate::supervisor::escalation::{EscalationRequest, new_escalation_id};
use anyhow::{Result, anyhow};
use anyhow::{Result, anyhow, bail};
use indexmap::IndexMap;
use inquire::{Confirm, MultiSelect, Select, Text};
use serde_json::{Value, json};
@@ -155,7 +155,10 @@ fn handle_direct_ask(args: &Value) -> Result<Value> {
let mut options = parse_options(args)?;
options.push(CUSTOM_MULTI_CHOICE_ANSWER_OPTION.to_string());
let mut answer = Select::new(question, options).prompt()?;
let mut answer = Select::new(question, options)
.without_filtering()
.with_help_message("↑↓ to move, enter to select")
.prompt()?;
if answer == CUSTOM_MULTI_CHOICE_ANSWER_OPTION {
answer = Text::new("Custom response:").prompt()?
@@ -205,12 +208,11 @@ async fn handle_escalated(ctx: &RequestContext, action: &str, args: &Value) -> R
.ok_or_else(|| anyhow!("'question' is required"))?
.to_string();
let options: Option<Vec<String>> = args.get("options").and_then(Value::as_array).map(|arr| {
arr.iter()
.filter_map(Value::as_str)
.map(String::from)
.collect()
});
let options: Option<Vec<String>> = if args.get("options").is_some() {
Some(parse_options(args)?)
} else {
None
};
let from_agent_id = ctx
.self_agent_id
@@ -262,13 +264,24 @@ async fn handle_escalated(ctx: &RequestContext, action: &str, args: &Value) -> R
}
fn parse_options(args: &Value) -> Result<Vec<String>> {
args.get("options")
.and_then(Value::as_array)
.map(|arr| {
arr.iter()
.filter_map(Value::as_str)
.map(String::from)
.collect()
})
.ok_or_else(|| anyhow!("'options' is required and must be an array of strings"))
let raw = args
.get("options")
.ok_or_else(|| anyhow!("'options' is required and must be an array of strings"))?;
let arr: Vec<Value> = match raw {
Value::Array(arr) => arr.clone(),
Value::String(s) => serde_json::from_str::<Vec<Value>>(s).map_err(|_| {
anyhow!(
"'options' was a string but did not parse as a JSON array. \
Pass options as a native JSON array, e.g. [\"yes\", \"no\"]."
)
})?,
_ => bail!("'options' is required and must be an array of strings"),
};
Ok(arr
.iter()
.filter_map(Value::as_str)
.map(String::from)
.collect())
}
-24
View File
@@ -3,7 +3,6 @@ use super::structured;
use super::types::AgentNode;
use crate::config::RequestContext;
use crate::function::supervisor::run_agent_for_graph;
use crate::utils::dimmed_text;
use anyhow::{Context, Result};
use serde_json::Value;
use std::time::Duration;
@@ -24,14 +23,6 @@ impl AgentNodeExecutor {
.interpolate(&node.prompt)
.with_context(|| format!("Failed to interpolate prompt for agent '{}'", node.agent))?;
eprintln!(
"{}",
dimmed_text(&format!("▸ spawning agent '{}' with prompt:", node.agent))
);
for line in indent_prompt(&prompt, 6) {
eprintln!("{}", dimmed_text(&line));
}
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS));
let raw = timeout(
@@ -66,21 +57,6 @@ impl AgentNodeExecutor {
}
}
fn indent_prompt(prompt: &str, prefix_spaces: usize) -> Vec<String> {
const MAX_LINES: usize = 12;
let pad = " ".repeat(prefix_spaces);
let mut out: Vec<String> = prompt
.lines()
.take(MAX_LINES)
.map(|line| format!("{pad}{line}"))
.collect();
let total = prompt.lines().count();
if total > MAX_LINES {
out.push(format!("{pad}... ({} more lines)", total - MAX_LINES));
}
out
}
fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &Value) {
if node.output_schema.is_some()
&& let Some(obj) = output.as_object()
+19 -15
View File
@@ -1,6 +1,6 @@
use super::agent::AgentNodeExecutor;
use super::llm::{LlmExecutionOutcome, LlmNodeExecutor};
use super::logging::GraphLogger;
use super::logging::{GraphLogger, node_type_label};
use super::map::MapNodeExecutor;
use super::progress::{BranchProgressHandle, BranchProgressTracker};
use super::rag::RagNodeExecutor;
@@ -146,11 +146,12 @@ impl GraphExecutor {
let semaphore = Arc::new(Semaphore::new(max_concurrency));
let frontier_size = frontier.len();
let progress_tracker = if frontier_size > 1 {
Some(BranchProgressTracker::new())
} else {
None
};
let has_progress_nodes = frontier.iter().any(|nid| {
graph.get_node(nid).is_some_and(|n| {
!matches!(n.node_type, NodeType::Approval(_) | NodeType::Input(_))
})
});
let progress_tracker = has_progress_nodes.then(BranchProgressTracker::new);
let mut branch_tasks = Vec::with_capacity(frontier_size);
for node_id in &frontier {
let node = graph
@@ -161,19 +162,24 @@ impl GraphExecutor {
.clone();
let branch_state = state.fork_for_branch_state();
let mut branch_ctx = ctx.fork_for_branch();
if frontier_size > 1 {
branch_ctx.render_mode = RenderMode::Silent;
}
branch_ctx.render_mode = RenderMode::Silent;
let script_exec_clone = script_executor.clone();
let graph_clone = Arc::clone(&graph);
let current = node_id.clone();
let sem_clone = semaphore.clone();
let abort_clone = abort_signal.clone();
let progress_handle: Option<BranchProgressHandle> =
progress_tracker.as_ref().map(|t| t.add_branch(node_id));
let progress_handle = match (
matches!(node.node_type, NodeType::Approval(_) | NodeType::Input(_)),
&progress_tracker,
) {
(false, Some(tracker)) => {
tracker.add_branch(&format!("{} ({})", node_id, node_type_label(&node)))
}
_ => BranchProgressHandle::disabled(),
};
let task = tokio::spawn(async move {
let mut progress_handle = progress_handle;
let mut progress_handle = Some(progress_handle);
let _permit = sem_clone
.acquire()
.await
@@ -212,9 +218,7 @@ impl GraphExecutor {
}
let joined = join_all(branch_tasks).await;
if let Some(t) = &progress_tracker {
t.clear();
}
drop(progress_tracker);
let mut branch_writes: Vec<BranchWrites> = Vec::new();
let mut next_frontier: HashSet<String> = HashSet::new();
+1 -25
View File
@@ -3,7 +3,7 @@ use super::structured;
use super::types::LlmNode;
use crate::client::{Model, ModelType, call_chat_completions};
use crate::config::{Input, RequestContext, Role, RoleLike};
use crate::utils::{create_abort_signal, dimmed_text};
use crate::utils::create_abort_signal;
use anyhow::{Context, Error, Result, anyhow, bail};
use serde_json::Value;
use std::collections::HashSet;
@@ -101,15 +101,6 @@ async fn run(
let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref());
validate_tools_subset(&regular_tools, &mcp_servers, parent_ctx)?;
eprintln!(
"{}",
dimmed_text(&format!(
"▸ llm call: model={} tools={}",
node.model.as_deref().unwrap_or("<active>"),
describe_tools_filter(node.tools.as_deref())
))
);
let role = build_inline_role(
node,
instructions.as_deref(),
@@ -363,13 +354,6 @@ fn format_schema_hint(schema: &Value) -> String {
)
}
fn describe_tools_filter(tools: Option<&[String]>) -> String {
match tools {
Some(t) if !t.is_empty() => t.join(","),
_ => "<none>".into(),
}
}
#[cfg(test)]
mod tests {
use super::super::types::*;
@@ -571,14 +555,6 @@ mod tests {
assert!(hint.contains("ONLY"));
}
#[test]
fn describe_tools_filter_renders_each_case() {
assert_eq!(describe_tools_filter(None), "<none>");
assert_eq!(describe_tools_filter(Some(&[])), "<none>");
let tools = vec!["a".to_string(), "b".to_string()];
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
}
#[test]
fn categorize_tools_splits_mcp_and_regular() {
let entries = vec![
+1 -5
View File
@@ -72,10 +72,6 @@ impl GraphLogger {
"[graph:{}] entering '{}' (visit {visit})",
self.graph_name, node.id
);
eprintln!(
"{}",
dimmed_text(&format!("{} ({})", node.id, node_type_label(node)))
);
}
pub fn record_timing(&mut self, node_id: &str, elapsed: Duration) {
@@ -142,7 +138,7 @@ impl GraphLogger {
}
}
fn node_type_label(node: &Node) -> &'static str {
pub(super) fn node_type_label(node: &Node) -> &'static str {
match &node.node_type {
NodeType::Agent(_) => "agent",
NodeType::Script(_) => "script",
+1 -1
View File
@@ -123,7 +123,7 @@ impl MapNodeExecutor {
}
let joined = join_all(sub_tasks).await;
progress_tracker.clear();
drop(progress_tracker);
// Collect outputs keyed by input index so order is preserved regardless
// of finish order. This is the user-facing contract from plan E.2.
+1 -7
View File
@@ -46,12 +46,6 @@ impl BranchProgressTracker {
started: Instant::now(),
}
}
pub fn clear(&self) {
if let Some(multi) = &self.multi {
let _ = multi.clear();
}
}
}
pub(super) struct BranchProgressHandle {
@@ -60,7 +54,7 @@ pub(super) struct BranchProgressHandle {
}
impl BranchProgressHandle {
fn disabled() -> Self {
pub fn disabled() -> Self {
Self {
bar: None,
started: Instant::now(),
+1 -6
View File
@@ -1,7 +1,7 @@
use super::state::StateManager;
use super::types::RagNode;
use crate::config::RequestContext;
use crate::utils::{create_abort_signal, dimmed_text};
use crate::utils::create_abort_signal;
use anyhow::{Context, Result, anyhow};
use serde_json::{Map, Value};
use std::time::Duration;
@@ -34,11 +34,6 @@ impl RagNodeExecutor {
let top_k = node.top_k.unwrap_or_else(|| rag.configured_top_k());
let rerank = rag.configured_reranker();
eprintln!(
"{}",
dimmed_text(&format!("▸ rag lookup: node={node_id} top_k={top_k}"))
);
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_RAG_TIMEOUT_SECS));
let abort = create_abort_signal();
let (context, sources_str, _ids) =
-23
View File
@@ -1,7 +1,6 @@
use super::state::{StateManager, StateRepresentation};
use super::types::ScriptNode;
use crate::function::Language;
use crate::utils::dimmed_text;
use anyhow::{Context, Result, anyhow, bail};
use serde_json::Value;
use std::path::{Path, PathBuf};
@@ -32,11 +31,6 @@ impl ScriptExecutor {
bail!("Script file not found: '{}'", script_path.display());
}
eprintln!(
"{}",
dimmed_text(&format!("▸ running script '{}'", node.script))
);
let language = detect_language(&script_path)?;
let state_repr = state_manager.serialize_state()?;
@@ -98,23 +92,6 @@ impl ScriptExecutor {
)
})?;
if let Ok(parsed) = serde_json::from_str::<serde_json::Map<String, Value>>(json_output) {
let keys: Vec<&str> = parsed
.keys()
.filter(|k| k.as_str() != "_next")
.map(|s| s.as_str())
.collect();
if !keys.is_empty() {
eprintln!(
"{}",
dimmed_text(&format!("▸ merged: {}", keys.join(", ")))
);
}
if let Some(n) = &next {
eprintln!("{}", dimmed_text(&format!("▸ script set _next = '{n}'")));
}
}
apply_state_updates(node, state_manager);
Ok(next)
+2 -13
View File
@@ -1,6 +1,6 @@
use crate::client::call_chat_completions;
use crate::config::{Input, RequestContext, Role, RoleLike};
use crate::utils::{create_abort_signal, dimmed_text};
use crate::utils::create_abort_signal;
use anyhow::{Context, Result, bail};
use serde_json::Value;
use std::sync::Arc;
@@ -24,10 +24,6 @@ pub async fn extract(raw: &str, schema: &Value, parent_ctx: &mut RequestContext)
return Ok(parsed);
}
eprintln!(
"{}",
dimmed_text("▸ structured-output: parsing raw output failed, invoking extractor")
);
extract_via_extractor(raw, schema, parent_ctx, false).await
}
@@ -53,14 +49,7 @@ async fn extract_via_extractor(
"Structured-output extractor failed to produce valid JSON after repair retry. \
Last response:\n{output}"
),
None => {
eprintln!(
"{}",
dimmed_text("▸ structured-output: extractor returned invalid JSON, retrying")
);
Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await
}
None => Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await,
}
}
+20
View File
@@ -17,7 +17,11 @@ pub async fn render_stream(
rx: UnboundedReceiver<SseEvent>,
app: &AppConfig,
abort_signal: AbortSignal,
silent: bool,
) -> Result<()> {
if silent {
return drain_silently(rx, &abort_signal).await;
}
let ret = if *IS_STDOUT_TERMINAL && app.highlight {
let render_options = app.render_options()?;
let mut render = MarkdownRender::init(render_options)?;
@@ -28,6 +32,22 @@ pub async fn render_stream(
ret.map_err(|err| err.context("Failed to reader stream"))
}
async fn drain_silently(
mut rx: UnboundedReceiver<SseEvent>,
abort_signal: &AbortSignal,
) -> Result<()> {
loop {
if abort_signal.aborted() {
break;
}
match rx.recv().await {
Some(SseEvent::Done) | None => break,
Some(SseEvent::Text(_)) => {}
}
}
Ok(())
}
pub fn render_error(err: anyhow::Error) {
eprintln!("{}", error_text(&pretty_error(&err)));
}