feat: scaffolding work for fan-out nodes for parallel branch execution support and stubbed out Map node types
This commit is contained in:
+27
-12
@@ -154,9 +154,12 @@ async fn step(
|
||||
match &node.node_type {
|
||||
NodeType::Agent(agent_node) => {
|
||||
AgentNodeExecutor::execute(agent_node, state, ctx).await?;
|
||||
let next = node.next.clone().ok_or_else(|| {
|
||||
anyhow!("agent node '{current}' has no `next` and is not an end node")
|
||||
})?;
|
||||
let next = node
|
||||
.next_single()?
|
||||
.ok_or_else(|| {
|
||||
anyhow!("agent node '{current}' has no `next` and is not an end node")
|
||||
})?
|
||||
.to_string();
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
NodeType::Script(script_node) => {
|
||||
@@ -173,9 +176,17 @@ async fn step(
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
let next = dynamic.or_else(|| node.next.clone()).ok_or_else(|| {
|
||||
anyhow!("script node '{current}' did not emit `_next` and has no static `next`")
|
||||
})?;
|
||||
let next = match dynamic {
|
||||
Some(n) => n,
|
||||
None => node
|
||||
.next_single()?
|
||||
.ok_or_else(|| {
|
||||
anyhow!(
|
||||
"script node '{current}' did not emit `_next` and has no static `next`"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
};
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
NodeType::Approval(approval_node) => {
|
||||
@@ -183,21 +194,25 @@ async fn step(
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
NodeType::Input(input_node) => {
|
||||
let next =
|
||||
InputNodeExecutor::execute(input_node, node.next.as_deref(), state, ctx).await?;
|
||||
let next_id = node.next_single()?;
|
||||
let next = InputNodeExecutor::execute(input_node, next_id, state, ctx).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
NodeType::Llm(llm_node) => {
|
||||
let next = LlmNodeExecutor::execute(llm_node, node.next.as_deref(), state, ctx).await?;
|
||||
let next_id = node.next_single()?;
|
||||
let next = LlmNodeExecutor::execute(llm_node, next_id, state, ctx).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
NodeType::Rag(rag_node) => {
|
||||
let next =
|
||||
RagNodeExecutor::execute(rag_node, current, node.next.as_deref(), state, ctx)
|
||||
.await?;
|
||||
let next_id = node.next_single()?;
|
||||
let next = RagNodeExecutor::execute(rag_node, current, next_id, state, ctx).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
}
|
||||
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
||||
NodeType::Map(_) => bail!(
|
||||
"Map nodes are not yet supported in this build \
|
||||
(parallel branch execution lands in Phase D/E)."
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -151,6 +151,7 @@ fn node_type_label(node: &Node) -> &'static str {
|
||||
NodeType::Llm(_) => "llm",
|
||||
NodeType::Rag(_) => "rag",
|
||||
NodeType::End(_) => "end",
|
||||
NodeType::Map(_) => "map",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+1
-1
@@ -161,7 +161,7 @@ mod tests {
|
||||
assert_eq!(graph.start, "node1");
|
||||
assert_eq!(graph.nodes.len(), 2);
|
||||
assert_eq!(
|
||||
graph.nodes.get("node1").unwrap().next.as_deref(),
|
||||
graph.nodes.get("node1").unwrap().next_target(),
|
||||
Some("node2")
|
||||
);
|
||||
}
|
||||
|
||||
+384
-4
@@ -1,8 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use anyhow::{Result, bail};
|
||||
use indexmap::IndexMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::slice;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Graph {
|
||||
@@ -38,6 +39,9 @@ pub struct Graph {
|
||||
#[serde(default)]
|
||||
pub initial_state: HashMap<String, Value>,
|
||||
|
||||
#[serde(default)]
|
||||
pub reducers: HashMap<String, Reducer>,
|
||||
|
||||
pub start: String,
|
||||
|
||||
pub nodes: IndexMap<String, Node>,
|
||||
@@ -80,6 +84,9 @@ pub struct GraphSettings {
|
||||
|
||||
#[serde(default = "default_true")]
|
||||
pub validate_before_run: bool,
|
||||
|
||||
#[serde(default = "default_max_concurrency")]
|
||||
pub max_concurrency: usize,
|
||||
}
|
||||
|
||||
impl Default for GraphSettings {
|
||||
@@ -89,6 +96,7 @@ impl Default for GraphSettings {
|
||||
timeout: None,
|
||||
log_state_snapshots: true,
|
||||
validate_before_run: true,
|
||||
max_concurrency: default_max_concurrency(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -101,6 +109,10 @@ fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_max_concurrency() -> usize {
|
||||
8
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Node {
|
||||
#[serde(default)]
|
||||
@@ -113,7 +125,89 @@ pub struct Node {
|
||||
pub node_type: NodeType,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub next: Option<String>,
|
||||
pub next: Option<NextTargets>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Returns the single next target as a string slice, or `None` if no next is
|
||||
/// declared or if a multi-target fan-out is declared. Use this for read-only
|
||||
/// inspection (e.g. tests). For execution paths that require single-target
|
||||
/// semantics, use `next_single()` — it errors explicitly when a fan-out is
|
||||
/// declared so the caller can surface a clear failure instead of skipping it.
|
||||
#[allow(dead_code)]
|
||||
pub fn next_target(&self) -> Option<&str> {
|
||||
match &self.next {
|
||||
None => None,
|
||||
Some(NextTargets::One(s)) => Some(s),
|
||||
Some(NextTargets::Many(v)) if v.len() == 1 => Some(&v[0]),
|
||||
Some(NextTargets::Many(_)) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the single next target as a string slice, or an explicit error if
|
||||
/// the node declares a multi-target fan-out (which is not yet supported
|
||||
/// pre-Phase-D). Returns `Ok(None)` when no next is declared at all.
|
||||
pub fn next_single(&self) -> Result<Option<&str>> {
|
||||
match &self.next {
|
||||
None => Ok(None),
|
||||
Some(targets) => Ok(Some(targets.single()?.as_str())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum NextTargets {
|
||||
One(String),
|
||||
Many(Vec<String>),
|
||||
}
|
||||
|
||||
impl NextTargets {
|
||||
/// View as a slice of node ids. `One(s)` returns a single-element slice.
|
||||
pub fn as_slice(&self) -> &[String] {
|
||||
match self {
|
||||
NextTargets::One(s) => slice::from_ref(s),
|
||||
NextTargets::Many(v) => v.as_slice(),
|
||||
}
|
||||
}
|
||||
|
||||
/// True if this declares more than one parallel target (i.e., a real fan-out).
|
||||
#[allow(dead_code)]
|
||||
pub fn is_fan_out(&self) -> bool {
|
||||
matches!(self, NextTargets::Many(v) if v.len() > 1)
|
||||
}
|
||||
|
||||
/// Returns the single target if exactly one is declared, else errors with a
|
||||
/// clear "not yet supported" message. Used by the v1 executor until parallel
|
||||
/// branch execution lands in Phase D.
|
||||
pub fn single(&self) -> Result<&String> {
|
||||
match self {
|
||||
NextTargets::One(s) => Ok(s),
|
||||
NextTargets::Many(v) if v.len() == 1 => Ok(&v[0]),
|
||||
NextTargets::Many(_) => bail!(
|
||||
"Parallel fan-out (`next: [a, b, ...]`) is declared, but parallel \
|
||||
branch execution is not yet implemented in this build."
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for NextTargets {
|
||||
fn from(s: String) -> Self {
|
||||
NextTargets::One(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for NextTargets {
|
||||
fn from(s: &str) -> Self {
|
||||
NextTargets::One(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<String>> for NextTargets {
|
||||
fn from(v: Vec<String>) -> Self {
|
||||
NextTargets::Many(v)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -126,6 +220,7 @@ pub enum NodeType {
|
||||
Llm(LlmNode),
|
||||
Rag(RagNode),
|
||||
End(EndNode),
|
||||
Map(MapNode),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -277,6 +372,51 @@ pub struct EndNode {
|
||||
pub state_updates: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct MapNode {
|
||||
/// Template expression that must resolve (via `interpolate_raw`, added in
|
||||
/// Phase B) to a JSON array. Each item in the array is one branch invocation.
|
||||
pub over: String,
|
||||
|
||||
/// The name to bind each item under, accessible as `{{<as_name>}}` inside
|
||||
/// the branch node's templates. YAML field is `as:`.
|
||||
#[serde(rename = "as")]
|
||||
pub as_name: String,
|
||||
|
||||
/// Node id to invoke once per item in the resolved list.
|
||||
pub branch: String,
|
||||
|
||||
/// State key that the branch node writes; the map collects this key's value
|
||||
/// across invocations. Defaults to "output".
|
||||
#[serde(default = "default_map_output_key")]
|
||||
pub output_key: String,
|
||||
|
||||
/// State key to receive the array of per-branch outputs, in input-list order.
|
||||
pub collect_into: String,
|
||||
|
||||
/// Optional cap on simultaneously-running sub-branches. Falls back to
|
||||
/// `settings.max_concurrency` when unset.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_concurrency: Option<usize>,
|
||||
}
|
||||
|
||||
fn default_map_output_key() -> String {
|
||||
"output".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Reducer {
|
||||
Append,
|
||||
Extend,
|
||||
Concat,
|
||||
Sum,
|
||||
Max,
|
||||
Min,
|
||||
Merge,
|
||||
Overwrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct GraphState {
|
||||
data: HashMap<String, Value>,
|
||||
@@ -469,6 +609,7 @@ state_updates:
|
||||
next: configure
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
let next_target = node.next_target().map(str::to_string);
|
||||
let input = match node.node_type {
|
||||
NodeType::Input(i) => i,
|
||||
_ => panic!("expected Input variant"),
|
||||
@@ -481,7 +622,7 @@ next: configure
|
||||
updates.get("api_key").map(|s| s.as_str()),
|
||||
Some("{{input}}")
|
||||
);
|
||||
assert_eq!(node.next.as_deref(), Some("configure"));
|
||||
assert_eq!(next_target.as_deref(), Some("configure"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -627,6 +768,7 @@ timeout: 30
|
||||
next: review
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
let next_target = node.next_target().map(str::to_string);
|
||||
let llm = match node.node_type {
|
||||
NodeType::Llm(l) => l,
|
||||
_ => panic!("expected Llm variant"),
|
||||
@@ -646,7 +788,7 @@ next: review
|
||||
assert_eq!(llm.max_iterations, 5);
|
||||
assert_eq!(llm.timeout, Some(30));
|
||||
assert!(llm.state_updates.is_some());
|
||||
assert_eq!(node.next.as_deref(), Some("review"));
|
||||
assert_eq!(next_target.as_deref(), Some("review"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -788,4 +930,242 @@ nodes:
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
assert!(!graph.has_agent_node());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_static_fan_out_as_many_next_targets() {
|
||||
let yaml = r#"
|
||||
id: triage
|
||||
type: llm
|
||||
prompt: Classify
|
||||
next: [retrieve_local, retrieve_web, retrieve_docs]
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
let targets = node.next.as_ref().expect("next should be present");
|
||||
|
||||
assert!(targets.is_fan_out());
|
||||
assert_eq!(
|
||||
targets.as_slice(),
|
||||
&[
|
||||
"retrieve_local".to_string(),
|
||||
"retrieve_web".to_string(),
|
||||
"retrieve_docs".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_single_target_next_as_one_variant() {
|
||||
let yaml = r#"
|
||||
id: triage
|
||||
type: llm
|
||||
prompt: Classify
|
||||
next: retrieve
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
let targets = node.next.as_ref().expect("next should be present");
|
||||
|
||||
assert!(!targets.is_fan_out());
|
||||
assert_eq!(node.next_target(), Some("retrieve"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_single_errors_on_real_fan_out_with_clear_message() {
|
||||
let yaml = r#"
|
||||
id: triage
|
||||
type: llm
|
||||
prompt: Classify
|
||||
next: [a, b]
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
let err = node.next_single().unwrap_err().to_string();
|
||||
|
||||
assert!(err.contains("Parallel fan-out"), "got: {err}");
|
||||
assert!(err.contains("not yet implemented"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_single_accepts_many_containing_exactly_one_target() {
|
||||
let yaml = r#"
|
||||
id: triage
|
||||
type: llm
|
||||
prompt: Classify
|
||||
next: [retrieve]
|
||||
"#;
|
||||
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert_eq!(node.next_single().unwrap(), Some("retrieve"));
|
||||
assert_eq!(node.next_target(), Some("retrieve"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_targets_round_trips_through_yaml_for_both_variants() {
|
||||
let one: NextTargets = serde_yaml::from_str(r#""foo""#).unwrap();
|
||||
let reparsed: NextTargets =
|
||||
serde_yaml::from_str(&serde_yaml::to_string(&one).unwrap()).unwrap();
|
||||
assert_eq!(reparsed.as_slice(), &["foo".to_string()]);
|
||||
|
||||
let many: NextTargets = serde_yaml::from_str("[a, b, c]").unwrap();
|
||||
let reparsed: NextTargets =
|
||||
serde_yaml::from_str(&serde_yaml::to_string(&many).unwrap()).unwrap();
|
||||
assert_eq!(
|
||||
reparsed.as_slice(),
|
||||
&["a".to_string(), "b".to_string(), "c".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_reducers_block_with_all_builtins() {
|
||||
let yaml = r#"
|
||||
name: g
|
||||
start: e
|
||||
reducers:
|
||||
sources: append
|
||||
findings: extend
|
||||
context: concat
|
||||
cost_usd: sum
|
||||
high_score: max
|
||||
low_score: min
|
||||
config: merge
|
||||
forced: overwrite
|
||||
nodes:
|
||||
e:
|
||||
type: end
|
||||
output: ok
|
||||
"#;
|
||||
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert_eq!(graph.reducers.len(), 8);
|
||||
assert_eq!(graph.reducers.get("sources"), Some(&Reducer::Append));
|
||||
assert_eq!(graph.reducers.get("findings"), Some(&Reducer::Extend));
|
||||
assert_eq!(graph.reducers.get("context"), Some(&Reducer::Concat));
|
||||
assert_eq!(graph.reducers.get("cost_usd"), Some(&Reducer::Sum));
|
||||
assert_eq!(graph.reducers.get("high_score"), Some(&Reducer::Max));
|
||||
assert_eq!(graph.reducers.get("low_score"), Some(&Reducer::Min));
|
||||
assert_eq!(graph.reducers.get("config"), Some(&Reducer::Merge));
|
||||
assert_eq!(graph.reducers.get("forced"), Some(&Reducer::Overwrite));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reducers_default_to_empty_when_block_absent() {
|
||||
let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n";
|
||||
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert!(graph.reducers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_concurrency_defaults_to_eight() {
|
||||
let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n";
|
||||
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert_eq!(graph.settings.max_concurrency, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_concurrency_can_be_overridden() {
|
||||
let yaml = r#"
|
||||
name: g
|
||||
start: x
|
||||
settings:
|
||||
max_concurrency: 16
|
||||
nodes:
|
||||
x:
|
||||
type: end
|
||||
"#;
|
||||
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert_eq!(graph.settings.max_concurrency, 16);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_map_node_with_all_fields() {
|
||||
let yaml = r#"
|
||||
id: fan_out
|
||||
type: map
|
||||
over: "{{subjects}}"
|
||||
as: subject
|
||||
branch: research_subject
|
||||
output_key: research_result
|
||||
collect_into: research_results
|
||||
max_concurrency: 5
|
||||
next: rank
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
let map = match node.node_type {
|
||||
NodeType::Map(m) => m,
|
||||
_ => panic!("expected Map variant"),
|
||||
};
|
||||
|
||||
assert_eq!(map.over, "{{subjects}}");
|
||||
assert_eq!(map.as_name, "subject");
|
||||
assert_eq!(map.branch, "research_subject");
|
||||
assert_eq!(map.output_key, "research_result");
|
||||
assert_eq!(map.collect_into, "research_results");
|
||||
assert_eq!(map.max_concurrency, Some(5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_node_uses_default_output_key_and_no_concurrency_cap() {
|
||||
let yaml = r#"
|
||||
id: fan_out
|
||||
type: map
|
||||
over: "{{items}}"
|
||||
as: item
|
||||
branch: process
|
||||
collect_into: results
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
let map = match node.node_type {
|
||||
NodeType::Map(m) => m,
|
||||
_ => panic!("expected Map variant"),
|
||||
};
|
||||
|
||||
assert_eq!(map.output_key, "output");
|
||||
assert!(map.max_concurrency.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn full_graph_with_all_new_phase_a_fields_parses() {
|
||||
let yaml = r#"
|
||||
name: deep_research
|
||||
start: triage
|
||||
settings:
|
||||
max_concurrency: 4
|
||||
reducers:
|
||||
sources: append
|
||||
cost_usd: sum
|
||||
nodes:
|
||||
triage:
|
||||
type: llm
|
||||
prompt: Classify
|
||||
next: [retrieve_local, retrieve_web]
|
||||
retrieve_local:
|
||||
type: rag
|
||||
documents: ["./docs"]
|
||||
next: synthesize
|
||||
retrieve_web:
|
||||
type: llm
|
||||
prompt: Search web
|
||||
next: synthesize
|
||||
synthesize:
|
||||
type: end
|
||||
output: done
|
||||
"#;
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
assert_eq!(graph.settings.max_concurrency, 4);
|
||||
assert_eq!(graph.reducers.len(), 2);
|
||||
let triage = graph.get_node("triage").unwrap();
|
||||
assert!(triage.next.as_ref().unwrap().is_fan_out());
|
||||
assert_eq!(triage.next.as_ref().unwrap().as_slice().len(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
+123
-4
@@ -318,8 +318,10 @@ impl GraphValidator {
|
||||
|
||||
fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
||||
let mut out = Vec::new();
|
||||
if let Some(n) = &node.next {
|
||||
out.push((n.clone(), "'next'"));
|
||||
if let Some(targets) = &node.next {
|
||||
for target in targets.as_slice() {
|
||||
out.push((target.clone(), "'next'"));
|
||||
}
|
||||
}
|
||||
|
||||
match &node.node_type {
|
||||
@@ -342,6 +344,12 @@ fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
||||
// `agent`/`input`/`rag` route only via `next` (already collected
|
||||
// above); `end` is terminal. No type-specific routing edges to add.
|
||||
NodeType::Agent(_) | NodeType::Input(_) | NodeType::Rag(_) | NodeType::End(_) => {}
|
||||
// A `map` node invokes its `branch:` target once per item from the
|
||||
// resolved `over` list. The branch is statically referenced, so it
|
||||
// is a real declared edge for cycle/reachability purposes.
|
||||
NodeType::Map(m) => {
|
||||
out.push((m.branch.clone(), "map 'branch'"));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
@@ -434,6 +442,7 @@ mod tests {
|
||||
conversation_starters: Vec::new(),
|
||||
settings: GraphSettings::default(),
|
||||
initial_state: HashMap::new(),
|
||||
reducers: HashMap::new(),
|
||||
start: start.into(),
|
||||
nodes: map,
|
||||
}
|
||||
@@ -529,7 +538,7 @@ mod tests {
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
}),
|
||||
next: next.map(String::from),
|
||||
next: next.map(NextTargets::from),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -759,7 +768,7 @@ mod tests {
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
}),
|
||||
next: next.map(String::from),
|
||||
next: next.map(NextTargets::from),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1038,4 +1047,114 @@ mod tests {
|
||||
|
||||
assert!(validator().validate(&graph).into_result().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cycle_detector_treats_fan_out_diamond_as_a_valid_dag() {
|
||||
let mut start = end_node("start");
|
||||
start.next = Some(NextTargets::Many(vec!["a".into(), "b".into()]));
|
||||
let mut a = end_node("a");
|
||||
a.next = Some("join".into());
|
||||
let mut b = end_node("b");
|
||||
b.next = Some("join".into());
|
||||
let mut join = end_node("join");
|
||||
join.next = Some("end".into());
|
||||
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
("start", start),
|
||||
("a", a),
|
||||
("b", b),
|
||||
("join", join),
|
||||
("end", end_node("end")),
|
||||
],
|
||||
"start",
|
||||
);
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
assert!(
|
||||
!result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("Cycle detected")),
|
||||
"fan-out diamond incorrectly reported as cycle: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reachability_visits_every_member_of_many_next_targets() {
|
||||
let mut start = end_node("start");
|
||||
start.next = Some(NextTargets::Many(vec!["a".into(), "b".into(), "c".into()]));
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
("start", start),
|
||||
("a", end_node("a")),
|
||||
("b", end_node("b")),
|
||||
("c", end_node("c")),
|
||||
],
|
||||
"start",
|
||||
);
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
for orphan in ["a", "b", "c"] {
|
||||
assert!(
|
||||
!result
|
||||
.warnings
|
||||
.iter()
|
||||
.any(|w| w.node_id.as_deref() == Some(orphan)
|
||||
&& w.message.contains("unreachable")),
|
||||
"fan-out target '{orphan}' incorrectly marked unreachable: {:?}",
|
||||
result.warnings
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_reference_check_catches_missing_member_inside_many() {
|
||||
let mut start = end_node("start");
|
||||
start.next = Some(NextTargets::Many(vec!["a".into(), "ghost".into()]));
|
||||
let graph = graph_with(vec![("start", start), ("a", end_node("a"))], "start");
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("non-existent node 'ghost'")
|
||||
&& e.node_id.as_deref() == Some("start")),
|
||||
"expected error for missing 'ghost' target in Many: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_reference_check_catches_missing_map_branch_target() {
|
||||
let map = Node {
|
||||
id: "fan".into(),
|
||||
description: String::new(),
|
||||
node_type: NodeType::Map(MapNode {
|
||||
over: "{{items}}".into(),
|
||||
as_name: "item".into(),
|
||||
branch: "no_such_node".into(),
|
||||
output_key: "output".into(),
|
||||
collect_into: "results".into(),
|
||||
max_concurrency: None,
|
||||
}),
|
||||
next: Some("end".into()),
|
||||
};
|
||||
let graph = graph_with(vec![("fan", map), ("end", end_node("end"))], "fan");
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("non-existent node 'no_such_node'")
|
||||
&& e.message.contains("map 'branch'")),
|
||||
"expected error for missing map branch: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user