From fd0e4e6d0e3360f21453e945e37028e5b837642c Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Wed, 20 May 2026 18:54:20 -0600 Subject: [PATCH] feat: improved UX for parallel graph execution --- src/client/common.rs | 10 ++++--- src/function/user_interaction.rs | 47 ++++++++++++++++++++------------ src/graph/agent.rs | 24 ---------------- src/graph/executor.rs | 34 +++++++++++++---------- src/graph/llm.rs | 26 +----------------- src/graph/logging.rs | 6 +--- src/graph/map.rs | 2 +- src/graph/progress.rs | 8 +----- src/graph/rag.rs | 7 +---- src/graph/script.rs | 23 ---------------- src/graph/structured.rs | 15 ++-------- src/render/mod.rs | 20 ++++++++++++++ 12 files changed, 82 insertions(+), 140 deletions(-) diff --git a/src/client/common.rs b/src/client/common.rs index 8bb1d6e..33fb8a7 100644 --- a/src/client/common.rs +++ b/src/client/common.rs @@ -1,6 +1,6 @@ use super::*; -use crate::config::{paths, RenderMode}; +use crate::config::{RenderMode, paths}; use crate::{ config::{AppConfig, Input, RequestContext}, function::{FunctionDeclaration, ToolCall, ToolResult, eval_tool_calls}, @@ -418,7 +418,8 @@ pub async fn call_chat_completions( abort_signal: AbortSignal, ) -> Result<(String, Vec)> { let is_child_agent = ctx.current_depth > 0; - let spinner_message = if is_child_agent { "" } else { "Generating" }; + let suppress_spinner = is_child_agent || ctx.render_mode == RenderMode::Silent; + let spinner_message = if suppress_spinner { "" } else { "Generating" }; let ret = abortable_run_with_spinner( client.chat_completions(input.clone()), spinner_message, @@ -459,13 +460,14 @@ pub async fn call_chat_completions_streaming( ) -> Result<(String, Vec)> { let (tx, rx) = unbounded_channel(); let mut handler = SseHandler::new(tx, abort_signal.clone()); - if ctx.render_mode == RenderMode::Silent { + let silent = ctx.render_mode == RenderMode::Silent; + if silent { handler.set_silent(true); } let (send_ret, render_ret) = tokio::join!( client.chat_completions_streaming(input, &mut handler), - render_stream(rx, client.app_config(), abort_signal.clone()), + render_stream(rx, client.app_config(), abort_signal.clone(), silent), ); if handler.abort().aborted() { diff --git a/src/function/user_interaction.rs b/src/function/user_interaction.rs index e747a76..92c407c 100644 --- a/src/function/user_interaction.rs +++ b/src/function/user_interaction.rs @@ -2,7 +2,7 @@ use super::{FunctionDeclaration, JsonSchema}; use crate::config::RequestContext; use crate::supervisor::escalation::{EscalationRequest, new_escalation_id}; -use anyhow::{Result, anyhow}; +use anyhow::{Result, anyhow, bail}; use indexmap::IndexMap; use inquire::{Confirm, MultiSelect, Select, Text}; use serde_json::{Value, json}; @@ -155,7 +155,10 @@ fn handle_direct_ask(args: &Value) -> Result { let mut options = parse_options(args)?; options.push(CUSTOM_MULTI_CHOICE_ANSWER_OPTION.to_string()); - let mut answer = Select::new(question, options).prompt()?; + let mut answer = Select::new(question, options) + .without_filtering() + .with_help_message("↑↓ to move, enter to select") + .prompt()?; if answer == CUSTOM_MULTI_CHOICE_ANSWER_OPTION { answer = Text::new("Custom response:").prompt()? @@ -205,12 +208,11 @@ async fn handle_escalated(ctx: &RequestContext, action: &str, args: &Value) -> R .ok_or_else(|| anyhow!("'question' is required"))? .to_string(); - let options: Option> = args.get("options").and_then(Value::as_array).map(|arr| { - arr.iter() - .filter_map(Value::as_str) - .map(String::from) - .collect() - }); + let options: Option> = if args.get("options").is_some() { + Some(parse_options(args)?) + } else { + None + }; let from_agent_id = ctx .self_agent_id @@ -262,13 +264,24 @@ async fn handle_escalated(ctx: &RequestContext, action: &str, args: &Value) -> R } fn parse_options(args: &Value) -> Result> { - args.get("options") - .and_then(Value::as_array) - .map(|arr| { - arr.iter() - .filter_map(Value::as_str) - .map(String::from) - .collect() - }) - .ok_or_else(|| anyhow!("'options' is required and must be an array of strings")) + let raw = args + .get("options") + .ok_or_else(|| anyhow!("'options' is required and must be an array of strings"))?; + + let arr: Vec = match raw { + Value::Array(arr) => arr.clone(), + Value::String(s) => serde_json::from_str::>(s).map_err(|_| { + anyhow!( + "'options' was a string but did not parse as a JSON array. \ + Pass options as a native JSON array, e.g. [\"yes\", \"no\"]." + ) + })?, + _ => bail!("'options' is required and must be an array of strings"), + }; + + Ok(arr + .iter() + .filter_map(Value::as_str) + .map(String::from) + .collect()) } diff --git a/src/graph/agent.rs b/src/graph/agent.rs index 45fbb75..de95cfd 100644 --- a/src/graph/agent.rs +++ b/src/graph/agent.rs @@ -3,7 +3,6 @@ use super::structured; use super::types::AgentNode; use crate::config::RequestContext; use crate::function::supervisor::run_agent_for_graph; -use crate::utils::dimmed_text; use anyhow::{Context, Result}; use serde_json::Value; use std::time::Duration; @@ -24,14 +23,6 @@ impl AgentNodeExecutor { .interpolate(&node.prompt) .with_context(|| format!("Failed to interpolate prompt for agent '{}'", node.agent))?; - eprintln!( - "{}", - dimmed_text(&format!("▸ spawning agent '{}' with prompt:", node.agent)) - ); - for line in indent_prompt(&prompt, 6) { - eprintln!("{}", dimmed_text(&line)); - } - let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS)); let raw = timeout( @@ -66,21 +57,6 @@ impl AgentNodeExecutor { } } -fn indent_prompt(prompt: &str, prefix_spaces: usize) -> Vec { - const MAX_LINES: usize = 12; - let pad = " ".repeat(prefix_spaces); - let mut out: Vec = prompt - .lines() - .take(MAX_LINES) - .map(|line| format!("{pad}{line}")) - .collect(); - let total = prompt.lines().count(); - if total > MAX_LINES { - out.push(format!("{pad}... ({} more lines)", total - MAX_LINES)); - } - out -} - fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &Value) { if node.output_schema.is_some() && let Some(obj) = output.as_object() diff --git a/src/graph/executor.rs b/src/graph/executor.rs index 552c270..563506b 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -1,6 +1,6 @@ use super::agent::AgentNodeExecutor; use super::llm::{LlmExecutionOutcome, LlmNodeExecutor}; -use super::logging::GraphLogger; +use super::logging::{GraphLogger, node_type_label}; use super::map::MapNodeExecutor; use super::progress::{BranchProgressHandle, BranchProgressTracker}; use super::rag::RagNodeExecutor; @@ -146,11 +146,12 @@ impl GraphExecutor { let semaphore = Arc::new(Semaphore::new(max_concurrency)); let frontier_size = frontier.len(); - let progress_tracker = if frontier_size > 1 { - Some(BranchProgressTracker::new()) - } else { - None - }; + let has_progress_nodes = frontier.iter().any(|nid| { + graph.get_node(nid).is_some_and(|n| { + !matches!(n.node_type, NodeType::Approval(_) | NodeType::Input(_)) + }) + }); + let progress_tracker = has_progress_nodes.then(BranchProgressTracker::new); let mut branch_tasks = Vec::with_capacity(frontier_size); for node_id in &frontier { let node = graph @@ -161,19 +162,24 @@ impl GraphExecutor { .clone(); let branch_state = state.fork_for_branch_state(); let mut branch_ctx = ctx.fork_for_branch(); - if frontier_size > 1 { - branch_ctx.render_mode = RenderMode::Silent; - } + branch_ctx.render_mode = RenderMode::Silent; let script_exec_clone = script_executor.clone(); let graph_clone = Arc::clone(&graph); let current = node_id.clone(); let sem_clone = semaphore.clone(); let abort_clone = abort_signal.clone(); - let progress_handle: Option = - progress_tracker.as_ref().map(|t| t.add_branch(node_id)); + let progress_handle = match ( + matches!(node.node_type, NodeType::Approval(_) | NodeType::Input(_)), + &progress_tracker, + ) { + (false, Some(tracker)) => { + tracker.add_branch(&format!("{} ({})", node_id, node_type_label(&node))) + } + _ => BranchProgressHandle::disabled(), + }; let task = tokio::spawn(async move { - let mut progress_handle = progress_handle; + let mut progress_handle = Some(progress_handle); let _permit = sem_clone .acquire() .await @@ -212,9 +218,7 @@ impl GraphExecutor { } let joined = join_all(branch_tasks).await; - if let Some(t) = &progress_tracker { - t.clear(); - } + drop(progress_tracker); let mut branch_writes: Vec = Vec::new(); let mut next_frontier: HashSet = HashSet::new(); diff --git a/src/graph/llm.rs b/src/graph/llm.rs index 98b47f1..86ea453 100644 --- a/src/graph/llm.rs +++ b/src/graph/llm.rs @@ -3,7 +3,7 @@ use super::structured; use super::types::LlmNode; use crate::client::{Model, ModelType, call_chat_completions}; use crate::config::{Input, RequestContext, Role, RoleLike}; -use crate::utils::{create_abort_signal, dimmed_text}; +use crate::utils::create_abort_signal; use anyhow::{Context, Error, Result, anyhow, bail}; use serde_json::Value; use std::collections::HashSet; @@ -101,15 +101,6 @@ async fn run( let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref()); validate_tools_subset(®ular_tools, &mcp_servers, parent_ctx)?; - eprintln!( - "{}", - dimmed_text(&format!( - "▸ llm call: model={} tools={}", - node.model.as_deref().unwrap_or(""), - describe_tools_filter(node.tools.as_deref()) - )) - ); - let role = build_inline_role( node, instructions.as_deref(), @@ -363,13 +354,6 @@ fn format_schema_hint(schema: &Value) -> String { ) } -fn describe_tools_filter(tools: Option<&[String]>) -> String { - match tools { - Some(t) if !t.is_empty() => t.join(","), - _ => "".into(), - } -} - #[cfg(test)] mod tests { use super::super::types::*; @@ -571,14 +555,6 @@ mod tests { assert!(hint.contains("ONLY")); } - #[test] - fn describe_tools_filter_renders_each_case() { - assert_eq!(describe_tools_filter(None), ""); - assert_eq!(describe_tools_filter(Some(&[])), ""); - let tools = vec!["a".to_string(), "b".to_string()]; - assert_eq!(describe_tools_filter(Some(&tools)), "a,b"); - } - #[test] fn categorize_tools_splits_mcp_and_regular() { let entries = vec![ diff --git a/src/graph/logging.rs b/src/graph/logging.rs index 768d611..69d2990 100644 --- a/src/graph/logging.rs +++ b/src/graph/logging.rs @@ -72,10 +72,6 @@ impl GraphLogger { "[graph:{}] entering '{}' (visit {visit})", self.graph_name, node.id ); - eprintln!( - "{}", - dimmed_text(&format!("▸ {} ({})", node.id, node_type_label(node))) - ); } pub fn record_timing(&mut self, node_id: &str, elapsed: Duration) { @@ -142,7 +138,7 @@ impl GraphLogger { } } -fn node_type_label(node: &Node) -> &'static str { +pub(super) fn node_type_label(node: &Node) -> &'static str { match &node.node_type { NodeType::Agent(_) => "agent", NodeType::Script(_) => "script", diff --git a/src/graph/map.rs b/src/graph/map.rs index a81156f..627a07c 100644 --- a/src/graph/map.rs +++ b/src/graph/map.rs @@ -123,7 +123,7 @@ impl MapNodeExecutor { } let joined = join_all(sub_tasks).await; - progress_tracker.clear(); + drop(progress_tracker); // Collect outputs keyed by input index so order is preserved regardless // of finish order. This is the user-facing contract from plan E.2. diff --git a/src/graph/progress.rs b/src/graph/progress.rs index 3c38b37..2aa5330 100644 --- a/src/graph/progress.rs +++ b/src/graph/progress.rs @@ -46,12 +46,6 @@ impl BranchProgressTracker { started: Instant::now(), } } - - pub fn clear(&self) { - if let Some(multi) = &self.multi { - let _ = multi.clear(); - } - } } pub(super) struct BranchProgressHandle { @@ -60,7 +54,7 @@ pub(super) struct BranchProgressHandle { } impl BranchProgressHandle { - fn disabled() -> Self { + pub fn disabled() -> Self { Self { bar: None, started: Instant::now(), diff --git a/src/graph/rag.rs b/src/graph/rag.rs index 3439952..2aed523 100644 --- a/src/graph/rag.rs +++ b/src/graph/rag.rs @@ -1,7 +1,7 @@ use super::state::StateManager; use super::types::RagNode; use crate::config::RequestContext; -use crate::utils::{create_abort_signal, dimmed_text}; +use crate::utils::create_abort_signal; use anyhow::{Context, Result, anyhow}; use serde_json::{Map, Value}; use std::time::Duration; @@ -34,11 +34,6 @@ impl RagNodeExecutor { let top_k = node.top_k.unwrap_or_else(|| rag.configured_top_k()); let rerank = rag.configured_reranker(); - eprintln!( - "{}", - dimmed_text(&format!("▸ rag lookup: node={node_id} top_k={top_k}")) - ); - let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_RAG_TIMEOUT_SECS)); let abort = create_abort_signal(); let (context, sources_str, _ids) = diff --git a/src/graph/script.rs b/src/graph/script.rs index 91ebd41..f4a01bb 100644 --- a/src/graph/script.rs +++ b/src/graph/script.rs @@ -1,7 +1,6 @@ use super::state::{StateManager, StateRepresentation}; use super::types::ScriptNode; use crate::function::Language; -use crate::utils::dimmed_text; use anyhow::{Context, Result, anyhow, bail}; use serde_json::Value; use std::path::{Path, PathBuf}; @@ -32,11 +31,6 @@ impl ScriptExecutor { bail!("Script file not found: '{}'", script_path.display()); } - eprintln!( - "{}", - dimmed_text(&format!("▸ running script '{}'", node.script)) - ); - let language = detect_language(&script_path)?; let state_repr = state_manager.serialize_state()?; @@ -98,23 +92,6 @@ impl ScriptExecutor { ) })?; - if let Ok(parsed) = serde_json::from_str::>(json_output) { - let keys: Vec<&str> = parsed - .keys() - .filter(|k| k.as_str() != "_next") - .map(|s| s.as_str()) - .collect(); - if !keys.is_empty() { - eprintln!( - "{}", - dimmed_text(&format!("▸ merged: {}", keys.join(", "))) - ); - } - if let Some(n) = &next { - eprintln!("{}", dimmed_text(&format!("▸ script set _next = '{n}'"))); - } - } - apply_state_updates(node, state_manager); Ok(next) diff --git a/src/graph/structured.rs b/src/graph/structured.rs index 26acc5e..a590709 100644 --- a/src/graph/structured.rs +++ b/src/graph/structured.rs @@ -1,6 +1,6 @@ use crate::client::call_chat_completions; use crate::config::{Input, RequestContext, Role, RoleLike}; -use crate::utils::{create_abort_signal, dimmed_text}; +use crate::utils::create_abort_signal; use anyhow::{Context, Result, bail}; use serde_json::Value; use std::sync::Arc; @@ -24,10 +24,6 @@ pub async fn extract(raw: &str, schema: &Value, parent_ctx: &mut RequestContext) return Ok(parsed); } - eprintln!( - "{}", - dimmed_text("▸ structured-output: parsing raw output failed, invoking extractor") - ); extract_via_extractor(raw, schema, parent_ctx, false).await } @@ -53,14 +49,7 @@ async fn extract_via_extractor( "Structured-output extractor failed to produce valid JSON after repair retry. \ Last response:\n{output}" ), - None => { - eprintln!( - "{}", - dimmed_text("▸ structured-output: extractor returned invalid JSON, retrying") - ); - - Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await - } + None => Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await, } } diff --git a/src/render/mod.rs b/src/render/mod.rs index 80a5816..8bf2baa 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -17,7 +17,11 @@ pub async fn render_stream( rx: UnboundedReceiver, app: &AppConfig, abort_signal: AbortSignal, + silent: bool, ) -> Result<()> { + if silent { + return drain_silently(rx, &abort_signal).await; + } let ret = if *IS_STDOUT_TERMINAL && app.highlight { let render_options = app.render_options()?; let mut render = MarkdownRender::init(render_options)?; @@ -28,6 +32,22 @@ pub async fn render_stream( ret.map_err(|err| err.context("Failed to reader stream")) } +async fn drain_silently( + mut rx: UnboundedReceiver, + abort_signal: &AbortSignal, +) -> Result<()> { + loop { + if abort_signal.aborted() { + break; + } + match rx.recv().await { + Some(SseEvent::Done) | None => break, + Some(SseEvent::Text(_)) => {} + } + } + Ok(()) +} + pub fn render_error(err: anyhow::Error) { eprintln!("{}", error_text(&pretty_error(&err))); }