feat: added branch progress tracker for better visualization of parallel graph super-steps
This commit is contained in:
@@ -2,6 +2,7 @@ use super::agent::AgentNodeExecutor;
|
||||
use super::llm::LlmNodeExecutor;
|
||||
use super::logging::GraphLogger;
|
||||
use super::map::MapNodeExecutor;
|
||||
use super::progress::{BranchProgressHandle, BranchProgressTracker};
|
||||
use super::rag::RagNodeExecutor;
|
||||
use super::script::ScriptExecutor;
|
||||
use super::staging::BranchWrites;
|
||||
@@ -145,6 +146,11 @@ impl GraphExecutor {
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrency));
|
||||
|
||||
let frontier_size = frontier.len();
|
||||
let progress_tracker = if frontier_size > 1 {
|
||||
Some(BranchProgressTracker::new())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut branch_tasks = Vec::with_capacity(frontier_size);
|
||||
for node_id in &frontier {
|
||||
let node = graph
|
||||
@@ -163,13 +169,19 @@ impl GraphExecutor {
|
||||
let current = node_id.clone();
|
||||
let sem_clone = semaphore.clone();
|
||||
let abort_clone = abort_signal.clone();
|
||||
let progress_handle: Option<BranchProgressHandle> =
|
||||
progress_tracker.as_ref().map(|t| t.add_branch(node_id));
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let mut progress_handle = progress_handle;
|
||||
let _permit = sem_clone
|
||||
.acquire()
|
||||
.await
|
||||
.expect("semaphore should not be closed");
|
||||
if abort_clone.aborted() {
|
||||
if let Some(h) = progress_handle.take() {
|
||||
h.fail("aborted");
|
||||
}
|
||||
return (
|
||||
current.clone(),
|
||||
branch_state,
|
||||
@@ -188,12 +200,21 @@ impl GraphExecutor {
|
||||
};
|
||||
let result = step(&node, &mut state, &mut ctx, &step_ctx, ¤t).await;
|
||||
let elapsed = node_start.elapsed();
|
||||
if let Some(h) = progress_handle.take() {
|
||||
match &result {
|
||||
Ok(_) => h.complete(),
|
||||
Err(e) => h.fail(&e.to_string()),
|
||||
}
|
||||
}
|
||||
(current, state, result, elapsed)
|
||||
});
|
||||
branch_tasks.push(task);
|
||||
}
|
||||
|
||||
let joined = join_all(branch_tasks).await;
|
||||
if let Some(t) = &progress_tracker {
|
||||
t.clear();
|
||||
}
|
||||
|
||||
let mut branch_writes: Vec<BranchWrites> = Vec::new();
|
||||
let mut next_frontier: HashSet<String> = HashSet::new();
|
||||
|
||||
Reference in New Issue
Block a user