feat: implemented the frontier-based scheduling for the graph executor with simplified state management (gotta love .clone)
This commit is contained in:
+170
-45
@@ -3,6 +3,7 @@ use super::llm::LlmNodeExecutor;
|
||||
use super::logging::GraphLogger;
|
||||
use super::rag::RagNodeExecutor;
|
||||
use super::script::ScriptExecutor;
|
||||
use super::staging::BranchWrites;
|
||||
use super::state::StateManager;
|
||||
use super::types::{EndNode, Graph, Node, NodeType};
|
||||
use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
|
||||
@@ -10,11 +11,13 @@ use super::validator::{AgentValidationContext, GraphValidator};
|
||||
use crate::config::RequestContext;
|
||||
use crate::utils::AbortSignal;
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use futures_util::future::join_all;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
pub struct GraphExecutor {
|
||||
graph: Graph,
|
||||
@@ -70,74 +73,196 @@ impl GraphExecutor {
|
||||
let script_executor = ScriptExecutor::new(&base_dir);
|
||||
let max_iterations = graph.settings.max_loop_iterations;
|
||||
let graph_timeout = graph.settings.timeout.map(Duration::from_secs);
|
||||
let max_concurrency = graph.settings.max_concurrency;
|
||||
let start = Instant::now();
|
||||
|
||||
let mut current = graph.start.clone();
|
||||
logger.graph_start(¤t, graph.nodes.len());
|
||||
let mut frontier: HashSet<String> = HashSet::from([graph.start.clone()]);
|
||||
logger.graph_start(&graph.start, graph.nodes.len());
|
||||
|
||||
loop {
|
||||
if frontier.is_empty() {
|
||||
bail!(
|
||||
"Graph '{}' frontier emptied without reaching an End node",
|
||||
graph.name
|
||||
);
|
||||
}
|
||||
|
||||
let output = loop {
|
||||
if abort_signal.aborted() {
|
||||
bail!("Graph '{}' aborted at '{}'", graph.name, current);
|
||||
bail!(
|
||||
"Graph '{}' aborted before super-step with frontier {:?}",
|
||||
graph.name,
|
||||
sorted_frontier(&frontier)
|
||||
);
|
||||
}
|
||||
if let Some(t) = graph_timeout
|
||||
&& start.elapsed() > t
|
||||
{
|
||||
bail!(
|
||||
"Graph '{}' timed out after {}s at '{}'",
|
||||
"Graph '{}' timed out after {}s before super-step with frontier {:?}",
|
||||
graph.name,
|
||||
t.as_secs(),
|
||||
current
|
||||
sorted_frontier(&frontier)
|
||||
);
|
||||
}
|
||||
|
||||
state.state_mut().visit_node(¤t);
|
||||
let visits = state.state().loop_count(¤t);
|
||||
if visits > max_iterations {
|
||||
// Loop-count and visit tracking on live state, BEFORE forking.
|
||||
// This counts every entry to a node toward max_loop_iterations
|
||||
// regardless of how many parallel branches converged on it.
|
||||
for node_id in &frontier {
|
||||
state.state_mut().visit_node(node_id);
|
||||
let visits = state.state().loop_count(node_id);
|
||||
if visits > max_iterations {
|
||||
bail!(
|
||||
"Node '{}' visited {} times (max_loop_iterations={}). \
|
||||
Possible infinite loop.",
|
||||
node_id,
|
||||
visits,
|
||||
max_iterations
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for node_id in &frontier {
|
||||
let node = graph.get_node(node_id).ok_or_else(|| {
|
||||
anyhow!("Node '{}' not found in graph '{}'", node_id, graph.name)
|
||||
})?;
|
||||
let visits = state.state().loop_count(node_id);
|
||||
logger.node_entry(node, visits);
|
||||
}
|
||||
let snapshot_label = if frontier.len() == 1 {
|
||||
frontier.iter().next().cloned().unwrap_or_default()
|
||||
} else {
|
||||
format!("super-step {{{}}}", sorted_frontier(&frontier).join(","))
|
||||
};
|
||||
logger.state_snapshot(&snapshot_label, &state);
|
||||
|
||||
let snapshot = state.read_snapshot();
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrency));
|
||||
|
||||
let mut branch_tasks = Vec::with_capacity(frontier.len());
|
||||
for node_id in &frontier {
|
||||
let node = graph
|
||||
.get_node(node_id)
|
||||
.ok_or_else(|| {
|
||||
anyhow!("Node '{}' not found in graph '{}'", node_id, graph.name)
|
||||
})?
|
||||
.clone();
|
||||
let branch_state = state.fork_for_branch_state();
|
||||
let branch_ctx = ctx.fork_for_branch();
|
||||
let script_exec_clone = script_executor.clone();
|
||||
let graph_name = graph.name.clone();
|
||||
let current = node_id.clone();
|
||||
let sem_clone = semaphore.clone();
|
||||
let abort_clone = abort_signal.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let _permit = sem_clone
|
||||
.acquire()
|
||||
.await
|
||||
.expect("semaphore should not be closed");
|
||||
if abort_clone.aborted() {
|
||||
return (
|
||||
current.clone(),
|
||||
branch_state,
|
||||
Err(anyhow!("branch aborted")),
|
||||
Duration::default(),
|
||||
);
|
||||
}
|
||||
let node_start = Instant::now();
|
||||
let mut state = branch_state;
|
||||
let mut ctx = branch_ctx;
|
||||
let result = step(
|
||||
&node,
|
||||
&mut state,
|
||||
&mut ctx,
|
||||
&script_exec_clone,
|
||||
&graph_name,
|
||||
¤t,
|
||||
)
|
||||
.await;
|
||||
let elapsed = node_start.elapsed();
|
||||
(current, state, result, elapsed)
|
||||
});
|
||||
branch_tasks.push(task);
|
||||
}
|
||||
|
||||
let joined = join_all(branch_tasks).await;
|
||||
|
||||
let mut branch_writes: Vec<BranchWrites> = Vec::new();
|
||||
let mut next_frontier: HashSet<String> = HashSet::new();
|
||||
let mut end_results: Vec<(String, StateManager, String)> = Vec::new();
|
||||
|
||||
for join_result in joined {
|
||||
let (node_id, branch_state, step_result, elapsed) =
|
||||
join_result.map_err(|e| anyhow!("Branch task panicked: {e}"))?;
|
||||
logger.record_timing(&node_id, elapsed);
|
||||
|
||||
let step_outcome = step_result.with_context(|| format!("at node '{node_id}'"))?;
|
||||
|
||||
match step_outcome {
|
||||
StepResult::Continue(target) => {
|
||||
logger.routing(&node_id, &target);
|
||||
let diff = branch_state.diff_against(snapshot.as_ref());
|
||||
branch_writes.push(BranchWrites {
|
||||
node_id: node_id.clone(),
|
||||
invocation_index: 0,
|
||||
writes: diff,
|
||||
});
|
||||
next_frontier.insert(target);
|
||||
}
|
||||
StepResult::End(output) => {
|
||||
end_results.push((node_id.clone(), branch_state, output));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if end_results.len() > 1 {
|
||||
let mut ids: Vec<String> =
|
||||
end_results.iter().map(|(id, _, _)| id.clone()).collect();
|
||||
ids.sort();
|
||||
bail!(
|
||||
"Node '{}' visited {} times (max_loop_iterations={}). \
|
||||
Possible infinite loop.",
|
||||
current,
|
||||
visits,
|
||||
max_iterations
|
||||
"super-step ended with multiple End targets ({}). \
|
||||
Fan-out branches must converge at a join node before \
|
||||
terminating. To fix: route all parallel branches to a \
|
||||
single shared next-node, then terminate from there.",
|
||||
ids.join(", ")
|
||||
);
|
||||
}
|
||||
|
||||
let node = graph
|
||||
.get_node(¤t)
|
||||
.ok_or_else(|| anyhow!("Node '{}' not found in graph '{}'", current, graph.name))?;
|
||||
// Sort by (node_id, invocation_index) so non-commutative reducers
|
||||
// like Concat/Merge produce deterministic output across runs.
|
||||
branch_writes.sort_by(|a, b| {
|
||||
a.node_id
|
||||
.cmp(&b.node_id)
|
||||
.then(a.invocation_index.cmp(&b.invocation_index))
|
||||
});
|
||||
state.apply_branch_writes(branch_writes, &graph.reducers)?;
|
||||
|
||||
logger.node_entry(node, visits);
|
||||
logger.state_snapshot(¤t, &state);
|
||||
|
||||
let node_start = Instant::now();
|
||||
let step_result = step(
|
||||
node,
|
||||
&mut state,
|
||||
ctx,
|
||||
&script_executor,
|
||||
&graph.name,
|
||||
¤t,
|
||||
)
|
||||
.await;
|
||||
logger.record_timing(¤t, node_start.elapsed());
|
||||
let next = step_result.with_context(|| format!("at node '{current}'"))?;
|
||||
|
||||
match next {
|
||||
StepResult::Continue(next_id) => {
|
||||
logger.routing(¤t, &next_id);
|
||||
current = next_id;
|
||||
}
|
||||
StepResult::End(out) => {
|
||||
logger.graph_complete(¤t, start.elapsed());
|
||||
break out;
|
||||
}
|
||||
if let Some((node_id, end_state, output)) = end_results.into_iter().next() {
|
||||
let diff = end_state.diff_against(snapshot.as_ref());
|
||||
state.apply_branch_writes(
|
||||
vec![BranchWrites {
|
||||
node_id: node_id.clone(),
|
||||
invocation_index: 0,
|
||||
writes: diff,
|
||||
}],
|
||||
&graph.reducers,
|
||||
)?;
|
||||
logger.graph_complete(&node_id, start.elapsed());
|
||||
return Ok(output);
|
||||
}
|
||||
};
|
||||
|
||||
Ok(output)
|
||||
frontier = next_frontier;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sorted_frontier(frontier: &HashSet<String>) -> Vec<String> {
|
||||
let mut v: Vec<String> = frontier.iter().cloned().collect();
|
||||
v.sort();
|
||||
v
|
||||
}
|
||||
|
||||
enum StepResult {
|
||||
Continue(String),
|
||||
End(String),
|
||||
|
||||
Reference in New Issue
Block a user