feat: scaffolding work for fan-out nodes for parallel branch execution support and stubbed out Map node types

This commit is contained in:
2026-05-20 11:37:23 -06:00
parent 9c22b41a13
commit ad650116f3
5 changed files with 536 additions and 21 deletions
+27 -12
View File
@@ -154,9 +154,12 @@ async fn step(
match &node.node_type {
NodeType::Agent(agent_node) => {
AgentNodeExecutor::execute(agent_node, state, ctx).await?;
let next = node.next.clone().ok_or_else(|| {
anyhow!("agent node '{current}' has no `next` and is not an end node")
})?;
let next = node
.next_single()?
.ok_or_else(|| {
anyhow!("agent node '{current}' has no `next` and is not an end node")
})?
.to_string();
Ok(StepResult::Continue(next))
}
NodeType::Script(script_node) => {
@@ -173,9 +176,17 @@ async fn step(
return Err(e);
}
};
let next = dynamic.or_else(|| node.next.clone()).ok_or_else(|| {
anyhow!("script node '{current}' did not emit `_next` and has no static `next`")
})?;
let next = match dynamic {
Some(n) => n,
None => node
.next_single()?
.ok_or_else(|| {
anyhow!(
"script node '{current}' did not emit `_next` and has no static `next`"
)
})?
.to_string(),
};
Ok(StepResult::Continue(next))
}
NodeType::Approval(approval_node) => {
@@ -183,21 +194,25 @@ async fn step(
Ok(StepResult::Continue(next))
}
NodeType::Input(input_node) => {
let next =
InputNodeExecutor::execute(input_node, node.next.as_deref(), state, ctx).await?;
let next_id = node.next_single()?;
let next = InputNodeExecutor::execute(input_node, next_id, state, ctx).await?;
Ok(StepResult::Continue(next))
}
NodeType::Llm(llm_node) => {
let next = LlmNodeExecutor::execute(llm_node, node.next.as_deref(), state, ctx).await?;
let next_id = node.next_single()?;
let next = LlmNodeExecutor::execute(llm_node, next_id, state, ctx).await?;
Ok(StepResult::Continue(next))
}
NodeType::Rag(rag_node) => {
let next =
RagNodeExecutor::execute(rag_node, current, node.next.as_deref(), state, ctx)
.await?;
let next_id = node.next_single()?;
let next = RagNodeExecutor::execute(rag_node, current, next_id, state, ctx).await?;
Ok(StepResult::Continue(next))
}
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
NodeType::Map(_) => bail!(
"Map nodes are not yet supported in this build \
(parallel branch execution lands in Phase D/E)."
),
}
}
+1
View File
@@ -151,6 +151,7 @@ fn node_type_label(node: &Node) -> &'static str {
NodeType::Llm(_) => "llm",
NodeType::Rag(_) => "rag",
NodeType::End(_) => "end",
NodeType::Map(_) => "map",
}
}
+1 -1
View File
@@ -161,7 +161,7 @@ mod tests {
assert_eq!(graph.start, "node1");
assert_eq!(graph.nodes.len(), 2);
assert_eq!(
graph.nodes.get("node1").unwrap().next.as_deref(),
graph.nodes.get("node1").unwrap().next_target(),
Some("node2")
);
}
+384 -4
View File
@@ -1,8 +1,9 @@
use anyhow::Result;
use anyhow::{Result, bail};
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::slice;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Graph {
@@ -38,6 +39,9 @@ pub struct Graph {
#[serde(default)]
pub initial_state: HashMap<String, Value>,
#[serde(default)]
pub reducers: HashMap<String, Reducer>,
pub start: String,
pub nodes: IndexMap<String, Node>,
@@ -80,6 +84,9 @@ pub struct GraphSettings {
#[serde(default = "default_true")]
pub validate_before_run: bool,
#[serde(default = "default_max_concurrency")]
pub max_concurrency: usize,
}
impl Default for GraphSettings {
@@ -89,6 +96,7 @@ impl Default for GraphSettings {
timeout: None,
log_state_snapshots: true,
validate_before_run: true,
max_concurrency: default_max_concurrency(),
}
}
}
@@ -101,6 +109,10 @@ fn default_true() -> bool {
true
}
fn default_max_concurrency() -> usize {
8
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Node {
#[serde(default)]
@@ -113,7 +125,89 @@ pub struct Node {
pub node_type: NodeType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub next: Option<String>,
pub next: Option<NextTargets>,
}
impl Node {
/// Returns the single next target as a string slice, or `None` if no next is
/// declared or if a multi-target fan-out is declared. Use this for read-only
/// inspection (e.g. tests). For execution paths that require single-target
/// semantics, use `next_single()` — it errors explicitly when a fan-out is
/// declared so the caller can surface a clear failure instead of skipping it.
#[allow(dead_code)]
pub fn next_target(&self) -> Option<&str> {
match &self.next {
None => None,
Some(NextTargets::One(s)) => Some(s),
Some(NextTargets::Many(v)) if v.len() == 1 => Some(&v[0]),
Some(NextTargets::Many(_)) => None,
}
}
/// Returns the single next target as a string slice, or an explicit error if
/// the node declares a multi-target fan-out (which is not yet supported
/// pre-Phase-D). Returns `Ok(None)` when no next is declared at all.
pub fn next_single(&self) -> Result<Option<&str>> {
match &self.next {
None => Ok(None),
Some(targets) => Ok(Some(targets.single()?.as_str())),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum NextTargets {
One(String),
Many(Vec<String>),
}
impl NextTargets {
/// View as a slice of node ids. `One(s)` returns a single-element slice.
pub fn as_slice(&self) -> &[String] {
match self {
NextTargets::One(s) => slice::from_ref(s),
NextTargets::Many(v) => v.as_slice(),
}
}
/// True if this declares more than one parallel target (i.e., a real fan-out).
#[allow(dead_code)]
pub fn is_fan_out(&self) -> bool {
matches!(self, NextTargets::Many(v) if v.len() > 1)
}
/// Returns the single target if exactly one is declared, else errors with a
/// clear "not yet supported" message. Used by the v1 executor until parallel
/// branch execution lands in Phase D.
pub fn single(&self) -> Result<&String> {
match self {
NextTargets::One(s) => Ok(s),
NextTargets::Many(v) if v.len() == 1 => Ok(&v[0]),
NextTargets::Many(_) => bail!(
"Parallel fan-out (`next: [a, b, ...]`) is declared, but parallel \
branch execution is not yet implemented in this build."
),
}
}
}
impl From<String> for NextTargets {
fn from(s: String) -> Self {
NextTargets::One(s)
}
}
impl From<&str> for NextTargets {
fn from(s: &str) -> Self {
NextTargets::One(s.to_string())
}
}
impl From<Vec<String>> for NextTargets {
fn from(v: Vec<String>) -> Self {
NextTargets::Many(v)
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -126,6 +220,7 @@ pub enum NodeType {
Llm(LlmNode),
Rag(RagNode),
End(EndNode),
Map(MapNode),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -277,6 +372,51 @@ pub struct EndNode {
pub state_updates: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct MapNode {
/// Template expression that must resolve (via `interpolate_raw`, added in
/// Phase B) to a JSON array. Each item in the array is one branch invocation.
pub over: String,
/// The name to bind each item under, accessible as `{{<as_name>}}` inside
/// the branch node's templates. YAML field is `as:`.
#[serde(rename = "as")]
pub as_name: String,
/// Node id to invoke once per item in the resolved list.
pub branch: String,
/// State key that the branch node writes; the map collects this key's value
/// across invocations. Defaults to "output".
#[serde(default = "default_map_output_key")]
pub output_key: String,
/// State key to receive the array of per-branch outputs, in input-list order.
pub collect_into: String,
/// Optional cap on simultaneously-running sub-branches. Falls back to
/// `settings.max_concurrency` when unset.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_concurrency: Option<usize>,
}
fn default_map_output_key() -> String {
"output".to_string()
}
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Reducer {
Append,
Extend,
Concat,
Sum,
Max,
Min,
Merge,
Overwrite,
}
#[derive(Debug, Clone, Default)]
pub struct GraphState {
data: HashMap<String, Value>,
@@ -469,6 +609,7 @@ state_updates:
next: configure
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
let next_target = node.next_target().map(str::to_string);
let input = match node.node_type {
NodeType::Input(i) => i,
_ => panic!("expected Input variant"),
@@ -481,7 +622,7 @@ next: configure
updates.get("api_key").map(|s| s.as_str()),
Some("{{input}}")
);
assert_eq!(node.next.as_deref(), Some("configure"));
assert_eq!(next_target.as_deref(), Some("configure"));
}
#[test]
@@ -627,6 +768,7 @@ timeout: 30
next: review
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
let next_target = node.next_target().map(str::to_string);
let llm = match node.node_type {
NodeType::Llm(l) => l,
_ => panic!("expected Llm variant"),
@@ -646,7 +788,7 @@ next: review
assert_eq!(llm.max_iterations, 5);
assert_eq!(llm.timeout, Some(30));
assert!(llm.state_updates.is_some());
assert_eq!(node.next.as_deref(), Some("review"));
assert_eq!(next_target.as_deref(), Some("review"));
}
#[test]
@@ -788,4 +930,242 @@ nodes:
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
assert!(!graph.has_agent_node());
}
#[test]
fn parses_static_fan_out_as_many_next_targets() {
let yaml = r#"
id: triage
type: llm
prompt: Classify
next: [retrieve_local, retrieve_web, retrieve_docs]
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
let targets = node.next.as_ref().expect("next should be present");
assert!(targets.is_fan_out());
assert_eq!(
targets.as_slice(),
&[
"retrieve_local".to_string(),
"retrieve_web".to_string(),
"retrieve_docs".to_string()
]
);
}
#[test]
fn parses_single_target_next_as_one_variant() {
let yaml = r#"
id: triage
type: llm
prompt: Classify
next: retrieve
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
let targets = node.next.as_ref().expect("next should be present");
assert!(!targets.is_fan_out());
assert_eq!(node.next_target(), Some("retrieve"));
}
#[test]
fn next_single_errors_on_real_fan_out_with_clear_message() {
let yaml = r#"
id: triage
type: llm
prompt: Classify
next: [a, b]
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
let err = node.next_single().unwrap_err().to_string();
assert!(err.contains("Parallel fan-out"), "got: {err}");
assert!(err.contains("not yet implemented"), "got: {err}");
}
#[test]
fn next_single_accepts_many_containing_exactly_one_target() {
let yaml = r#"
id: triage
type: llm
prompt: Classify
next: [retrieve]
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
assert_eq!(node.next_single().unwrap(), Some("retrieve"));
assert_eq!(node.next_target(), Some("retrieve"));
}
#[test]
fn next_targets_round_trips_through_yaml_for_both_variants() {
let one: NextTargets = serde_yaml::from_str(r#""foo""#).unwrap();
let reparsed: NextTargets =
serde_yaml::from_str(&serde_yaml::to_string(&one).unwrap()).unwrap();
assert_eq!(reparsed.as_slice(), &["foo".to_string()]);
let many: NextTargets = serde_yaml::from_str("[a, b, c]").unwrap();
let reparsed: NextTargets =
serde_yaml::from_str(&serde_yaml::to_string(&many).unwrap()).unwrap();
assert_eq!(
reparsed.as_slice(),
&["a".to_string(), "b".to_string(), "c".to_string()]
);
}
#[test]
fn parses_reducers_block_with_all_builtins() {
let yaml = r#"
name: g
start: e
reducers:
sources: append
findings: extend
context: concat
cost_usd: sum
high_score: max
low_score: min
config: merge
forced: overwrite
nodes:
e:
type: end
output: ok
"#;
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
assert_eq!(graph.reducers.len(), 8);
assert_eq!(graph.reducers.get("sources"), Some(&Reducer::Append));
assert_eq!(graph.reducers.get("findings"), Some(&Reducer::Extend));
assert_eq!(graph.reducers.get("context"), Some(&Reducer::Concat));
assert_eq!(graph.reducers.get("cost_usd"), Some(&Reducer::Sum));
assert_eq!(graph.reducers.get("high_score"), Some(&Reducer::Max));
assert_eq!(graph.reducers.get("low_score"), Some(&Reducer::Min));
assert_eq!(graph.reducers.get("config"), Some(&Reducer::Merge));
assert_eq!(graph.reducers.get("forced"), Some(&Reducer::Overwrite));
}
#[test]
fn reducers_default_to_empty_when_block_absent() {
let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n";
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
assert!(graph.reducers.is_empty());
}
#[test]
fn max_concurrency_defaults_to_eight() {
let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n";
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
assert_eq!(graph.settings.max_concurrency, 8);
}
#[test]
fn max_concurrency_can_be_overridden() {
let yaml = r#"
name: g
start: x
settings:
max_concurrency: 16
nodes:
x:
type: end
"#;
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
assert_eq!(graph.settings.max_concurrency, 16);
}
#[test]
fn parses_map_node_with_all_fields() {
let yaml = r#"
id: fan_out
type: map
over: "{{subjects}}"
as: subject
branch: research_subject
output_key: research_result
collect_into: research_results
max_concurrency: 5
next: rank
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
let map = match node.node_type {
NodeType::Map(m) => m,
_ => panic!("expected Map variant"),
};
assert_eq!(map.over, "{{subjects}}");
assert_eq!(map.as_name, "subject");
assert_eq!(map.branch, "research_subject");
assert_eq!(map.output_key, "research_result");
assert_eq!(map.collect_into, "research_results");
assert_eq!(map.max_concurrency, Some(5));
}
#[test]
fn map_node_uses_default_output_key_and_no_concurrency_cap() {
let yaml = r#"
id: fan_out
type: map
over: "{{items}}"
as: item
branch: process
collect_into: results
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
let map = match node.node_type {
NodeType::Map(m) => m,
_ => panic!("expected Map variant"),
};
assert_eq!(map.output_key, "output");
assert!(map.max_concurrency.is_none());
}
#[test]
fn full_graph_with_all_new_phase_a_fields_parses() {
let yaml = r#"
name: deep_research
start: triage
settings:
max_concurrency: 4
reducers:
sources: append
cost_usd: sum
nodes:
triage:
type: llm
prompt: Classify
next: [retrieve_local, retrieve_web]
retrieve_local:
type: rag
documents: ["./docs"]
next: synthesize
retrieve_web:
type: llm
prompt: Search web
next: synthesize
synthesize:
type: end
output: done
"#;
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
assert_eq!(graph.settings.max_concurrency, 4);
assert_eq!(graph.reducers.len(), 2);
let triage = graph.get_node("triage").unwrap();
assert!(triage.next.as_ref().unwrap().is_fan_out());
assert_eq!(triage.next.as_ref().unwrap().as_slice().len(), 2);
}
}
+123 -4
View File
@@ -318,8 +318,10 @@ impl GraphValidator {
fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
let mut out = Vec::new();
if let Some(n) = &node.next {
out.push((n.clone(), "'next'"));
if let Some(targets) = &node.next {
for target in targets.as_slice() {
out.push((target.clone(), "'next'"));
}
}
match &node.node_type {
@@ -342,6 +344,12 @@ fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
// `agent`/`input`/`rag` route only via `next` (already collected
// above); `end` is terminal. No type-specific routing edges to add.
NodeType::Agent(_) | NodeType::Input(_) | NodeType::Rag(_) | NodeType::End(_) => {}
// A `map` node invokes its `branch:` target once per item from the
// resolved `over` list. The branch is statically referenced, so it
// is a real declared edge for cycle/reachability purposes.
NodeType::Map(m) => {
out.push((m.branch.clone(), "map 'branch'"));
}
}
out
}
@@ -434,6 +442,7 @@ mod tests {
conversation_starters: Vec::new(),
settings: GraphSettings::default(),
initial_state: HashMap::new(),
reducers: HashMap::new(),
start: start.into(),
nodes: map,
}
@@ -529,7 +538,7 @@ mod tests {
output_schema: None,
timeout: None,
}),
next: next.map(String::from),
next: next.map(NextTargets::from),
}
}
@@ -759,7 +768,7 @@ mod tests {
output_schema: None,
timeout: None,
}),
next: next.map(String::from),
next: next.map(NextTargets::from),
}
}
@@ -1038,4 +1047,114 @@ mod tests {
assert!(validator().validate(&graph).into_result().is_ok());
}
#[test]
fn cycle_detector_treats_fan_out_diamond_as_a_valid_dag() {
let mut start = end_node("start");
start.next = Some(NextTargets::Many(vec!["a".into(), "b".into()]));
let mut a = end_node("a");
a.next = Some("join".into());
let mut b = end_node("b");
b.next = Some("join".into());
let mut join = end_node("join");
join.next = Some("end".into());
let graph = graph_with(
vec![
("start", start),
("a", a),
("b", b),
("join", join),
("end", end_node("end")),
],
"start",
);
let result = validator().validate(&graph);
assert!(
!result
.errors
.iter()
.any(|e| e.message.contains("Cycle detected")),
"fan-out diamond incorrectly reported as cycle: {:?}",
result.errors
);
}
#[test]
fn reachability_visits_every_member_of_many_next_targets() {
let mut start = end_node("start");
start.next = Some(NextTargets::Many(vec!["a".into(), "b".into(), "c".into()]));
let graph = graph_with(
vec![
("start", start),
("a", end_node("a")),
("b", end_node("b")),
("c", end_node("c")),
],
"start",
);
let result = validator().validate(&graph);
for orphan in ["a", "b", "c"] {
assert!(
!result
.warnings
.iter()
.any(|w| w.node_id.as_deref() == Some(orphan)
&& w.message.contains("unreachable")),
"fan-out target '{orphan}' incorrectly marked unreachable: {:?}",
result.warnings
);
}
}
#[test]
fn node_reference_check_catches_missing_member_inside_many() {
let mut start = end_node("start");
start.next = Some(NextTargets::Many(vec!["a".into(), "ghost".into()]));
let graph = graph_with(vec![("start", start), ("a", end_node("a"))], "start");
let result = validator().validate(&graph);
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("non-existent node 'ghost'")
&& e.node_id.as_deref() == Some("start")),
"expected error for missing 'ghost' target in Many: {:?}",
result.errors
);
}
#[test]
fn node_reference_check_catches_missing_map_branch_target() {
let map = Node {
id: "fan".into(),
description: String::new(),
node_type: NodeType::Map(MapNode {
over: "{{items}}".into(),
as_name: "item".into(),
branch: "no_such_node".into(),
output_key: "output".into(),
collect_into: "results".into(),
max_concurrency: None,
}),
next: Some("end".into()),
};
let graph = graph_with(vec![("fan", map), ("end", end_node("end"))], "fan");
let result = validator().validate(&graph);
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("non-existent node 'no_such_node'")
&& e.message.contains("map 'branch'")),
"expected error for missing map branch: {:?}",
result.errors
);
}
}