feat: Full support for map node types

This commit is contained in:
2026-05-20 15:15:58 -06:00
parent 36ac924d77
commit 7154c3a652
5 changed files with 229 additions and 30 deletions
-1
View File
@@ -172,7 +172,6 @@ impl RequestContext {
/// - `auto_continue_count` reset to 0 (each branch starts a fresh
/// continuation budget)
/// - `last_continuation_response` reset to None
#[allow(dead_code)]
pub fn fork_for_branch(&self) -> Self {
Self {
app: Arc::clone(&self.app),
+46 -18
View File
@@ -1,6 +1,7 @@
use super::agent::AgentNodeExecutor;
use super::llm::LlmNodeExecutor;
use super::logging::GraphLogger;
use super::map::MapNodeExecutor;
use super::rag::RagNodeExecutor;
use super::script::ScriptExecutor;
use super::staging::BranchWrites;
@@ -74,6 +75,10 @@ impl GraphExecutor {
let max_iterations = graph.settings.max_loop_iterations;
let graph_timeout = graph.settings.timeout.map(Duration::from_secs);
let max_concurrency = graph.settings.max_concurrency;
// Wrap in Arc so spawned branch tasks can cheaply share the Graph for
// node lookup (especially the map executor, which needs to resolve its
// `branch:` target from inside a spawned task).
let graph = Arc::new(graph);
let start = Instant::now();
let mut frontier: HashSet<String> = HashSet::from([graph.start.clone()]);
@@ -150,7 +155,7 @@ impl GraphExecutor {
let branch_state = state.fork_for_branch_state();
let branch_ctx = ctx.fork_for_branch();
let script_exec_clone = script_executor.clone();
let graph_name = graph.name.clone();
let graph_clone = Arc::clone(&graph);
let current = node_id.clone();
let sem_clone = semaphore.clone();
let abort_clone = abort_signal.clone();
@@ -171,15 +176,13 @@ impl GraphExecutor {
let node_start = Instant::now();
let mut state = branch_state;
let mut ctx = branch_ctx;
let result = step(
&node,
&mut state,
&mut ctx,
&script_exec_clone,
&graph_name,
&current,
)
.await;
let step_ctx = StepContext {
graph: graph_clone.as_ref(),
script_executor: &script_exec_clone,
max_concurrency,
abort_signal: &abort_clone,
};
let result = step(&node, &mut state, &mut ctx, &step_ctx, &current).await;
let elapsed = node_start.elapsed();
(current, state, result, elapsed)
});
@@ -263,6 +266,23 @@ fn sorted_frontier(frontier: &HashSet<String>) -> Vec<String> {
v
}
// Bundles the engine-config refs that every `step()` call needs to thread
// through. Constructed once per spawned branch task (or once at the call site
// for sequential paths) so step() and downstream executors (MapNodeExecutor)
// take one parameter instead of five.
pub(super) struct StepContext<'a> {
pub graph: &'a Graph,
pub script_executor: &'a ScriptExecutor,
pub max_concurrency: usize,
pub abort_signal: &'a AbortSignal,
}
impl StepContext<'_> {
pub fn graph_name(&self) -> &str {
&self.graph.name
}
}
enum StepResult {
Continue(String),
End(String),
@@ -272,8 +292,7 @@ async fn step(
node: &Node,
state: &mut StateManager,
ctx: &mut RequestContext,
script_executor: &ScriptExecutor,
graph_name: &str,
step_ctx: &StepContext<'_>,
current: &str,
) -> Result<StepResult> {
match &node.node_type {
@@ -288,13 +307,16 @@ async fn step(
Ok(StepResult::Continue(next))
}
NodeType::Script(script_node) => {
let dynamic = match script_executor.execute(script_node, state).await {
let dynamic = match step_ctx.script_executor.execute(script_node, state).await {
Ok(n) => n,
Err(e) => {
if let Some(fallback) = &script_node.fallback {
warn!(
"[graph:{}] script '{}' failed, routing to fallback '{}': {}",
graph_name, current, fallback, e
step_ctx.graph_name(),
current,
fallback,
e
);
return Ok(StepResult::Continue(fallback.clone()));
}
@@ -334,10 +356,16 @@ async fn step(
Ok(StepResult::Continue(next))
}
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
NodeType::Map(_) => bail!(
"Map nodes are not yet supported in this build \
(parallel branch execution lands in Phase D/E)."
),
NodeType::Map(map_node) => {
let next = node
.next_single()?
.ok_or_else(|| {
anyhow!("map node '{current}' has no `next` and is not an end node")
})?
.to_string();
MapNodeExecutor::execute(map_node, state, ctx, step_ctx, current).await?;
Ok(StepResult::Continue(next))
}
}
}
+169
View File
@@ -0,0 +1,169 @@
use super::agent::AgentNodeExecutor;
use super::executor::StepContext;
use super::llm::LlmNodeExecutor;
use super::rag::RagNodeExecutor;
use super::state::StateManager;
use super::types::{MapNode, NodeType};
use crate::config::RequestContext;
use anyhow::{Context, Result, anyhow};
use futures_util::future::join_all;
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::graph::type_name;
// Map sub-branches are atomic — the branch node has no `next` (enforced by
// validator rule C.5). But LLM/RAG node executors require an `Option<&str>` for
// their routing argument and error if it's `None`. Passing this sentinel
// satisfies their contract; the map executor discards the returned routing.
const MAP_BRANCH_SENTINEL_NEXT: &str = "__map_branch_continuation__";
pub(super) struct MapNodeExecutor;
impl MapNodeExecutor {
pub(super) async fn execute(
node: &MapNode,
state: &mut StateManager,
ctx: &mut RequestContext,
step_ctx: &StepContext<'_>,
node_id: &str,
) -> Result<()> {
let over_value = state
.interpolate_raw(&node.over)
.with_context(|| format!("map node '{node_id}': evaluating `over` template"))?;
let items = over_value.as_array().ok_or_else(|| {
anyhow!(
"map node '{}': `over` template '{}' must resolve to an array, got {}",
node_id,
node.over,
type_name(&over_value)
)
})?;
let items = items.clone();
let branch_node = step_ctx
.graph
.get_node(&node.branch)
.ok_or_else(|| {
anyhow!(
"map node '{node_id}': branch '{}' not found in graph",
node.branch
)
})?
.clone();
let max_conc = node
.max_concurrency
.unwrap_or(step_ctx.max_concurrency)
.max(1);
let semaphore = Arc::new(Semaphore::new(max_conc));
let mut sub_tasks = Vec::with_capacity(items.len());
for (idx, item) in items.iter().enumerate() {
let item = item.clone();
let as_name = node.as_name.clone();
let branch_clone = branch_node.clone();
let mut sub_state = state.fork_for_branch_state();
let sub_ctx = ctx.fork_for_branch();
let script_clone = step_ctx.script_executor.clone();
let sub_branch_id = node.branch.clone();
let sem = semaphore.clone();
let abort = step_ctx.abort_signal.clone();
sub_state.state_mut().set(as_name, item);
let task = tokio::spawn(async move {
let _permit = sem
.acquire()
.await
.expect("map semaphore should not be closed");
if abort.aborted() {
return (
idx,
sub_state,
Err(anyhow!("map sub-branch [{idx}] aborted")),
);
}
let mut state = sub_state;
let mut ctx = sub_ctx;
let exec_result: Result<()> = match &branch_clone.node_type {
NodeType::Llm(n) => LlmNodeExecutor::execute(
n,
Some(MAP_BRANCH_SENTINEL_NEXT),
&mut state,
&mut ctx,
)
.await
.map(|_| ()),
NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx)
.await
.map(|_| ()),
NodeType::Rag(n) => RagNodeExecutor::execute(
n,
&sub_branch_id,
Some(MAP_BRANCH_SENTINEL_NEXT),
&mut state,
&mut ctx,
)
.await
.map(|_| ()),
NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()),
_ => Err(anyhow!(
"map branch '{}' has type that cannot run inside a map \
(validator should have caught this; internal error)",
branch_clone.id
)),
};
(idx, state, exec_result)
});
sub_tasks.push(task);
}
let joined = join_all(sub_tasks).await;
// Collect outputs keyed by input index so order is preserved regardless
// of finish order. This is the user-facing contract from plan E.2.
let mut outputs: HashMap<usize, Value> = HashMap::new();
for join_result in joined {
let (idx, sub_state, exec_result) =
join_result.map_err(|e| anyhow!("map sub-branch panicked: {e}"))?;
exec_result
.with_context(|| format!("map node '{node_id}': sub-branch [{idx}] failed"))?;
let output_value = sub_state
.state()
.get(&node.output_key)
.cloned()
.ok_or_else(|| {
anyhow!(
"map node '{node_id}': sub-branch [{idx}] did not write \
`output_key` '{}'",
node.output_key
)
})?;
outputs.insert(idx, output_value);
}
let mut collected = Vec::with_capacity(items.len());
for idx in 0..items.len() {
let value = outputs.remove(&idx).ok_or_else(|| {
anyhow!(
"map node '{node_id}': internal error: missing result for sub-branch [{idx}]"
)
})?;
collected.push(value);
}
state
.state_mut()
.set(node.collect_into.clone(), Value::Array(collected));
Ok(())
}
}
+13
View File
@@ -3,6 +3,7 @@ pub mod dispatch;
pub mod executor;
pub mod llm;
pub mod logging;
pub mod map;
pub mod parser;
pub mod rag;
pub mod reducer;
@@ -14,6 +15,7 @@ pub mod types;
pub mod user_interaction;
pub mod validator;
use serde_json::Value;
pub use dispatch::{active_agent_graph_name, run_active_agent_graph};
pub use executor::GraphExecutor;
pub use parser::{GraphParser, agent_has_graph};
@@ -24,3 +26,14 @@ pub const GRAPH_SCHEMA_VERSION: &str = "1.0";
pub const DEFAULT_MAX_LOOP_ITERATIONS: usize = 100;
pub const MAX_STATE_SIZE_BYTES: usize = 32 * 1024;
pub (in crate::graph) fn type_name(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "bool",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
+1 -11
View File
@@ -1,6 +1,7 @@
use super::types::Reducer;
use anyhow::{Result, bail};
use serde_json::{Number, Value};
use crate::graph::type_name;
/// Combines a branch's incoming write with the current state value (if any)
/// via the specified reducer. The result is what gets written back to live
@@ -147,17 +148,6 @@ fn number_or_error(value: &Value, reducer_name: &str, position: &str) -> Result<
}
}
fn type_name(value: &Value) -> &'static str {
match value {
Value::Null => "null",
Value::Bool(_) => "bool",
Value::Number(_) => "number",
Value::String(_) => "string",
Value::Array(_) => "array",
Value::Object(_) => "object",
}
}
// Numeric reducers compute in f64 for simplicity. We preserve integer typing
// when the result is losslessly representable as i64 so `count: sum` stays an
// integer rather than degrading to a float. Non-finite values (NaN, Inf) can't