feat: improved UX for parallel graph execution
This commit is contained in:
+19
-15
@@ -1,6 +1,6 @@
|
||||
use super::agent::AgentNodeExecutor;
|
||||
use super::llm::{LlmExecutionOutcome, LlmNodeExecutor};
|
||||
use super::logging::GraphLogger;
|
||||
use super::logging::{GraphLogger, node_type_label};
|
||||
use super::map::MapNodeExecutor;
|
||||
use super::progress::{BranchProgressHandle, BranchProgressTracker};
|
||||
use super::rag::RagNodeExecutor;
|
||||
@@ -146,11 +146,12 @@ impl GraphExecutor {
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrency));
|
||||
|
||||
let frontier_size = frontier.len();
|
||||
let progress_tracker = if frontier_size > 1 {
|
||||
Some(BranchProgressTracker::new())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let has_progress_nodes = frontier.iter().any(|nid| {
|
||||
graph.get_node(nid).is_some_and(|n| {
|
||||
!matches!(n.node_type, NodeType::Approval(_) | NodeType::Input(_))
|
||||
})
|
||||
});
|
||||
let progress_tracker = has_progress_nodes.then(BranchProgressTracker::new);
|
||||
let mut branch_tasks = Vec::with_capacity(frontier_size);
|
||||
for node_id in &frontier {
|
||||
let node = graph
|
||||
@@ -161,19 +162,24 @@ impl GraphExecutor {
|
||||
.clone();
|
||||
let branch_state = state.fork_for_branch_state();
|
||||
let mut branch_ctx = ctx.fork_for_branch();
|
||||
if frontier_size > 1 {
|
||||
branch_ctx.render_mode = RenderMode::Silent;
|
||||
}
|
||||
branch_ctx.render_mode = RenderMode::Silent;
|
||||
let script_exec_clone = script_executor.clone();
|
||||
let graph_clone = Arc::clone(&graph);
|
||||
let current = node_id.clone();
|
||||
let sem_clone = semaphore.clone();
|
||||
let abort_clone = abort_signal.clone();
|
||||
let progress_handle: Option<BranchProgressHandle> =
|
||||
progress_tracker.as_ref().map(|t| t.add_branch(node_id));
|
||||
let progress_handle = match (
|
||||
matches!(node.node_type, NodeType::Approval(_) | NodeType::Input(_)),
|
||||
&progress_tracker,
|
||||
) {
|
||||
(false, Some(tracker)) => {
|
||||
tracker.add_branch(&format!("{} ({})", node_id, node_type_label(&node)))
|
||||
}
|
||||
_ => BranchProgressHandle::disabled(),
|
||||
};
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let mut progress_handle = progress_handle;
|
||||
let mut progress_handle = Some(progress_handle);
|
||||
let _permit = sem_clone
|
||||
.acquire()
|
||||
.await
|
||||
@@ -212,9 +218,7 @@ impl GraphExecutor {
|
||||
}
|
||||
|
||||
let joined = join_all(branch_tasks).await;
|
||||
if let Some(t) = &progress_tracker {
|
||||
t.clear();
|
||||
}
|
||||
drop(progress_tracker);
|
||||
|
||||
let mut branch_writes: Vec<BranchWrites> = Vec::new();
|
||||
let mut next_frontier: HashSet<String> = HashSet::new();
|
||||
|
||||
Reference in New Issue
Block a user