feat: scaffolding work for fan-out nodes for parallel branch execution support and stubbed out Map node types
This commit is contained in:
+27
-12
@@ -154,9 +154,12 @@ async fn step(
|
|||||||
match &node.node_type {
|
match &node.node_type {
|
||||||
NodeType::Agent(agent_node) => {
|
NodeType::Agent(agent_node) => {
|
||||||
AgentNodeExecutor::execute(agent_node, state, ctx).await?;
|
AgentNodeExecutor::execute(agent_node, state, ctx).await?;
|
||||||
let next = node.next.clone().ok_or_else(|| {
|
let next = node
|
||||||
anyhow!("agent node '{current}' has no `next` and is not an end node")
|
.next_single()?
|
||||||
})?;
|
.ok_or_else(|| {
|
||||||
|
anyhow!("agent node '{current}' has no `next` and is not an end node")
|
||||||
|
})?
|
||||||
|
.to_string();
|
||||||
Ok(StepResult::Continue(next))
|
Ok(StepResult::Continue(next))
|
||||||
}
|
}
|
||||||
NodeType::Script(script_node) => {
|
NodeType::Script(script_node) => {
|
||||||
@@ -173,9 +176,17 @@ async fn step(
|
|||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let next = dynamic.or_else(|| node.next.clone()).ok_or_else(|| {
|
let next = match dynamic {
|
||||||
anyhow!("script node '{current}' did not emit `_next` and has no static `next`")
|
Some(n) => n,
|
||||||
})?;
|
None => node
|
||||||
|
.next_single()?
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow!(
|
||||||
|
"script node '{current}' did not emit `_next` and has no static `next`"
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.to_string(),
|
||||||
|
};
|
||||||
Ok(StepResult::Continue(next))
|
Ok(StepResult::Continue(next))
|
||||||
}
|
}
|
||||||
NodeType::Approval(approval_node) => {
|
NodeType::Approval(approval_node) => {
|
||||||
@@ -183,21 +194,25 @@ async fn step(
|
|||||||
Ok(StepResult::Continue(next))
|
Ok(StepResult::Continue(next))
|
||||||
}
|
}
|
||||||
NodeType::Input(input_node) => {
|
NodeType::Input(input_node) => {
|
||||||
let next =
|
let next_id = node.next_single()?;
|
||||||
InputNodeExecutor::execute(input_node, node.next.as_deref(), state, ctx).await?;
|
let next = InputNodeExecutor::execute(input_node, next_id, state, ctx).await?;
|
||||||
Ok(StepResult::Continue(next))
|
Ok(StepResult::Continue(next))
|
||||||
}
|
}
|
||||||
NodeType::Llm(llm_node) => {
|
NodeType::Llm(llm_node) => {
|
||||||
let next = LlmNodeExecutor::execute(llm_node, node.next.as_deref(), state, ctx).await?;
|
let next_id = node.next_single()?;
|
||||||
|
let next = LlmNodeExecutor::execute(llm_node, next_id, state, ctx).await?;
|
||||||
Ok(StepResult::Continue(next))
|
Ok(StepResult::Continue(next))
|
||||||
}
|
}
|
||||||
NodeType::Rag(rag_node) => {
|
NodeType::Rag(rag_node) => {
|
||||||
let next =
|
let next_id = node.next_single()?;
|
||||||
RagNodeExecutor::execute(rag_node, current, node.next.as_deref(), state, ctx)
|
let next = RagNodeExecutor::execute(rag_node, current, next_id, state, ctx).await?;
|
||||||
.await?;
|
|
||||||
Ok(StepResult::Continue(next))
|
Ok(StepResult::Continue(next))
|
||||||
}
|
}
|
||||||
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
||||||
|
NodeType::Map(_) => bail!(
|
||||||
|
"Map nodes are not yet supported in this build \
|
||||||
|
(parallel branch execution lands in Phase D/E)."
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ fn node_type_label(node: &Node) -> &'static str {
|
|||||||
NodeType::Llm(_) => "llm",
|
NodeType::Llm(_) => "llm",
|
||||||
NodeType::Rag(_) => "rag",
|
NodeType::Rag(_) => "rag",
|
||||||
NodeType::End(_) => "end",
|
NodeType::End(_) => "end",
|
||||||
|
NodeType::Map(_) => "map",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+1
-1
@@ -161,7 +161,7 @@ mod tests {
|
|||||||
assert_eq!(graph.start, "node1");
|
assert_eq!(graph.start, "node1");
|
||||||
assert_eq!(graph.nodes.len(), 2);
|
assert_eq!(graph.nodes.len(), 2);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
graph.nodes.get("node1").unwrap().next.as_deref(),
|
graph.nodes.get("node1").unwrap().next_target(),
|
||||||
Some("node2")
|
Some("node2")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
+384
-4
@@ -1,8 +1,9 @@
|
|||||||
use anyhow::Result;
|
use anyhow::{Result, bail};
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::slice;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Graph {
|
pub struct Graph {
|
||||||
@@ -38,6 +39,9 @@ pub struct Graph {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub initial_state: HashMap<String, Value>,
|
pub initial_state: HashMap<String, Value>,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub reducers: HashMap<String, Reducer>,
|
||||||
|
|
||||||
pub start: String,
|
pub start: String,
|
||||||
|
|
||||||
pub nodes: IndexMap<String, Node>,
|
pub nodes: IndexMap<String, Node>,
|
||||||
@@ -80,6 +84,9 @@ pub struct GraphSettings {
|
|||||||
|
|
||||||
#[serde(default = "default_true")]
|
#[serde(default = "default_true")]
|
||||||
pub validate_before_run: bool,
|
pub validate_before_run: bool,
|
||||||
|
|
||||||
|
#[serde(default = "default_max_concurrency")]
|
||||||
|
pub max_concurrency: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for GraphSettings {
|
impl Default for GraphSettings {
|
||||||
@@ -89,6 +96,7 @@ impl Default for GraphSettings {
|
|||||||
timeout: None,
|
timeout: None,
|
||||||
log_state_snapshots: true,
|
log_state_snapshots: true,
|
||||||
validate_before_run: true,
|
validate_before_run: true,
|
||||||
|
max_concurrency: default_max_concurrency(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,6 +109,10 @@ fn default_true() -> bool {
|
|||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_max_concurrency() -> usize {
|
||||||
|
8
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
pub struct Node {
|
pub struct Node {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -113,7 +125,89 @@ pub struct Node {
|
|||||||
pub node_type: NodeType,
|
pub node_type: NodeType,
|
||||||
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub next: Option<String>,
|
pub next: Option<NextTargets>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Node {
|
||||||
|
/// Returns the single next target as a string slice, or `None` if no next is
|
||||||
|
/// declared or if a multi-target fan-out is declared. Use this for read-only
|
||||||
|
/// inspection (e.g. tests). For execution paths that require single-target
|
||||||
|
/// semantics, use `next_single()` — it errors explicitly when a fan-out is
|
||||||
|
/// declared so the caller can surface a clear failure instead of skipping it.
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn next_target(&self) -> Option<&str> {
|
||||||
|
match &self.next {
|
||||||
|
None => None,
|
||||||
|
Some(NextTargets::One(s)) => Some(s),
|
||||||
|
Some(NextTargets::Many(v)) if v.len() == 1 => Some(&v[0]),
|
||||||
|
Some(NextTargets::Many(_)) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the single next target as a string slice, or an explicit error if
|
||||||
|
/// the node declares a multi-target fan-out (which is not yet supported
|
||||||
|
/// pre-Phase-D). Returns `Ok(None)` when no next is declared at all.
|
||||||
|
pub fn next_single(&self) -> Result<Option<&str>> {
|
||||||
|
match &self.next {
|
||||||
|
None => Ok(None),
|
||||||
|
Some(targets) => Ok(Some(targets.single()?.as_str())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum NextTargets {
|
||||||
|
One(String),
|
||||||
|
Many(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NextTargets {
|
||||||
|
/// View as a slice of node ids. `One(s)` returns a single-element slice.
|
||||||
|
pub fn as_slice(&self) -> &[String] {
|
||||||
|
match self {
|
||||||
|
NextTargets::One(s) => slice::from_ref(s),
|
||||||
|
NextTargets::Many(v) => v.as_slice(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// True if this declares more than one parallel target (i.e., a real fan-out).
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn is_fan_out(&self) -> bool {
|
||||||
|
matches!(self, NextTargets::Many(v) if v.len() > 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the single target if exactly one is declared, else errors with a
|
||||||
|
/// clear "not yet supported" message. Used by the v1 executor until parallel
|
||||||
|
/// branch execution lands in Phase D.
|
||||||
|
pub fn single(&self) -> Result<&String> {
|
||||||
|
match self {
|
||||||
|
NextTargets::One(s) => Ok(s),
|
||||||
|
NextTargets::Many(v) if v.len() == 1 => Ok(&v[0]),
|
||||||
|
NextTargets::Many(_) => bail!(
|
||||||
|
"Parallel fan-out (`next: [a, b, ...]`) is declared, but parallel \
|
||||||
|
branch execution is not yet implemented in this build."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for NextTargets {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
NextTargets::One(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for NextTargets {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
NextTargets::One(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<Vec<String>> for NextTargets {
|
||||||
|
fn from(v: Vec<String>) -> Self {
|
||||||
|
NextTargets::Many(v)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -126,6 +220,7 @@ pub enum NodeType {
|
|||||||
Llm(LlmNode),
|
Llm(LlmNode),
|
||||||
Rag(RagNode),
|
Rag(RagNode),
|
||||||
End(EndNode),
|
End(EndNode),
|
||||||
|
Map(MapNode),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -277,6 +372,51 @@ pub struct EndNode {
|
|||||||
pub state_updates: Option<HashMap<String, String>>,
|
pub state_updates: Option<HashMap<String, String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct MapNode {
|
||||||
|
/// Template expression that must resolve (via `interpolate_raw`, added in
|
||||||
|
/// Phase B) to a JSON array. Each item in the array is one branch invocation.
|
||||||
|
pub over: String,
|
||||||
|
|
||||||
|
/// The name to bind each item under, accessible as `{{<as_name>}}` inside
|
||||||
|
/// the branch node's templates. YAML field is `as:`.
|
||||||
|
#[serde(rename = "as")]
|
||||||
|
pub as_name: String,
|
||||||
|
|
||||||
|
/// Node id to invoke once per item in the resolved list.
|
||||||
|
pub branch: String,
|
||||||
|
|
||||||
|
/// State key that the branch node writes; the map collects this key's value
|
||||||
|
/// across invocations. Defaults to "output".
|
||||||
|
#[serde(default = "default_map_output_key")]
|
||||||
|
pub output_key: String,
|
||||||
|
|
||||||
|
/// State key to receive the array of per-branch outputs, in input-list order.
|
||||||
|
pub collect_into: String,
|
||||||
|
|
||||||
|
/// Optional cap on simultaneously-running sub-branches. Falls back to
|
||||||
|
/// `settings.max_concurrency` when unset.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_concurrency: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_map_output_key() -> String {
|
||||||
|
"output".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum Reducer {
|
||||||
|
Append,
|
||||||
|
Extend,
|
||||||
|
Concat,
|
||||||
|
Sum,
|
||||||
|
Max,
|
||||||
|
Min,
|
||||||
|
Merge,
|
||||||
|
Overwrite,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct GraphState {
|
pub struct GraphState {
|
||||||
data: HashMap<String, Value>,
|
data: HashMap<String, Value>,
|
||||||
@@ -469,6 +609,7 @@ state_updates:
|
|||||||
next: configure
|
next: configure
|
||||||
"#;
|
"#;
|
||||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
let next_target = node.next_target().map(str::to_string);
|
||||||
let input = match node.node_type {
|
let input = match node.node_type {
|
||||||
NodeType::Input(i) => i,
|
NodeType::Input(i) => i,
|
||||||
_ => panic!("expected Input variant"),
|
_ => panic!("expected Input variant"),
|
||||||
@@ -481,7 +622,7 @@ next: configure
|
|||||||
updates.get("api_key").map(|s| s.as_str()),
|
updates.get("api_key").map(|s| s.as_str()),
|
||||||
Some("{{input}}")
|
Some("{{input}}")
|
||||||
);
|
);
|
||||||
assert_eq!(node.next.as_deref(), Some("configure"));
|
assert_eq!(next_target.as_deref(), Some("configure"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -627,6 +768,7 @@ timeout: 30
|
|||||||
next: review
|
next: review
|
||||||
"#;
|
"#;
|
||||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
let next_target = node.next_target().map(str::to_string);
|
||||||
let llm = match node.node_type {
|
let llm = match node.node_type {
|
||||||
NodeType::Llm(l) => l,
|
NodeType::Llm(l) => l,
|
||||||
_ => panic!("expected Llm variant"),
|
_ => panic!("expected Llm variant"),
|
||||||
@@ -646,7 +788,7 @@ next: review
|
|||||||
assert_eq!(llm.max_iterations, 5);
|
assert_eq!(llm.max_iterations, 5);
|
||||||
assert_eq!(llm.timeout, Some(30));
|
assert_eq!(llm.timeout, Some(30));
|
||||||
assert!(llm.state_updates.is_some());
|
assert!(llm.state_updates.is_some());
|
||||||
assert_eq!(node.next.as_deref(), Some("review"));
|
assert_eq!(next_target.as_deref(), Some("review"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -788,4 +930,242 @@ nodes:
|
|||||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||||
assert!(!graph.has_agent_node());
|
assert!(!graph.has_agent_node());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_static_fan_out_as_many_next_targets() {
|
||||||
|
let yaml = r#"
|
||||||
|
id: triage
|
||||||
|
type: llm
|
||||||
|
prompt: Classify
|
||||||
|
next: [retrieve_local, retrieve_web, retrieve_docs]
|
||||||
|
"#;
|
||||||
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
let targets = node.next.as_ref().expect("next should be present");
|
||||||
|
|
||||||
|
assert!(targets.is_fan_out());
|
||||||
|
assert_eq!(
|
||||||
|
targets.as_slice(),
|
||||||
|
&[
|
||||||
|
"retrieve_local".to_string(),
|
||||||
|
"retrieve_web".to_string(),
|
||||||
|
"retrieve_docs".to_string()
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_single_target_next_as_one_variant() {
|
||||||
|
let yaml = r#"
|
||||||
|
id: triage
|
||||||
|
type: llm
|
||||||
|
prompt: Classify
|
||||||
|
next: retrieve
|
||||||
|
"#;
|
||||||
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
let targets = node.next.as_ref().expect("next should be present");
|
||||||
|
|
||||||
|
assert!(!targets.is_fan_out());
|
||||||
|
assert_eq!(node.next_target(), Some("retrieve"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn next_single_errors_on_real_fan_out_with_clear_message() {
|
||||||
|
let yaml = r#"
|
||||||
|
id: triage
|
||||||
|
type: llm
|
||||||
|
prompt: Classify
|
||||||
|
next: [a, b]
|
||||||
|
"#;
|
||||||
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
let err = node.next_single().unwrap_err().to_string();
|
||||||
|
|
||||||
|
assert!(err.contains("Parallel fan-out"), "got: {err}");
|
||||||
|
assert!(err.contains("not yet implemented"), "got: {err}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn next_single_accepts_many_containing_exactly_one_target() {
|
||||||
|
let yaml = r#"
|
||||||
|
id: triage
|
||||||
|
type: llm
|
||||||
|
prompt: Classify
|
||||||
|
next: [retrieve]
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(node.next_single().unwrap(), Some("retrieve"));
|
||||||
|
assert_eq!(node.next_target(), Some("retrieve"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn next_targets_round_trips_through_yaml_for_both_variants() {
|
||||||
|
let one: NextTargets = serde_yaml::from_str(r#""foo""#).unwrap();
|
||||||
|
let reparsed: NextTargets =
|
||||||
|
serde_yaml::from_str(&serde_yaml::to_string(&one).unwrap()).unwrap();
|
||||||
|
assert_eq!(reparsed.as_slice(), &["foo".to_string()]);
|
||||||
|
|
||||||
|
let many: NextTargets = serde_yaml::from_str("[a, b, c]").unwrap();
|
||||||
|
let reparsed: NextTargets =
|
||||||
|
serde_yaml::from_str(&serde_yaml::to_string(&many).unwrap()).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
reparsed.as_slice(),
|
||||||
|
&["a".to_string(), "b".to_string(), "c".to_string()]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_reducers_block_with_all_builtins() {
|
||||||
|
let yaml = r#"
|
||||||
|
name: g
|
||||||
|
start: e
|
||||||
|
reducers:
|
||||||
|
sources: append
|
||||||
|
findings: extend
|
||||||
|
context: concat
|
||||||
|
cost_usd: sum
|
||||||
|
high_score: max
|
||||||
|
low_score: min
|
||||||
|
config: merge
|
||||||
|
forced: overwrite
|
||||||
|
nodes:
|
||||||
|
e:
|
||||||
|
type: end
|
||||||
|
output: ok
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(graph.reducers.len(), 8);
|
||||||
|
assert_eq!(graph.reducers.get("sources"), Some(&Reducer::Append));
|
||||||
|
assert_eq!(graph.reducers.get("findings"), Some(&Reducer::Extend));
|
||||||
|
assert_eq!(graph.reducers.get("context"), Some(&Reducer::Concat));
|
||||||
|
assert_eq!(graph.reducers.get("cost_usd"), Some(&Reducer::Sum));
|
||||||
|
assert_eq!(graph.reducers.get("high_score"), Some(&Reducer::Max));
|
||||||
|
assert_eq!(graph.reducers.get("low_score"), Some(&Reducer::Min));
|
||||||
|
assert_eq!(graph.reducers.get("config"), Some(&Reducer::Merge));
|
||||||
|
assert_eq!(graph.reducers.get("forced"), Some(&Reducer::Overwrite));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reducers_default_to_empty_when_block_absent() {
|
||||||
|
let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n";
|
||||||
|
|
||||||
|
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
assert!(graph.reducers.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn max_concurrency_defaults_to_eight() {
|
||||||
|
let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n";
|
||||||
|
|
||||||
|
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(graph.settings.max_concurrency, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn max_concurrency_can_be_overridden() {
|
||||||
|
let yaml = r#"
|
||||||
|
name: g
|
||||||
|
start: x
|
||||||
|
settings:
|
||||||
|
max_concurrency: 16
|
||||||
|
nodes:
|
||||||
|
x:
|
||||||
|
type: end
|
||||||
|
"#;
|
||||||
|
|
||||||
|
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(graph.settings.max_concurrency, 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parses_map_node_with_all_fields() {
|
||||||
|
let yaml = r#"
|
||||||
|
id: fan_out
|
||||||
|
type: map
|
||||||
|
over: "{{subjects}}"
|
||||||
|
as: subject
|
||||||
|
branch: research_subject
|
||||||
|
output_key: research_result
|
||||||
|
collect_into: research_results
|
||||||
|
max_concurrency: 5
|
||||||
|
next: rank
|
||||||
|
"#;
|
||||||
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
let map = match node.node_type {
|
||||||
|
NodeType::Map(m) => m,
|
||||||
|
_ => panic!("expected Map variant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(map.over, "{{subjects}}");
|
||||||
|
assert_eq!(map.as_name, "subject");
|
||||||
|
assert_eq!(map.branch, "research_subject");
|
||||||
|
assert_eq!(map.output_key, "research_result");
|
||||||
|
assert_eq!(map.collect_into, "research_results");
|
||||||
|
assert_eq!(map.max_concurrency, Some(5));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn map_node_uses_default_output_key_and_no_concurrency_cap() {
|
||||||
|
let yaml = r#"
|
||||||
|
id: fan_out
|
||||||
|
type: map
|
||||||
|
over: "{{items}}"
|
||||||
|
as: item
|
||||||
|
branch: process
|
||||||
|
collect_into: results
|
||||||
|
"#;
|
||||||
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
|
let map = match node.node_type {
|
||||||
|
NodeType::Map(m) => m,
|
||||||
|
_ => panic!("expected Map variant"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(map.output_key, "output");
|
||||||
|
assert!(map.max_concurrency.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn full_graph_with_all_new_phase_a_fields_parses() {
|
||||||
|
let yaml = r#"
|
||||||
|
name: deep_research
|
||||||
|
start: triage
|
||||||
|
settings:
|
||||||
|
max_concurrency: 4
|
||||||
|
reducers:
|
||||||
|
sources: append
|
||||||
|
cost_usd: sum
|
||||||
|
nodes:
|
||||||
|
triage:
|
||||||
|
type: llm
|
||||||
|
prompt: Classify
|
||||||
|
next: [retrieve_local, retrieve_web]
|
||||||
|
retrieve_local:
|
||||||
|
type: rag
|
||||||
|
documents: ["./docs"]
|
||||||
|
next: synthesize
|
||||||
|
retrieve_web:
|
||||||
|
type: llm
|
||||||
|
prompt: Search web
|
||||||
|
next: synthesize
|
||||||
|
synthesize:
|
||||||
|
type: end
|
||||||
|
output: done
|
||||||
|
"#;
|
||||||
|
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
assert_eq!(graph.settings.max_concurrency, 4);
|
||||||
|
assert_eq!(graph.reducers.len(), 2);
|
||||||
|
let triage = graph.get_node("triage").unwrap();
|
||||||
|
assert!(triage.next.as_ref().unwrap().is_fan_out());
|
||||||
|
assert_eq!(triage.next.as_ref().unwrap().as_slice().len(), 2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+123
-4
@@ -318,8 +318,10 @@ impl GraphValidator {
|
|||||||
|
|
||||||
fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
||||||
let mut out = Vec::new();
|
let mut out = Vec::new();
|
||||||
if let Some(n) = &node.next {
|
if let Some(targets) = &node.next {
|
||||||
out.push((n.clone(), "'next'"));
|
for target in targets.as_slice() {
|
||||||
|
out.push((target.clone(), "'next'"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match &node.node_type {
|
match &node.node_type {
|
||||||
@@ -342,6 +344,12 @@ fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
|||||||
// `agent`/`input`/`rag` route only via `next` (already collected
|
// `agent`/`input`/`rag` route only via `next` (already collected
|
||||||
// above); `end` is terminal. No type-specific routing edges to add.
|
// above); `end` is terminal. No type-specific routing edges to add.
|
||||||
NodeType::Agent(_) | NodeType::Input(_) | NodeType::Rag(_) | NodeType::End(_) => {}
|
NodeType::Agent(_) | NodeType::Input(_) | NodeType::Rag(_) | NodeType::End(_) => {}
|
||||||
|
// A `map` node invokes its `branch:` target once per item from the
|
||||||
|
// resolved `over` list. The branch is statically referenced, so it
|
||||||
|
// is a real declared edge for cycle/reachability purposes.
|
||||||
|
NodeType::Map(m) => {
|
||||||
|
out.push((m.branch.clone(), "map 'branch'"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
@@ -434,6 +442,7 @@ mod tests {
|
|||||||
conversation_starters: Vec::new(),
|
conversation_starters: Vec::new(),
|
||||||
settings: GraphSettings::default(),
|
settings: GraphSettings::default(),
|
||||||
initial_state: HashMap::new(),
|
initial_state: HashMap::new(),
|
||||||
|
reducers: HashMap::new(),
|
||||||
start: start.into(),
|
start: start.into(),
|
||||||
nodes: map,
|
nodes: map,
|
||||||
}
|
}
|
||||||
@@ -529,7 +538,7 @@ mod tests {
|
|||||||
output_schema: None,
|
output_schema: None,
|
||||||
timeout: None,
|
timeout: None,
|
||||||
}),
|
}),
|
||||||
next: next.map(String::from),
|
next: next.map(NextTargets::from),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -759,7 +768,7 @@ mod tests {
|
|||||||
output_schema: None,
|
output_schema: None,
|
||||||
timeout: None,
|
timeout: None,
|
||||||
}),
|
}),
|
||||||
next: next.map(String::from),
|
next: next.map(NextTargets::from),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1038,4 +1047,114 @@ mod tests {
|
|||||||
|
|
||||||
assert!(validator().validate(&graph).into_result().is_ok());
|
assert!(validator().validate(&graph).into_result().is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cycle_detector_treats_fan_out_diamond_as_a_valid_dag() {
|
||||||
|
let mut start = end_node("start");
|
||||||
|
start.next = Some(NextTargets::Many(vec!["a".into(), "b".into()]));
|
||||||
|
let mut a = end_node("a");
|
||||||
|
a.next = Some("join".into());
|
||||||
|
let mut b = end_node("b");
|
||||||
|
b.next = Some("join".into());
|
||||||
|
let mut join = end_node("join");
|
||||||
|
join.next = Some("end".into());
|
||||||
|
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("start", start),
|
||||||
|
("a", a),
|
||||||
|
("b", b),
|
||||||
|
("join", join),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"start",
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
assert!(
|
||||||
|
!result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("Cycle detected")),
|
||||||
|
"fan-out diamond incorrectly reported as cycle: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reachability_visits_every_member_of_many_next_targets() {
|
||||||
|
let mut start = end_node("start");
|
||||||
|
start.next = Some(NextTargets::Many(vec!["a".into(), "b".into(), "c".into()]));
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("start", start),
|
||||||
|
("a", end_node("a")),
|
||||||
|
("b", end_node("b")),
|
||||||
|
("c", end_node("c")),
|
||||||
|
],
|
||||||
|
"start",
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
for orphan in ["a", "b", "c"] {
|
||||||
|
assert!(
|
||||||
|
!result
|
||||||
|
.warnings
|
||||||
|
.iter()
|
||||||
|
.any(|w| w.node_id.as_deref() == Some(orphan)
|
||||||
|
&& w.message.contains("unreachable")),
|
||||||
|
"fan-out target '{orphan}' incorrectly marked unreachable: {:?}",
|
||||||
|
result.warnings
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn node_reference_check_catches_missing_member_inside_many() {
|
||||||
|
let mut start = end_node("start");
|
||||||
|
start.next = Some(NextTargets::Many(vec!["a".into(), "ghost".into()]));
|
||||||
|
let graph = graph_with(vec![("start", start), ("a", end_node("a"))], "start");
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("non-existent node 'ghost'")
|
||||||
|
&& e.node_id.as_deref() == Some("start")),
|
||||||
|
"expected error for missing 'ghost' target in Many: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn node_reference_check_catches_missing_map_branch_target() {
|
||||||
|
let map = Node {
|
||||||
|
id: "fan".into(),
|
||||||
|
description: String::new(),
|
||||||
|
node_type: NodeType::Map(MapNode {
|
||||||
|
over: "{{items}}".into(),
|
||||||
|
as_name: "item".into(),
|
||||||
|
branch: "no_such_node".into(),
|
||||||
|
output_key: "output".into(),
|
||||||
|
collect_into: "results".into(),
|
||||||
|
max_concurrency: None,
|
||||||
|
}),
|
||||||
|
next: Some("end".into()),
|
||||||
|
};
|
||||||
|
let graph = graph_with(vec![("fan", map), ("end", end_node("end"))], "fan");
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("non-existent node 'no_such_node'")
|
||||||
|
&& e.message.contains("map 'branch'")),
|
||||||
|
"expected error for missing map branch: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user