From de2a8dcf89ba730639f8afa0e2f35ab7535a0d53 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Wed, 20 May 2026 13:48:55 -0600 Subject: [PATCH] feat: implemented the frontier-based scheduling for the graph executor with simplified state management (gotta love .clone) --- src/config/request_context.rs | 50 ++++++++ src/config/tool_scope.rs | 3 +- src/graph/executor.rs | 215 +++++++++++++++++++++++++++------- src/graph/script.rs | 1 + src/graph/staging.rs | 82 ------------- src/graph/state.rs | 126 ++++++++++++++++++-- 6 files changed, 341 insertions(+), 136 deletions(-) diff --git a/src/config/request_context.rs b/src/config/request_context.rs index a3a7566..6f3b0f5 100644 --- a/src/config/request_context.rs +++ b/src/config/request_context.rs @@ -149,6 +149,56 @@ impl RequestContext { }) } + /// Forks the context for one parallel branch of a graph super-step. + /// + /// Each branch gets a fresh, owned clone — mutations (role swap, + /// `before/after_chat_completion`, tool tracker, last_message, etc.) are + /// scoped to the branch and discarded when the branch finishes. The + /// user-visible state communication happens through the graph's + /// `StateManager` (via `fork_for_branch_state` + `diff_against` + + /// `apply_branch_writes` reducers), NOT through `RequestContext`. + /// + /// Distinction from `new_for_child`: `new_for_child` builds a fresh context + /// for a SPAWNED SUB-AGENT (different agent identity, different supervisor + /// hierarchy, depth+1, fresh tool tracker). `fork_for_branch` keeps the + /// caller's identity and supervisor hierarchy — it's a sibling clone of the + /// SAME logical agent, running one of N parallel work items. + /// + /// Behavior of per-field cloning: + /// - `Arc`-wrapped fields (`app`, `rag`, `supervisor`, `parent_supervisor`, + /// `inbox`, `escalation_queue`) — shared via Arc::clone + /// - Owned heap fields (`model`, `role`, `session`, `agent`, `tool_scope`, + /// `todo_list`, etc.) — deep `.clone()` so the branch can mutate freely + /// - `auto_continue_count` reset to 0 (each branch starts a fresh + /// continuation budget) + /// - `last_continuation_response` reset to None + #[allow(dead_code)] + pub fn fork_for_branch(&self) -> Self { + Self { + app: Arc::clone(&self.app), + macro_flag: self.macro_flag, + info_flag: self.info_flag, + working_mode: self.working_mode, + model: self.model.clone(), + agent_variables: self.agent_variables.clone(), + role: self.role.clone(), + session: self.session.clone(), + rag: self.rag.clone(), + agent: self.agent.clone(), + last_message: self.last_message.clone(), + tool_scope: self.tool_scope.clone(), + supervisor: self.supervisor.clone(), + parent_supervisor: self.parent_supervisor.clone(), + self_agent_id: self.self_agent_id.clone(), + inbox: self.inbox.clone(), + escalation_queue: self.escalation_queue.clone(), + current_depth: self.current_depth, + auto_continue_count: 0, + todo_list: self.todo_list.clone(), + last_continuation_response: None, + } + } + pub fn new_for_child( app: Arc, parent: &Self, diff --git a/src/config/tool_scope.rs b/src/config/tool_scope.rs index b61c982..9836ca5 100644 --- a/src/config/tool_scope.rs +++ b/src/config/tool_scope.rs @@ -8,6 +8,7 @@ use serde_json::{Value, json}; use std::collections::HashMap; use std::sync::Arc; +#[derive(Clone)] pub struct ToolScope { pub functions: Functions, pub mcp_runtime: McpRuntime, @@ -24,7 +25,7 @@ impl Default for ToolScope { } } -#[derive(Default)] +#[derive(Default, Clone)] pub struct McpRuntime { pub servers: HashMap>, } diff --git a/src/graph/executor.rs b/src/graph/executor.rs index 4fa59f2..f5cb673 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -3,6 +3,7 @@ use super::llm::LlmNodeExecutor; use super::logging::GraphLogger; use super::rag::RagNodeExecutor; use super::script::ScriptExecutor; +use super::staging::BranchWrites; use super::state::StateManager; use super::types::{EndNode, Graph, Node, NodeType}; use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor}; @@ -10,11 +11,13 @@ use super::validator::{AgentValidationContext, GraphValidator}; use crate::config::RequestContext; use crate::utils::AbortSignal; use anyhow::{Context, Result, anyhow, bail}; +use futures_util::future::join_all; use serde_json::Value; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; +use tokio::sync::Semaphore; pub struct GraphExecutor { graph: Graph, @@ -70,74 +73,196 @@ impl GraphExecutor { let script_executor = ScriptExecutor::new(&base_dir); let max_iterations = graph.settings.max_loop_iterations; let graph_timeout = graph.settings.timeout.map(Duration::from_secs); + let max_concurrency = graph.settings.max_concurrency; let start = Instant::now(); - let mut current = graph.start.clone(); - logger.graph_start(¤t, graph.nodes.len()); + let mut frontier: HashSet = HashSet::from([graph.start.clone()]); + logger.graph_start(&graph.start, graph.nodes.len()); + + loop { + if frontier.is_empty() { + bail!( + "Graph '{}' frontier emptied without reaching an End node", + graph.name + ); + } - let output = loop { if abort_signal.aborted() { - bail!("Graph '{}' aborted at '{}'", graph.name, current); + bail!( + "Graph '{}' aborted before super-step with frontier {:?}", + graph.name, + sorted_frontier(&frontier) + ); } if let Some(t) = graph_timeout && start.elapsed() > t { bail!( - "Graph '{}' timed out after {}s at '{}'", + "Graph '{}' timed out after {}s before super-step with frontier {:?}", graph.name, t.as_secs(), - current + sorted_frontier(&frontier) ); } - state.state_mut().visit_node(¤t); - let visits = state.state().loop_count(¤t); - if visits > max_iterations { + // Loop-count and visit tracking on live state, BEFORE forking. + // This counts every entry to a node toward max_loop_iterations + // regardless of how many parallel branches converged on it. + for node_id in &frontier { + state.state_mut().visit_node(node_id); + let visits = state.state().loop_count(node_id); + if visits > max_iterations { + bail!( + "Node '{}' visited {} times (max_loop_iterations={}). \ + Possible infinite loop.", + node_id, + visits, + max_iterations + ); + } + } + + for node_id in &frontier { + let node = graph.get_node(node_id).ok_or_else(|| { + anyhow!("Node '{}' not found in graph '{}'", node_id, graph.name) + })?; + let visits = state.state().loop_count(node_id); + logger.node_entry(node, visits); + } + let snapshot_label = if frontier.len() == 1 { + frontier.iter().next().cloned().unwrap_or_default() + } else { + format!("super-step {{{}}}", sorted_frontier(&frontier).join(",")) + }; + logger.state_snapshot(&snapshot_label, &state); + + let snapshot = state.read_snapshot(); + let semaphore = Arc::new(Semaphore::new(max_concurrency)); + + let mut branch_tasks = Vec::with_capacity(frontier.len()); + for node_id in &frontier { + let node = graph + .get_node(node_id) + .ok_or_else(|| { + anyhow!("Node '{}' not found in graph '{}'", node_id, graph.name) + })? + .clone(); + let branch_state = state.fork_for_branch_state(); + let branch_ctx = ctx.fork_for_branch(); + let script_exec_clone = script_executor.clone(); + let graph_name = graph.name.clone(); + let current = node_id.clone(); + let sem_clone = semaphore.clone(); + let abort_clone = abort_signal.clone(); + + let task = tokio::spawn(async move { + let _permit = sem_clone + .acquire() + .await + .expect("semaphore should not be closed"); + if abort_clone.aborted() { + return ( + current.clone(), + branch_state, + Err(anyhow!("branch aborted")), + Duration::default(), + ); + } + let node_start = Instant::now(); + let mut state = branch_state; + let mut ctx = branch_ctx; + let result = step( + &node, + &mut state, + &mut ctx, + &script_exec_clone, + &graph_name, + ¤t, + ) + .await; + let elapsed = node_start.elapsed(); + (current, state, result, elapsed) + }); + branch_tasks.push(task); + } + + let joined = join_all(branch_tasks).await; + + let mut branch_writes: Vec = Vec::new(); + let mut next_frontier: HashSet = HashSet::new(); + let mut end_results: Vec<(String, StateManager, String)> = Vec::new(); + + for join_result in joined { + let (node_id, branch_state, step_result, elapsed) = + join_result.map_err(|e| anyhow!("Branch task panicked: {e}"))?; + logger.record_timing(&node_id, elapsed); + + let step_outcome = step_result.with_context(|| format!("at node '{node_id}'"))?; + + match step_outcome { + StepResult::Continue(target) => { + logger.routing(&node_id, &target); + let diff = branch_state.diff_against(snapshot.as_ref()); + branch_writes.push(BranchWrites { + node_id: node_id.clone(), + invocation_index: 0, + writes: diff, + }); + next_frontier.insert(target); + } + StepResult::End(output) => { + end_results.push((node_id.clone(), branch_state, output)); + } + } + } + + if end_results.len() > 1 { + let mut ids: Vec = + end_results.iter().map(|(id, _, _)| id.clone()).collect(); + ids.sort(); bail!( - "Node '{}' visited {} times (max_loop_iterations={}). \ - Possible infinite loop.", - current, - visits, - max_iterations + "super-step ended with multiple End targets ({}). \ + Fan-out branches must converge at a join node before \ + terminating. To fix: route all parallel branches to a \ + single shared next-node, then terminate from there.", + ids.join(", ") ); } - let node = graph - .get_node(¤t) - .ok_or_else(|| anyhow!("Node '{}' not found in graph '{}'", current, graph.name))?; + // Sort by (node_id, invocation_index) so non-commutative reducers + // like Concat/Merge produce deterministic output across runs. + branch_writes.sort_by(|a, b| { + a.node_id + .cmp(&b.node_id) + .then(a.invocation_index.cmp(&b.invocation_index)) + }); + state.apply_branch_writes(branch_writes, &graph.reducers)?; - logger.node_entry(node, visits); - logger.state_snapshot(¤t, &state); - - let node_start = Instant::now(); - let step_result = step( - node, - &mut state, - ctx, - &script_executor, - &graph.name, - ¤t, - ) - .await; - logger.record_timing(¤t, node_start.elapsed()); - let next = step_result.with_context(|| format!("at node '{current}'"))?; - - match next { - StepResult::Continue(next_id) => { - logger.routing(¤t, &next_id); - current = next_id; - } - StepResult::End(out) => { - logger.graph_complete(¤t, start.elapsed()); - break out; - } + if let Some((node_id, end_state, output)) = end_results.into_iter().next() { + let diff = end_state.diff_against(snapshot.as_ref()); + state.apply_branch_writes( + vec![BranchWrites { + node_id: node_id.clone(), + invocation_index: 0, + writes: diff, + }], + &graph.reducers, + )?; + logger.graph_complete(&node_id, start.elapsed()); + return Ok(output); } - }; - Ok(output) + frontier = next_frontier; + } } } +fn sorted_frontier(frontier: &HashSet) -> Vec { + let mut v: Vec = frontier.iter().cloned().collect(); + v.sort(); + v +} + enum StepResult { Continue(String), End(String), diff --git a/src/graph/script.rs b/src/graph/script.rs index 8e82aab..91ebd41 100644 --- a/src/graph/script.rs +++ b/src/graph/script.rs @@ -10,6 +10,7 @@ use std::time::Duration; use tokio::process::Command; use tokio::time::timeout; +#[derive(Clone)] pub struct ScriptExecutor { base_dir: PathBuf, } diff --git a/src/graph/staging.rs b/src/graph/staging.rs index c89631c..097c7f2 100644 --- a/src/graph/staging.rs +++ b/src/graph/staging.rs @@ -1,38 +1,6 @@ use serde_json::Value; use std::collections::HashMap; -#[derive(Debug, Default, Clone)] -pub struct StagingArea { - writes: HashMap, -} - -#[allow(dead_code)] -impl StagingArea { - pub fn new() -> Self { - Self::default() - } - - pub fn write(&mut self, key: impl Into, value: Value) { - self.writes.insert(key.into(), value); - } - - pub fn get(&self, key: &str) -> Option<&Value> { - self.writes.get(key) - } - - pub fn is_empty(&self) -> bool { - self.writes.is_empty() - } - - pub fn len(&self) -> usize { - self.writes.len() - } - - pub fn into_writes(self) -> HashMap { - self.writes - } -} - /// Published form of one branch's writes for the super-step merge phase. /// Callers assemble these into a deterministically-ordered `Vec` keyed by /// `(node_id, invocation_index)` before passing to @@ -40,58 +8,8 @@ impl StagingArea { /// branches and the input-list position for map sub-branches — so multiple /// invocations of the same `branch:` node by a `map` are still totally ordered. #[derive(Debug, Clone)] -#[allow(dead_code)] pub struct BranchWrites { pub node_id: String, pub invocation_index: usize, pub writes: HashMap, } - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn new_staging_area_is_empty() { - let s = StagingArea::new(); - - assert!(s.is_empty()); - assert_eq!(s.len(), 0); - } - - #[test] - fn write_stores_value_under_key() { - let mut s = StagingArea::new(); - - s.write("key", json!("value")); - - assert_eq!(s.get("key"), Some(&json!("value"))); - assert_eq!(s.len(), 1); - assert!(!s.is_empty()); - } - - #[test] - fn write_overwrites_existing_key() { - let mut s = StagingArea::new(); - - s.write("k", json!(1)); - s.write("k", json!(2)); - - assert_eq!(s.get("k"), Some(&json!(2))); - assert_eq!(s.len(), 1); - } - - #[test] - fn into_writes_consumes_and_yields_map() { - let mut s = StagingArea::new(); - s.write("a", json!(1)); - s.write("b", json!(2)); - - let writes = s.into_writes(); - - assert_eq!(writes.len(), 2); - assert_eq!(writes.get("a"), Some(&json!(1))); - assert_eq!(writes.get("b"), Some(&json!(2))); - } -} diff --git a/src/graph/state.rs b/src/graph/state.rs index a7d6f8b..d242273 100644 --- a/src/graph/state.rs +++ b/src/graph/state.rs @@ -159,13 +159,44 @@ impl StateManager { } } - /// Returns an `Arc`-wrapped snapshot of the current graph state. Each branch - /// in a parallel super-step shares this snapshot for reads; their writes - /// accumulate into per-branch `StagingArea` instances, which are merged via - /// `apply_branch_writes` at the end of the super-step. + /// Forks state for a parallel branch: returns a fully-owned `StateManager` + /// seeded from the current state's data. The branch mutates its fork + /// freely; callers extract its writes via `diff_against` after the branch + /// completes, then merge them via `apply_branch_writes`. /// - /// Distinct from the older `snapshot()` method (returns a `HashMap` clone of - /// the data only — used by `script_executor` to ship state to child processes). + /// Distinct from `read_snapshot` (returns a shared `Arc` for + /// reads) — `fork_for_branch_state` returns a writable owned clone. + pub fn fork_for_branch_state(&self) -> Self { + Self { + state: self.state.clone(), + temp_file: None, + } + } + + /// Returns the keys whose values differ from `snapshot`. Use this after a + /// branch finishes to extract its writes (input to `apply_branch_writes`). + /// Keys present in `self` but absent from `snapshot`, or with different + /// values, count as writes. Deletions are not represented (no current node + /// executor deletes state). + pub fn diff_against(&self, snapshot: &GraphState) -> HashMap { + let mut diff = HashMap::new(); + for (k, v) in self.state.data() { + if snapshot.get(k) != Some(v) { + diff.insert(k.clone(), v.clone()); + } + } + diff + } + + /// Returns an `Arc`-wrapped snapshot of the current graph state. Each + /// branch in a parallel super-step uses this snapshot as the baseline for + /// its `diff_against` call at branch end. The executor extracts each + /// branch's writes (the diff) and merges them via `apply_branch_writes` at + /// the super-step boundary. + /// + /// Distinct from the older `snapshot()` method (returns a `HashMap` clone + /// of the data only — used by `script_executor` to ship state to child + /// processes). #[allow(dead_code)] pub fn read_snapshot(&self) -> Arc { Arc::new(self.state.clone()) @@ -936,12 +967,91 @@ mod tests { #[test] fn interpolate_raw_inner_spaces_treated_as_mixed() { let manager = manager_with(&[("k", json!("v"))]); - // `{{ k }}` is not a valid pure reference (spaces inside braces are // outside the allowed character set). Fall back to string interpolation // -- which doesn't match the regex either, so the literal passes through. let result = manager.interpolate_raw("{{ k }}").unwrap(); - assert_eq!(result, json!("{{ k }}")); } + + #[test] + fn fork_for_branch_state_copies_data() { + let parent = manager_with(&[("a", json!(1)), ("b", json!("x"))]); + + let fork = parent.fork_for_branch_state(); + + assert_eq!(fork.state().get("a"), Some(&json!(1))); + assert_eq!(fork.state().get("b"), Some(&json!("x"))); + } + + #[test] + fn fork_for_branch_state_isolates_writes_from_parent() { + let parent = manager_with(&[("count", json!(10))]); + let mut fork = parent.fork_for_branch_state(); + + fork.state_mut().set("count".into(), json!(999)); + + assert_eq!(fork.state().get("count"), Some(&json!(999))); + assert_eq!(parent.state().get("count"), Some(&json!(10))); + } + + #[test] + fn fork_for_branch_state_does_not_share_temp_file_lifecycle() { + let parent = manager_with(&[("k", json!("v"))]); + let fork = parent.fork_for_branch_state(); + + assert!(fork.temp_file.is_none()); + // Dropping the fork must not affect the parent's data + drop(fork); + assert_eq!(parent.state().get("k"), Some(&json!("v"))); + } + + #[test] + fn diff_against_returns_empty_when_unchanged() { + let original = manager_with(&[("a", json!(1)), ("b", json!(2))]); + let fork = original.fork_for_branch_state(); + + let diff = fork.diff_against(original.state()); + + assert!(diff.is_empty()); + } + + #[test] + fn diff_against_reports_newly_written_keys() { + let original = manager_with(&[]); + let mut fork = original.fork_for_branch_state(); + fork.state_mut().set("new".into(), json!(42)); + + let diff = fork.diff_against(original.state()); + + assert_eq!(diff.len(), 1); + assert_eq!(diff.get("new"), Some(&json!(42))); + } + + #[test] + fn diff_against_reports_changed_values_only() { + let original = manager_with(&[("a", json!(1)), ("b", json!(2)), ("c", json!(3))]); + let mut fork = original.fork_for_branch_state(); + fork.state_mut().set("b".into(), json!(99)); + + let diff = fork.diff_against(original.state()); + + assert_eq!(diff.len(), 1); + assert_eq!(diff.get("b"), Some(&json!(99))); + assert!(!diff.contains_key("a")); + assert!(!diff.contains_key("c")); + } + + #[test] + fn diff_against_does_not_report_reverted_writes() { + // Branch writes then writes back to the original value; net change = 0. + let original = manager_with(&[("x", json!("initial"))]); + let mut fork = original.fork_for_branch_state(); + fork.state_mut().set("x".into(), json!("modified")); + fork.state_mut().set("x".into(), json!("initial")); + + let diff = fork.diff_against(original.state()); + + assert!(diff.is_empty(), "reverted write should not appear in diff"); + } }