diff --git a/src/graph/executor.rs b/src/graph/executor.rs index a1773e0..4fa59f2 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -154,9 +154,12 @@ async fn step( match &node.node_type { NodeType::Agent(agent_node) => { AgentNodeExecutor::execute(agent_node, state, ctx).await?; - let next = node.next.clone().ok_or_else(|| { - anyhow!("agent node '{current}' has no `next` and is not an end node") - })?; + let next = node + .next_single()? + .ok_or_else(|| { + anyhow!("agent node '{current}' has no `next` and is not an end node") + })? + .to_string(); Ok(StepResult::Continue(next)) } NodeType::Script(script_node) => { @@ -173,9 +176,17 @@ async fn step( return Err(e); } }; - let next = dynamic.or_else(|| node.next.clone()).ok_or_else(|| { - anyhow!("script node '{current}' did not emit `_next` and has no static `next`") - })?; + let next = match dynamic { + Some(n) => n, + None => node + .next_single()? + .ok_or_else(|| { + anyhow!( + "script node '{current}' did not emit `_next` and has no static `next`" + ) + })? + .to_string(), + }; Ok(StepResult::Continue(next)) } NodeType::Approval(approval_node) => { @@ -183,21 +194,25 @@ async fn step( Ok(StepResult::Continue(next)) } NodeType::Input(input_node) => { - let next = - InputNodeExecutor::execute(input_node, node.next.as_deref(), state, ctx).await?; + let next_id = node.next_single()?; + let next = InputNodeExecutor::execute(input_node, next_id, state, ctx).await?; Ok(StepResult::Continue(next)) } NodeType::Llm(llm_node) => { - let next = LlmNodeExecutor::execute(llm_node, node.next.as_deref(), state, ctx).await?; + let next_id = node.next_single()?; + let next = LlmNodeExecutor::execute(llm_node, next_id, state, ctx).await?; Ok(StepResult::Continue(next)) } NodeType::Rag(rag_node) => { - let next = - RagNodeExecutor::execute(rag_node, current, node.next.as_deref(), state, ctx) - .await?; + let next_id = node.next_single()?; + let next = RagNodeExecutor::execute(rag_node, current, next_id, state, ctx).await?; Ok(StepResult::Continue(next)) } NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))), + NodeType::Map(_) => bail!( + "Map nodes are not yet supported in this build \ + (parallel branch execution lands in Phase D/E)." + ), } } diff --git a/src/graph/logging.rs b/src/graph/logging.rs index af29ef5..768d611 100644 --- a/src/graph/logging.rs +++ b/src/graph/logging.rs @@ -151,6 +151,7 @@ fn node_type_label(node: &Node) -> &'static str { NodeType::Llm(_) => "llm", NodeType::Rag(_) => "rag", NodeType::End(_) => "end", + NodeType::Map(_) => "map", } } diff --git a/src/graph/parser.rs b/src/graph/parser.rs index 2085187..58b6a5f 100644 --- a/src/graph/parser.rs +++ b/src/graph/parser.rs @@ -161,7 +161,7 @@ mod tests { assert_eq!(graph.start, "node1"); assert_eq!(graph.nodes.len(), 2); assert_eq!( - graph.nodes.get("node1").unwrap().next.as_deref(), + graph.nodes.get("node1").unwrap().next_target(), Some("node2") ); } diff --git a/src/graph/types.rs b/src/graph/types.rs index 5a7bc0c..81ba8f9 100644 --- a/src/graph/types.rs +++ b/src/graph/types.rs @@ -1,8 +1,9 @@ -use anyhow::Result; +use anyhow::{Result, bail}; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; +use std::slice; #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Graph { @@ -38,6 +39,9 @@ pub struct Graph { #[serde(default)] pub initial_state: HashMap, + #[serde(default)] + pub reducers: HashMap, + pub start: String, pub nodes: IndexMap, @@ -80,6 +84,9 @@ pub struct GraphSettings { #[serde(default = "default_true")] pub validate_before_run: bool, + + #[serde(default = "default_max_concurrency")] + pub max_concurrency: usize, } impl Default for GraphSettings { @@ -89,6 +96,7 @@ impl Default for GraphSettings { timeout: None, log_state_snapshots: true, validate_before_run: true, + max_concurrency: default_max_concurrency(), } } } @@ -101,6 +109,10 @@ fn default_true() -> bool { true } +fn default_max_concurrency() -> usize { + 8 +} + #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Node { #[serde(default)] @@ -113,7 +125,89 @@ pub struct Node { pub node_type: NodeType, #[serde(default, skip_serializing_if = "Option::is_none")] - pub next: Option, + pub next: Option, +} + +impl Node { + /// Returns the single next target as a string slice, or `None` if no next is + /// declared or if a multi-target fan-out is declared. Use this for read-only + /// inspection (e.g. tests). For execution paths that require single-target + /// semantics, use `next_single()` — it errors explicitly when a fan-out is + /// declared so the caller can surface a clear failure instead of skipping it. + #[allow(dead_code)] + pub fn next_target(&self) -> Option<&str> { + match &self.next { + None => None, + Some(NextTargets::One(s)) => Some(s), + Some(NextTargets::Many(v)) if v.len() == 1 => Some(&v[0]), + Some(NextTargets::Many(_)) => None, + } + } + + /// Returns the single next target as a string slice, or an explicit error if + /// the node declares a multi-target fan-out (which is not yet supported + /// pre-Phase-D). Returns `Ok(None)` when no next is declared at all. + pub fn next_single(&self) -> Result> { + match &self.next { + None => Ok(None), + Some(targets) => Ok(Some(targets.single()?.as_str())), + } + } +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum NextTargets { + One(String), + Many(Vec), +} + +impl NextTargets { + /// View as a slice of node ids. `One(s)` returns a single-element slice. + pub fn as_slice(&self) -> &[String] { + match self { + NextTargets::One(s) => slice::from_ref(s), + NextTargets::Many(v) => v.as_slice(), + } + } + + /// True if this declares more than one parallel target (i.e., a real fan-out). + #[allow(dead_code)] + pub fn is_fan_out(&self) -> bool { + matches!(self, NextTargets::Many(v) if v.len() > 1) + } + + /// Returns the single target if exactly one is declared, else errors with a + /// clear "not yet supported" message. Used by the v1 executor until parallel + /// branch execution lands in Phase D. + pub fn single(&self) -> Result<&String> { + match self { + NextTargets::One(s) => Ok(s), + NextTargets::Many(v) if v.len() == 1 => Ok(&v[0]), + NextTargets::Many(_) => bail!( + "Parallel fan-out (`next: [a, b, ...]`) is declared, but parallel \ + branch execution is not yet implemented in this build." + ), + } + } +} + +impl From for NextTargets { + fn from(s: String) -> Self { + NextTargets::One(s) + } +} + +impl From<&str> for NextTargets { + fn from(s: &str) -> Self { + NextTargets::One(s.to_string()) + } +} + +impl From> for NextTargets { + fn from(v: Vec) -> Self { + NextTargets::Many(v) + } } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -126,6 +220,7 @@ pub enum NodeType { Llm(LlmNode), Rag(RagNode), End(EndNode), + Map(MapNode), } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -277,6 +372,51 @@ pub struct EndNode { pub state_updates: Option>, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MapNode { + /// Template expression that must resolve (via `interpolate_raw`, added in + /// Phase B) to a JSON array. Each item in the array is one branch invocation. + pub over: String, + + /// The name to bind each item under, accessible as `{{}}` inside + /// the branch node's templates. YAML field is `as:`. + #[serde(rename = "as")] + pub as_name: String, + + /// Node id to invoke once per item in the resolved list. + pub branch: String, + + /// State key that the branch node writes; the map collects this key's value + /// across invocations. Defaults to "output". + #[serde(default = "default_map_output_key")] + pub output_key: String, + + /// State key to receive the array of per-branch outputs, in input-list order. + pub collect_into: String, + + /// Optional cap on simultaneously-running sub-branches. Falls back to + /// `settings.max_concurrency` when unset. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub max_concurrency: Option, +} + +fn default_map_output_key() -> String { + "output".to_string() +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum Reducer { + Append, + Extend, + Concat, + Sum, + Max, + Min, + Merge, + Overwrite, +} + #[derive(Debug, Clone, Default)] pub struct GraphState { data: HashMap, @@ -469,6 +609,7 @@ state_updates: next: configure "#; let node: Node = serde_yaml::from_str(yaml).unwrap(); + let next_target = node.next_target().map(str::to_string); let input = match node.node_type { NodeType::Input(i) => i, _ => panic!("expected Input variant"), @@ -481,7 +622,7 @@ next: configure updates.get("api_key").map(|s| s.as_str()), Some("{{input}}") ); - assert_eq!(node.next.as_deref(), Some("configure")); + assert_eq!(next_target.as_deref(), Some("configure")); } #[test] @@ -627,6 +768,7 @@ timeout: 30 next: review "#; let node: Node = serde_yaml::from_str(yaml).unwrap(); + let next_target = node.next_target().map(str::to_string); let llm = match node.node_type { NodeType::Llm(l) => l, _ => panic!("expected Llm variant"), @@ -646,7 +788,7 @@ next: review assert_eq!(llm.max_iterations, 5); assert_eq!(llm.timeout, Some(30)); assert!(llm.state_updates.is_some()); - assert_eq!(node.next.as_deref(), Some("review")); + assert_eq!(next_target.as_deref(), Some("review")); } #[test] @@ -788,4 +930,242 @@ nodes: let graph: Graph = serde_yaml::from_str(yaml).unwrap(); assert!(!graph.has_agent_node()); } + + #[test] + fn parses_static_fan_out_as_many_next_targets() { + let yaml = r#" +id: triage +type: llm +prompt: Classify +next: [retrieve_local, retrieve_web, retrieve_docs] +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + + let targets = node.next.as_ref().expect("next should be present"); + + assert!(targets.is_fan_out()); + assert_eq!( + targets.as_slice(), + &[ + "retrieve_local".to_string(), + "retrieve_web".to_string(), + "retrieve_docs".to_string() + ] + ); + } + + #[test] + fn parses_single_target_next_as_one_variant() { + let yaml = r#" +id: triage +type: llm +prompt: Classify +next: retrieve +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + + let targets = node.next.as_ref().expect("next should be present"); + + assert!(!targets.is_fan_out()); + assert_eq!(node.next_target(), Some("retrieve")); + } + + #[test] + fn next_single_errors_on_real_fan_out_with_clear_message() { + let yaml = r#" +id: triage +type: llm +prompt: Classify +next: [a, b] +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + + let err = node.next_single().unwrap_err().to_string(); + + assert!(err.contains("Parallel fan-out"), "got: {err}"); + assert!(err.contains("not yet implemented"), "got: {err}"); + } + + #[test] + fn next_single_accepts_many_containing_exactly_one_target() { + let yaml = r#" +id: triage +type: llm +prompt: Classify +next: [retrieve] +"#; + + let node: Node = serde_yaml::from_str(yaml).unwrap(); + + assert_eq!(node.next_single().unwrap(), Some("retrieve")); + assert_eq!(node.next_target(), Some("retrieve")); + } + + #[test] + fn next_targets_round_trips_through_yaml_for_both_variants() { + let one: NextTargets = serde_yaml::from_str(r#""foo""#).unwrap(); + let reparsed: NextTargets = + serde_yaml::from_str(&serde_yaml::to_string(&one).unwrap()).unwrap(); + assert_eq!(reparsed.as_slice(), &["foo".to_string()]); + + let many: NextTargets = serde_yaml::from_str("[a, b, c]").unwrap(); + let reparsed: NextTargets = + serde_yaml::from_str(&serde_yaml::to_string(&many).unwrap()).unwrap(); + assert_eq!( + reparsed.as_slice(), + &["a".to_string(), "b".to_string(), "c".to_string()] + ); + } + + #[test] + fn parses_reducers_block_with_all_builtins() { + let yaml = r#" +name: g +start: e +reducers: + sources: append + findings: extend + context: concat + cost_usd: sum + high_score: max + low_score: min + config: merge + forced: overwrite +nodes: + e: + type: end + output: ok +"#; + + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + + assert_eq!(graph.reducers.len(), 8); + assert_eq!(graph.reducers.get("sources"), Some(&Reducer::Append)); + assert_eq!(graph.reducers.get("findings"), Some(&Reducer::Extend)); + assert_eq!(graph.reducers.get("context"), Some(&Reducer::Concat)); + assert_eq!(graph.reducers.get("cost_usd"), Some(&Reducer::Sum)); + assert_eq!(graph.reducers.get("high_score"), Some(&Reducer::Max)); + assert_eq!(graph.reducers.get("low_score"), Some(&Reducer::Min)); + assert_eq!(graph.reducers.get("config"), Some(&Reducer::Merge)); + assert_eq!(graph.reducers.get("forced"), Some(&Reducer::Overwrite)); + } + + #[test] + fn reducers_default_to_empty_when_block_absent() { + let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n"; + + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + + assert!(graph.reducers.is_empty()); + } + + #[test] + fn max_concurrency_defaults_to_eight() { + let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n"; + + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + + assert_eq!(graph.settings.max_concurrency, 8); + } + + #[test] + fn max_concurrency_can_be_overridden() { + let yaml = r#" +name: g +start: x +settings: + max_concurrency: 16 +nodes: + x: + type: end +"#; + + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + + assert_eq!(graph.settings.max_concurrency, 16); + } + + #[test] + fn parses_map_node_with_all_fields() { + let yaml = r#" +id: fan_out +type: map +over: "{{subjects}}" +as: subject +branch: research_subject +output_key: research_result +collect_into: research_results +max_concurrency: 5 +next: rank +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + + let map = match node.node_type { + NodeType::Map(m) => m, + _ => panic!("expected Map variant"), + }; + + assert_eq!(map.over, "{{subjects}}"); + assert_eq!(map.as_name, "subject"); + assert_eq!(map.branch, "research_subject"); + assert_eq!(map.output_key, "research_result"); + assert_eq!(map.collect_into, "research_results"); + assert_eq!(map.max_concurrency, Some(5)); + } + + #[test] + fn map_node_uses_default_output_key_and_no_concurrency_cap() { + let yaml = r#" +id: fan_out +type: map +over: "{{items}}" +as: item +branch: process +collect_into: results +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + + let map = match node.node_type { + NodeType::Map(m) => m, + _ => panic!("expected Map variant"), + }; + + assert_eq!(map.output_key, "output"); + assert!(map.max_concurrency.is_none()); + } + + #[test] + fn full_graph_with_all_new_phase_a_fields_parses() { + let yaml = r#" +name: deep_research +start: triage +settings: + max_concurrency: 4 +reducers: + sources: append + cost_usd: sum +nodes: + triage: + type: llm + prompt: Classify + next: [retrieve_local, retrieve_web] + retrieve_local: + type: rag + documents: ["./docs"] + next: synthesize + retrieve_web: + type: llm + prompt: Search web + next: synthesize + synthesize: + type: end + output: done +"#; + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(graph.settings.max_concurrency, 4); + assert_eq!(graph.reducers.len(), 2); + let triage = graph.get_node("triage").unwrap(); + assert!(triage.next.as_ref().unwrap().is_fan_out()); + assert_eq!(triage.next.as_ref().unwrap().as_slice().len(), 2); + } } diff --git a/src/graph/validator.rs b/src/graph/validator.rs index 05cb91f..90beb96 100644 --- a/src/graph/validator.rs +++ b/src/graph/validator.rs @@ -318,8 +318,10 @@ impl GraphValidator { fn declared_targets(node: &Node) -> Vec<(String, &'static str)> { let mut out = Vec::new(); - if let Some(n) = &node.next { - out.push((n.clone(), "'next'")); + if let Some(targets) = &node.next { + for target in targets.as_slice() { + out.push((target.clone(), "'next'")); + } } match &node.node_type { @@ -342,6 +344,12 @@ fn declared_targets(node: &Node) -> Vec<(String, &'static str)> { // `agent`/`input`/`rag` route only via `next` (already collected // above); `end` is terminal. No type-specific routing edges to add. NodeType::Agent(_) | NodeType::Input(_) | NodeType::Rag(_) | NodeType::End(_) => {} + // A `map` node invokes its `branch:` target once per item from the + // resolved `over` list. The branch is statically referenced, so it + // is a real declared edge for cycle/reachability purposes. + NodeType::Map(m) => { + out.push((m.branch.clone(), "map 'branch'")); + } } out } @@ -434,6 +442,7 @@ mod tests { conversation_starters: Vec::new(), settings: GraphSettings::default(), initial_state: HashMap::new(), + reducers: HashMap::new(), start: start.into(), nodes: map, } @@ -529,7 +538,7 @@ mod tests { output_schema: None, timeout: None, }), - next: next.map(String::from), + next: next.map(NextTargets::from), } } @@ -759,7 +768,7 @@ mod tests { output_schema: None, timeout: None, }), - next: next.map(String::from), + next: next.map(NextTargets::from), } } @@ -1038,4 +1047,114 @@ mod tests { assert!(validator().validate(&graph).into_result().is_ok()); } + + #[test] + fn cycle_detector_treats_fan_out_diamond_as_a_valid_dag() { + let mut start = end_node("start"); + start.next = Some(NextTargets::Many(vec!["a".into(), "b".into()])); + let mut a = end_node("a"); + a.next = Some("join".into()); + let mut b = end_node("b"); + b.next = Some("join".into()); + let mut join = end_node("join"); + join.next = Some("end".into()); + + let graph = graph_with( + vec![ + ("start", start), + ("a", a), + ("b", b), + ("join", join), + ("end", end_node("end")), + ], + "start", + ); + + let result = validator().validate(&graph); + assert!( + !result + .errors + .iter() + .any(|e| e.message.contains("Cycle detected")), + "fan-out diamond incorrectly reported as cycle: {:?}", + result.errors + ); + } + + #[test] + fn reachability_visits_every_member_of_many_next_targets() { + let mut start = end_node("start"); + start.next = Some(NextTargets::Many(vec!["a".into(), "b".into(), "c".into()])); + let graph = graph_with( + vec![ + ("start", start), + ("a", end_node("a")), + ("b", end_node("b")), + ("c", end_node("c")), + ], + "start", + ); + + let result = validator().validate(&graph); + + for orphan in ["a", "b", "c"] { + assert!( + !result + .warnings + .iter() + .any(|w| w.node_id.as_deref() == Some(orphan) + && w.message.contains("unreachable")), + "fan-out target '{orphan}' incorrectly marked unreachable: {:?}", + result.warnings + ); + } + } + + #[test] + fn node_reference_check_catches_missing_member_inside_many() { + let mut start = end_node("start"); + start.next = Some(NextTargets::Many(vec!["a".into(), "ghost".into()])); + let graph = graph_with(vec![("start", start), ("a", end_node("a"))], "start"); + + let result = validator().validate(&graph); + + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("non-existent node 'ghost'") + && e.node_id.as_deref() == Some("start")), + "expected error for missing 'ghost' target in Many: {:?}", + result.errors + ); + } + + #[test] + fn node_reference_check_catches_missing_map_branch_target() { + let map = Node { + id: "fan".into(), + description: String::new(), + node_type: NodeType::Map(MapNode { + over: "{{items}}".into(), + as_name: "item".into(), + branch: "no_such_node".into(), + output_key: "output".into(), + collect_into: "results".into(), + max_concurrency: None, + }), + next: Some("end".into()), + }; + let graph = graph_with(vec![("fan", map), ("end", end_node("end"))], "fan"); + + let result = validator().validate(&graph); + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("non-existent node 'no_such_node'") + && e.message.contains("map 'branch'")), + "expected error for missing map branch: {:?}", + result.errors + ); + } }