From 3eff13534931cdf8235065862346961eb3df2710 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Wed, 20 May 2026 15:50:38 -0600 Subject: [PATCH] feat: added branch progress tracker for better visualization of parallel graph super-steps --- Cargo.lock | 39 ++++++++++++++++-- Cargo.toml | 1 + src/graph/executor.rs | 21 ++++++++++ src/graph/map.rs | 16 ++++++++ src/graph/mod.rs | 21 +++++----- src/graph/progress.rs | 94 +++++++++++++++++++++++++++++++++++++++++++ src/graph/reducer.rs | 2 +- 7 files changed, 180 insertions(+), 14 deletions(-) create mode 100644 src/graph/progress.rs diff --git a/Cargo.lock b/Cargo.lock index 7c276f5..029450c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1264,6 +1264,19 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width", + "windows-sys 0.59.0", +] + [[package]] name = "console" version = "0.16.3" @@ -1658,7 +1671,7 @@ version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25f104b501bf2364e78d0d3974cbc774f738f5865306ed128e1e0d7499c0ad96" dependencies = [ - "console", + "console 0.16.3", "shell-words", "tempfile", "zeroize", @@ -2854,13 +2867,26 @@ dependencies = [ "serde_core", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console 0.15.11", + "number_prefix", + "portable-atomic", + "unicode-width", + "web-time", +] + [[package]] name = "indicatif" version = "0.18.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ - "console", + "console 0.16.3", "portable-atomic", "unicode-width", "unit-prefix", @@ -3251,6 +3277,7 @@ dependencies = [ "html_to_markdown", "http 1.4.0", "indexmap 2.14.0", + "indicatif 0.17.11", "indoc", "inquire", "is-terminal", @@ -3578,6 +3605,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "objc2" version = "0.6.4" @@ -5109,7 +5142,7 @@ dependencies = [ "either", "flate2", "http 1.4.0", - "indicatif", + "indicatif 0.18.4", "log", "quick-xml 0.38.4", "regex", diff --git a/Cargo.toml b/Cargo.toml index e7f8a5b..9f4535b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ unicode-width = "0.2.0" async-recursion = "1.1.1" http = "1.1.0" indexmap = { version = "2.2.6", features = ["serde"] } +indicatif = "0.17" hmac = "0.12.1" aws-smithy-eventstream = "0.60.4" urlencoding = "2.1.3" diff --git a/src/graph/executor.rs b/src/graph/executor.rs index 661dc53..f9bf69c 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -2,6 +2,7 @@ use super::agent::AgentNodeExecutor; use super::llm::LlmNodeExecutor; use super::logging::GraphLogger; use super::map::MapNodeExecutor; +use super::progress::{BranchProgressHandle, BranchProgressTracker}; use super::rag::RagNodeExecutor; use super::script::ScriptExecutor; use super::staging::BranchWrites; @@ -145,6 +146,11 @@ impl GraphExecutor { let semaphore = Arc::new(Semaphore::new(max_concurrency)); let frontier_size = frontier.len(); + let progress_tracker = if frontier_size > 1 { + Some(BranchProgressTracker::new()) + } else { + None + }; let mut branch_tasks = Vec::with_capacity(frontier_size); for node_id in &frontier { let node = graph @@ -163,13 +169,19 @@ impl GraphExecutor { let current = node_id.clone(); let sem_clone = semaphore.clone(); let abort_clone = abort_signal.clone(); + let progress_handle: Option = + progress_tracker.as_ref().map(|t| t.add_branch(node_id)); let task = tokio::spawn(async move { + let mut progress_handle = progress_handle; let _permit = sem_clone .acquire() .await .expect("semaphore should not be closed"); if abort_clone.aborted() { + if let Some(h) = progress_handle.take() { + h.fail("aborted"); + } return ( current.clone(), branch_state, @@ -188,12 +200,21 @@ impl GraphExecutor { }; let result = step(&node, &mut state, &mut ctx, &step_ctx, ¤t).await; let elapsed = node_start.elapsed(); + if let Some(h) = progress_handle.take() { + match &result { + Ok(_) => h.complete(), + Err(e) => h.fail(&e.to_string()), + } + } (current, state, result, elapsed) }); branch_tasks.push(task); } let joined = join_all(branch_tasks).await; + if let Some(t) = &progress_tracker { + t.clear(); + } let mut branch_writes: Vec = Vec::new(); let mut next_frontier: HashSet = HashSet::new(); diff --git a/src/graph/map.rs b/src/graph/map.rs index a3a18d8..b798610 100644 --- a/src/graph/map.rs +++ b/src/graph/map.rs @@ -1,6 +1,7 @@ use super::agent::AgentNodeExecutor; use super::executor::StepContext; use super::llm::LlmNodeExecutor; +use super::progress::{BranchProgressHandle, BranchProgressTracker}; use super::rag::RagNodeExecutor; use super::state::StateManager; use super::types::{MapNode, NodeType}; @@ -59,6 +60,7 @@ impl MapNodeExecutor { .unwrap_or(step_ctx.max_concurrency) .max(1); let semaphore = Arc::new(Semaphore::new(max_conc)); + let progress_tracker = BranchProgressTracker::new(); let mut sub_tasks = Vec::with_capacity(items.len()); for (idx, item) in items.iter().enumerate() { @@ -72,15 +74,21 @@ impl MapNodeExecutor { let sub_branch_id = node.branch.clone(); let sem = semaphore.clone(); let abort = step_ctx.abort_signal.clone(); + let progress_handle: BranchProgressHandle = + progress_tracker.add_branch(&format!("{}[{idx}]", node.branch)); sub_state.state_mut().set(as_name, item); let task = tokio::spawn(async move { + let mut progress_handle = Some(progress_handle); let _permit = sem .acquire() .await .expect("map semaphore should not be closed"); if abort.aborted() { + if let Some(h) = progress_handle.take() { + h.fail("aborted"); + } return ( idx, sub_state, @@ -119,12 +127,20 @@ impl MapNodeExecutor { )), }; + if let Some(h) = progress_handle.take() { + match &exec_result { + Ok(_) => h.complete(), + Err(e) => h.fail(&e.to_string()), + } + } + (idx, state, exec_result) }); sub_tasks.push(task); } let joined = join_all(sub_tasks).await; + progress_tracker.clear(); // Collect outputs keyed by input index so order is preserved regardless // of finish order. This is the user-facing contract from plan E.2. diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 44fc381..6bd1cd0 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -5,6 +5,7 @@ pub mod llm; pub mod logging; pub mod map; pub mod parser; +pub mod progress; pub mod rag; pub mod reducer; pub mod script; @@ -15,10 +16,10 @@ pub mod types; pub mod user_interaction; pub mod validator; -use serde_json::Value; pub use dispatch::{active_agent_graph_name, run_active_agent_graph}; pub use executor::GraphExecutor; pub use parser::{GraphParser, agent_has_graph}; +use serde_json::Value; pub use types::{Graph, NodeType}; pub const GRAPH_SCHEMA_VERSION: &str = "1.0"; @@ -27,13 +28,13 @@ pub const DEFAULT_MAX_LOOP_ITERATIONS: usize = 100; pub const MAX_STATE_SIZE_BYTES: usize = 32 * 1024; -pub (in crate::graph) fn type_name(value: &Value) -> &'static str { - match value { - Value::Null => "null", - Value::Bool(_) => "bool", - Value::Number(_) => "number", - Value::String(_) => "string", - Value::Array(_) => "array", - Value::Object(_) => "object", - } +pub(in crate::graph) fn type_name(value: &Value) -> &'static str { + match value { + Value::Null => "null", + Value::Bool(_) => "bool", + Value::Number(_) => "number", + Value::String(_) => "string", + Value::Array(_) => "array", + Value::Object(_) => "object", + } } diff --git a/src/graph/progress.rs b/src/graph/progress.rs new file mode 100644 index 0000000..3c38b37 --- /dev/null +++ b/src/graph/progress.rs @@ -0,0 +1,94 @@ +use crate::utils::IS_STDOUT_TERMINAL; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; +use std::sync::LazyLock; +use std::time::{Duration, Instant}; + +static SPINNER_STYLE: LazyLock = LazyLock::new(|| { + ProgressStyle::with_template("{spinner} [{prefix}] {msg}") + .expect("valid template") + .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", ""]) +}); + +// Manages a set of per-branch spinners drawn side-by-side via indicatif's +// `MultiProgress`. Created at the start of a multi-branch graph super-step +// (or map sub-branch fan-out) and torn down at the join. +// +// When stdout isn't a terminal (CI, piped output), the tracker becomes a +// no-op — `add_branch` returns a disabled handle whose methods do nothing. +// This keeps machine-piped graph runs free of spinner garbage in their +// captured output. +pub(super) struct BranchProgressTracker { + multi: Option, +} + +impl BranchProgressTracker { + pub fn new() -> Self { + if *IS_STDOUT_TERMINAL { + Self { + multi: Some(MultiProgress::new()), + } + } else { + Self { multi: None } + } + } + + pub fn add_branch(&self, label: &str) -> BranchProgressHandle { + let Some(multi) = &self.multi else { + return BranchProgressHandle::disabled(); + }; + let bar = multi.add(ProgressBar::new_spinner()); + bar.set_style(SPINNER_STYLE.clone()); + bar.set_prefix(label.to_string()); + bar.set_message("running…"); + bar.enable_steady_tick(Duration::from_millis(80)); + BranchProgressHandle { + bar: Some(bar), + started: Instant::now(), + } + } + + pub fn clear(&self) { + if let Some(multi) = &self.multi { + let _ = multi.clear(); + } + } +} + +pub(super) struct BranchProgressHandle { + bar: Option, + started: Instant, +} + +impl BranchProgressHandle { + fn disabled() -> Self { + Self { + bar: None, + started: Instant::now(), + } + } + + pub fn complete(self) { + if let Some(bar) = self.bar { + let elapsed = self.started.elapsed(); + bar.finish_with_message(format!("✓ done ({:.1}s)", elapsed.as_secs_f64())); + } + } + + pub fn fail(self, err: &str) { + if let Some(bar) = self.bar { + let elapsed = self.started.elapsed(); + let truncated = if err.len() > 80 { + let mut s = err[..80].to_string(); + s.push('…'); + s + } else { + err.to_string() + }; + bar.finish_with_message(format!( + "✗ failed ({:.1}s) — {}", + elapsed.as_secs_f64(), + truncated + )); + } + } +} diff --git a/src/graph/reducer.rs b/src/graph/reducer.rs index 2d261a3..b634b97 100644 --- a/src/graph/reducer.rs +++ b/src/graph/reducer.rs @@ -1,7 +1,7 @@ use super::types::Reducer; +use crate::graph::type_name; use anyhow::{Result, bail}; use serde_json::{Number, Value}; -use crate::graph::type_name; /// Combines a branch's incoming write with the current state value (if any) /// via the specified reducer. The result is what gets written back to live