From 7154c3a6525b028166a0561eb7fc0b2943ff7509 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Wed, 20 May 2026 15:15:58 -0600 Subject: [PATCH] feat: Full support for map node types --- src/config/request_context.rs | 1 - src/graph/executor.rs | 64 +++++++++---- src/graph/map.rs | 169 ++++++++++++++++++++++++++++++++++ src/graph/mod.rs | 13 +++ src/graph/reducer.rs | 12 +-- 5 files changed, 229 insertions(+), 30 deletions(-) create mode 100644 src/graph/map.rs diff --git a/src/config/request_context.rs b/src/config/request_context.rs index 6f3b0f5..26a8498 100644 --- a/src/config/request_context.rs +++ b/src/config/request_context.rs @@ -172,7 +172,6 @@ impl RequestContext { /// - `auto_continue_count` reset to 0 (each branch starts a fresh /// continuation budget) /// - `last_continuation_response` reset to None - #[allow(dead_code)] pub fn fork_for_branch(&self) -> Self { Self { app: Arc::clone(&self.app), diff --git a/src/graph/executor.rs b/src/graph/executor.rs index f5cb673..7466ba4 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -1,6 +1,7 @@ use super::agent::AgentNodeExecutor; use super::llm::LlmNodeExecutor; use super::logging::GraphLogger; +use super::map::MapNodeExecutor; use super::rag::RagNodeExecutor; use super::script::ScriptExecutor; use super::staging::BranchWrites; @@ -74,6 +75,10 @@ impl GraphExecutor { let max_iterations = graph.settings.max_loop_iterations; let graph_timeout = graph.settings.timeout.map(Duration::from_secs); let max_concurrency = graph.settings.max_concurrency; + // Wrap in Arc so spawned branch tasks can cheaply share the Graph for + // node lookup (especially the map executor, which needs to resolve its + // `branch:` target from inside a spawned task). + let graph = Arc::new(graph); let start = Instant::now(); let mut frontier: HashSet = HashSet::from([graph.start.clone()]); @@ -150,7 +155,7 @@ impl GraphExecutor { let branch_state = state.fork_for_branch_state(); let branch_ctx = ctx.fork_for_branch(); let script_exec_clone = script_executor.clone(); - let graph_name = graph.name.clone(); + let graph_clone = Arc::clone(&graph); let current = node_id.clone(); let sem_clone = semaphore.clone(); let abort_clone = abort_signal.clone(); @@ -171,15 +176,13 @@ impl GraphExecutor { let node_start = Instant::now(); let mut state = branch_state; let mut ctx = branch_ctx; - let result = step( - &node, - &mut state, - &mut ctx, - &script_exec_clone, - &graph_name, - ¤t, - ) - .await; + let step_ctx = StepContext { + graph: graph_clone.as_ref(), + script_executor: &script_exec_clone, + max_concurrency, + abort_signal: &abort_clone, + }; + let result = step(&node, &mut state, &mut ctx, &step_ctx, ¤t).await; let elapsed = node_start.elapsed(); (current, state, result, elapsed) }); @@ -263,6 +266,23 @@ fn sorted_frontier(frontier: &HashSet) -> Vec { v } +// Bundles the engine-config refs that every `step()` call needs to thread +// through. Constructed once per spawned branch task (or once at the call site +// for sequential paths) so step() and downstream executors (MapNodeExecutor) +// take one parameter instead of five. +pub(super) struct StepContext<'a> { + pub graph: &'a Graph, + pub script_executor: &'a ScriptExecutor, + pub max_concurrency: usize, + pub abort_signal: &'a AbortSignal, +} + +impl StepContext<'_> { + pub fn graph_name(&self) -> &str { + &self.graph.name + } +} + enum StepResult { Continue(String), End(String), @@ -272,8 +292,7 @@ async fn step( node: &Node, state: &mut StateManager, ctx: &mut RequestContext, - script_executor: &ScriptExecutor, - graph_name: &str, + step_ctx: &StepContext<'_>, current: &str, ) -> Result { match &node.node_type { @@ -288,13 +307,16 @@ async fn step( Ok(StepResult::Continue(next)) } NodeType::Script(script_node) => { - let dynamic = match script_executor.execute(script_node, state).await { + let dynamic = match step_ctx.script_executor.execute(script_node, state).await { Ok(n) => n, Err(e) => { if let Some(fallback) = &script_node.fallback { warn!( "[graph:{}] script '{}' failed, routing to fallback '{}': {}", - graph_name, current, fallback, e + step_ctx.graph_name(), + current, + fallback, + e ); return Ok(StepResult::Continue(fallback.clone())); } @@ -334,10 +356,16 @@ async fn step( Ok(StepResult::Continue(next)) } NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))), - NodeType::Map(_) => bail!( - "Map nodes are not yet supported in this build \ - (parallel branch execution lands in Phase D/E)." - ), + NodeType::Map(map_node) => { + let next = node + .next_single()? + .ok_or_else(|| { + anyhow!("map node '{current}' has no `next` and is not an end node") + })? + .to_string(); + MapNodeExecutor::execute(map_node, state, ctx, step_ctx, current).await?; + Ok(StepResult::Continue(next)) + } } } diff --git a/src/graph/map.rs b/src/graph/map.rs new file mode 100644 index 0000000..aa800fe --- /dev/null +++ b/src/graph/map.rs @@ -0,0 +1,169 @@ +use super::agent::AgentNodeExecutor; +use super::executor::StepContext; +use super::llm::LlmNodeExecutor; +use super::rag::RagNodeExecutor; +use super::state::StateManager; +use super::types::{MapNode, NodeType}; +use crate::config::RequestContext; +use anyhow::{Context, Result, anyhow}; +use futures_util::future::join_all; +use serde_json::Value; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::Semaphore; +use crate::graph::type_name; + +// Map sub-branches are atomic — the branch node has no `next` (enforced by +// validator rule C.5). But LLM/RAG node executors require an `Option<&str>` for +// their routing argument and error if it's `None`. Passing this sentinel +// satisfies their contract; the map executor discards the returned routing. +const MAP_BRANCH_SENTINEL_NEXT: &str = "__map_branch_continuation__"; + +pub(super) struct MapNodeExecutor; + +impl MapNodeExecutor { + pub(super) async fn execute( + node: &MapNode, + state: &mut StateManager, + ctx: &mut RequestContext, + step_ctx: &StepContext<'_>, + node_id: &str, + ) -> Result<()> { + let over_value = state + .interpolate_raw(&node.over) + .with_context(|| format!("map node '{node_id}': evaluating `over` template"))?; + + let items = over_value.as_array().ok_or_else(|| { + anyhow!( + "map node '{}': `over` template '{}' must resolve to an array, got {}", + node_id, + node.over, + type_name(&over_value) + ) + })?; + let items = items.clone(); + + let branch_node = step_ctx + .graph + .get_node(&node.branch) + .ok_or_else(|| { + anyhow!( + "map node '{node_id}': branch '{}' not found in graph", + node.branch + ) + })? + .clone(); + + let max_conc = node + .max_concurrency + .unwrap_or(step_ctx.max_concurrency) + .max(1); + let semaphore = Arc::new(Semaphore::new(max_conc)); + let mut sub_tasks = Vec::with_capacity(items.len()); + + for (idx, item) in items.iter().enumerate() { + let item = item.clone(); + let as_name = node.as_name.clone(); + let branch_clone = branch_node.clone(); + let mut sub_state = state.fork_for_branch_state(); + let sub_ctx = ctx.fork_for_branch(); + let script_clone = step_ctx.script_executor.clone(); + let sub_branch_id = node.branch.clone(); + let sem = semaphore.clone(); + let abort = step_ctx.abort_signal.clone(); + + sub_state.state_mut().set(as_name, item); + + let task = tokio::spawn(async move { + let _permit = sem + .acquire() + .await + .expect("map semaphore should not be closed"); + if abort.aborted() { + return ( + idx, + sub_state, + Err(anyhow!("map sub-branch [{idx}] aborted")), + ); + } + let mut state = sub_state; + let mut ctx = sub_ctx; + + let exec_result: Result<()> = match &branch_clone.node_type { + NodeType::Llm(n) => LlmNodeExecutor::execute( + n, + Some(MAP_BRANCH_SENTINEL_NEXT), + &mut state, + &mut ctx, + ) + .await + .map(|_| ()), + NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx) + .await + .map(|_| ()), + NodeType::Rag(n) => RagNodeExecutor::execute( + n, + &sub_branch_id, + Some(MAP_BRANCH_SENTINEL_NEXT), + &mut state, + &mut ctx, + ) + .await + .map(|_| ()), + NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()), + _ => Err(anyhow!( + "map branch '{}' has type that cannot run inside a map \ + (validator should have caught this; internal error)", + branch_clone.id + )), + }; + + (idx, state, exec_result) + }); + sub_tasks.push(task); + } + + let joined = join_all(sub_tasks).await; + + // Collect outputs keyed by input index so order is preserved regardless + // of finish order. This is the user-facing contract from plan E.2. + let mut outputs: HashMap = HashMap::new(); + for join_result in joined { + let (idx, sub_state, exec_result) = + join_result.map_err(|e| anyhow!("map sub-branch panicked: {e}"))?; + + exec_result + .with_context(|| format!("map node '{node_id}': sub-branch [{idx}] failed"))?; + + let output_value = sub_state + .state() + .get(&node.output_key) + .cloned() + .ok_or_else(|| { + anyhow!( + "map node '{node_id}': sub-branch [{idx}] did not write \ + `output_key` '{}'", + node.output_key + ) + })?; + + outputs.insert(idx, output_value); + } + + let mut collected = Vec::with_capacity(items.len()); + for idx in 0..items.len() { + let value = outputs.remove(&idx).ok_or_else(|| { + anyhow!( + "map node '{node_id}': internal error: missing result for sub-branch [{idx}]" + ) + })?; + collected.push(value); + } + + state + .state_mut() + .set(node.collect_into.clone(), Value::Array(collected)); + + Ok(()) + } +} diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 3268ed2..44fc381 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -3,6 +3,7 @@ pub mod dispatch; pub mod executor; pub mod llm; pub mod logging; +pub mod map; pub mod parser; pub mod rag; pub mod reducer; @@ -14,6 +15,7 @@ pub mod types; pub mod user_interaction; pub mod validator; +use serde_json::Value; pub use dispatch::{active_agent_graph_name, run_active_agent_graph}; pub use executor::GraphExecutor; pub use parser::{GraphParser, agent_has_graph}; @@ -24,3 +26,14 @@ pub const GRAPH_SCHEMA_VERSION: &str = "1.0"; pub const DEFAULT_MAX_LOOP_ITERATIONS: usize = 100; pub const MAX_STATE_SIZE_BYTES: usize = 32 * 1024; + +pub (in crate::graph) fn type_name(value: &Value) -> &'static str { + match value { + Value::Null => "null", + Value::Bool(_) => "bool", + Value::Number(_) => "number", + Value::String(_) => "string", + Value::Array(_) => "array", + Value::Object(_) => "object", + } +} diff --git a/src/graph/reducer.rs b/src/graph/reducer.rs index 1ceb058..2d261a3 100644 --- a/src/graph/reducer.rs +++ b/src/graph/reducer.rs @@ -1,6 +1,7 @@ use super::types::Reducer; use anyhow::{Result, bail}; use serde_json::{Number, Value}; +use crate::graph::type_name; /// Combines a branch's incoming write with the current state value (if any) /// via the specified reducer. The result is what gets written back to live @@ -147,17 +148,6 @@ fn number_or_error(value: &Value, reducer_name: &str, position: &str) -> Result< } } -fn type_name(value: &Value) -> &'static str { - match value { - Value::Null => "null", - Value::Bool(_) => "bool", - Value::Number(_) => "number", - Value::String(_) => "string", - Value::Array(_) => "array", - Value::Object(_) => "object", - } -} - // Numeric reducers compute in f64 for simplicity. We preserve integer typing // when the result is losslessly representable as i64 so `count: sum` stays an // integer rather than degrading to a float. Non-finite values (NaN, Inf) can't