feat: added branch progress tracker for better visualization of parallel graph super-steps

This commit is contained in:
2026-05-20 15:50:38 -06:00
parent f32608169d
commit 76ee1ec7f1
7 changed files with 180 additions and 14 deletions
+21
View File
@@ -2,6 +2,7 @@ use super::agent::AgentNodeExecutor;
use super::llm::LlmNodeExecutor;
use super::logging::GraphLogger;
use super::map::MapNodeExecutor;
use super::progress::{BranchProgressHandle, BranchProgressTracker};
use super::rag::RagNodeExecutor;
use super::script::ScriptExecutor;
use super::staging::BranchWrites;
@@ -145,6 +146,11 @@ impl GraphExecutor {
let semaphore = Arc::new(Semaphore::new(max_concurrency));
let frontier_size = frontier.len();
let progress_tracker = if frontier_size > 1 {
Some(BranchProgressTracker::new())
} else {
None
};
let mut branch_tasks = Vec::with_capacity(frontier_size);
for node_id in &frontier {
let node = graph
@@ -163,13 +169,19 @@ impl GraphExecutor {
let current = node_id.clone();
let sem_clone = semaphore.clone();
let abort_clone = abort_signal.clone();
let progress_handle: Option<BranchProgressHandle> =
progress_tracker.as_ref().map(|t| t.add_branch(node_id));
let task = tokio::spawn(async move {
let mut progress_handle = progress_handle;
let _permit = sem_clone
.acquire()
.await
.expect("semaphore should not be closed");
if abort_clone.aborted() {
if let Some(h) = progress_handle.take() {
h.fail("aborted");
}
return (
current.clone(),
branch_state,
@@ -188,12 +200,21 @@ impl GraphExecutor {
};
let result = step(&node, &mut state, &mut ctx, &step_ctx, &current).await;
let elapsed = node_start.elapsed();
if let Some(h) = progress_handle.take() {
match &result {
Ok(_) => h.complete(),
Err(e) => h.fail(&e.to_string()),
}
}
(current, state, result, elapsed)
});
branch_tasks.push(task);
}
let joined = join_all(branch_tasks).await;
if let Some(t) = &progress_tracker {
t.clear();
}
let mut branch_writes: Vec<BranchWrites> = Vec::new();
let mut next_frontier: HashSet<String> = HashSet::new();