fix: Added additional graph validation for parallel reads and writes with dependencies between nodes states
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
use super::state::template_root_keys;
|
||||
use super::types::{Graph, Node, NodeType};
|
||||
use crate::client::{Model, ModelType};
|
||||
use crate::config::{Agent, AppConfig, paths};
|
||||
@@ -122,6 +123,7 @@ impl GraphValidator {
|
||||
self.validate_map_branches(graph, &mut result);
|
||||
self.validate_parallel_user_interaction(graph, &mut result);
|
||||
self.validate_parallel_writes(graph, &mut result);
|
||||
self.validate_parallel_reads(graph, &mut result);
|
||||
|
||||
result
|
||||
}
|
||||
@@ -539,6 +541,51 @@ impl GraphValidator {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_parallel_reads(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||
for group in compute_parallel_groups(graph) {
|
||||
let nodes: Vec<(&String, &Node)> = group
|
||||
.iter()
|
||||
.filter_map(|id| graph.nodes.get(id).map(|n| (id, n)))
|
||||
.collect();
|
||||
|
||||
for (id_a, node_a) in &nodes {
|
||||
let read_set_a = read_set_of(node_a);
|
||||
if read_set_a.is_empty() {
|
||||
continue;
|
||||
}
|
||||
for (id_b, node_b) in &nodes {
|
||||
if id_b == id_a {
|
||||
continue;
|
||||
}
|
||||
let Some(write_set_b) = write_set_of(node_b) else {
|
||||
continue;
|
||||
};
|
||||
let mut collisions: Vec<String> =
|
||||
read_set_a.intersection(&write_set_b).cloned().collect();
|
||||
if collisions.is_empty() {
|
||||
continue;
|
||||
}
|
||||
collisions.sort();
|
||||
let keys = collisions
|
||||
.iter()
|
||||
.map(|k| format!("`{k}`"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
result.error(ValidationError::with_node(
|
||||
id_a.as_str(),
|
||||
format!(
|
||||
"node '{id_a}' reads state key(s) {keys} which sibling parallel \
|
||||
branch '{id_b}' writes in the same super-step; parallel branches \
|
||||
see a state snapshot taken BEFORE the super-step and cannot observe \
|
||||
each other's writes. Move the dependent read to a later super-step \
|
||||
(or remove the cross-branch reference)."
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
||||
@@ -646,6 +693,103 @@ fn write_set_of(node: &Node) -> Option<HashSet<String>> {
|
||||
Some(writes)
|
||||
}
|
||||
|
||||
// Computes the set of root state keys this node's templated fields read from.
|
||||
//
|
||||
// "Root key" follows the same definition as `template_root_keys`: for a
|
||||
// reference like `{{user.name}}` or `{{items[0]}}`, the root is the bare
|
||||
// identifier before the first `.` or `[`.
|
||||
//
|
||||
// Templated fields scanned per node type:
|
||||
// - llm: instructions, prompt, state_updates values
|
||||
// - agent: prompt, state_updates values
|
||||
// - rag: query (defaulting to "{{initial_prompt}}"), state_updates values
|
||||
// - approval: question, state_updates values
|
||||
// - input: question, default, state_updates values
|
||||
// - end: output, state_updates values
|
||||
// - map: over (its `{{...}}` IS the dynamic read of the list to fan out over)
|
||||
// - script: state_updates values only (the script body is opaque to static
|
||||
// analysis; its reads via GRAPH_STATE / GRAPH_STATE_FILE can't be
|
||||
// inferred at load time)
|
||||
//
|
||||
// Scoped variables produced by THIS node's own execution are excluded from
|
||||
// state_updates value scanning:
|
||||
// - llm/agent/rag → "output" (the node's body output)
|
||||
// - approval → "choice" (the user's selected option)
|
||||
// - input → "input" (the user's typed text)
|
||||
// These are bindings created inside the node, not reads from prior state, so
|
||||
// they cannot race with a sibling's writes.
|
||||
fn read_set_of(node: &Node) -> HashSet<String> {
|
||||
let mut reads: HashSet<String> = HashSet::new();
|
||||
let scoped: &[&str] = match &node.node_type {
|
||||
NodeType::Llm(_) | NodeType::Agent(_) | NodeType::Rag(_) => &["output"],
|
||||
NodeType::Approval(_) => &["choice"],
|
||||
NodeType::Input(_) => &["input"],
|
||||
NodeType::Script(_) | NodeType::End(_) | NodeType::Map(_) => &[],
|
||||
};
|
||||
|
||||
for s in primary_templated_fields(node) {
|
||||
for k in template_root_keys(&s) {
|
||||
reads.insert(k);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(updates) = node_state_updates_map(node) {
|
||||
for v in updates.values() {
|
||||
for k in template_root_keys(v) {
|
||||
if !scoped.contains(&k.as_str()) {
|
||||
reads.insert(k);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reads
|
||||
}
|
||||
|
||||
fn primary_templated_fields(node: &Node) -> Vec<String> {
|
||||
match &node.node_type {
|
||||
NodeType::Llm(n) => {
|
||||
let mut v = vec![n.prompt.clone()];
|
||||
if let Some(i) = &n.instructions {
|
||||
v.push(i.clone());
|
||||
}
|
||||
v
|
||||
}
|
||||
NodeType::Agent(n) => vec![n.prompt.clone()],
|
||||
NodeType::Rag(n) => {
|
||||
vec![
|
||||
n.query
|
||||
.clone()
|
||||
.unwrap_or_else(|| "{{initial_prompt}}".to_string()),
|
||||
]
|
||||
}
|
||||
NodeType::Approval(n) => vec![n.question.clone()],
|
||||
NodeType::Input(n) => {
|
||||
let mut v = vec![n.question.clone()];
|
||||
if let Some(d) = &n.default {
|
||||
v.push(d.clone());
|
||||
}
|
||||
v
|
||||
}
|
||||
NodeType::End(n) => vec![n.output.clone()],
|
||||
NodeType::Map(n) => vec![n.over.clone()],
|
||||
NodeType::Script(_) => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn node_state_updates_map(node: &Node) -> Option<&std::collections::HashMap<String, String>> {
|
||||
match &node.node_type {
|
||||
NodeType::Llm(n) => n.state_updates.as_ref(),
|
||||
NodeType::Agent(n) => n.state_updates.as_ref(),
|
||||
NodeType::Rag(n) => n.state_updates.as_ref(),
|
||||
NodeType::Approval(n) => n.state_updates.as_ref(),
|
||||
NodeType::Input(n) => n.state_updates.as_ref(),
|
||||
NodeType::Script(n) => n.state_updates.as_ref(),
|
||||
NodeType::End(n) => n.state_updates.as_ref(),
|
||||
NodeType::Map(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn node_state_updates_keys(node: &Node) -> Option<HashSet<String>> {
|
||||
let updates = match &node.node_type {
|
||||
NodeType::Agent(n) => n.state_updates.as_ref(),
|
||||
@@ -2064,4 +2208,140 @@ mod tests {
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
fn llm_with_prompt(id: &str, prompt: &str, next: Option<&str>) -> Node {
|
||||
let mut node = llm_node(id, None, next);
|
||||
if let NodeType::Llm(ref mut n) = node.node_type {
|
||||
n.prompt = prompt.into();
|
||||
}
|
||||
node
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parallel_read_of_sibling_write_errors() {
|
||||
let reader = llm_with_prompt("worker_a", "Hello {{summary}}!", Some("end"));
|
||||
let writer = llm_with_state_updates("worker_b", &[("summary", "static")], Some("end"));
|
||||
let graph = fan_out_graph_with_two_workers(reader, writer);
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("reads state key(s) `summary`")
|
||||
&& e.message.contains("'worker_b'")),
|
||||
"expected cross-branch read error mentioning `summary` and sibling writer: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parallel_read_of_upstream_key_passes() {
|
||||
let reader_a = llm_with_prompt("worker_a", "Topic is {{topic}}", Some("end"));
|
||||
let reader_b = llm_with_prompt("worker_b", "Also {{topic}}", Some("end"));
|
||||
let graph = fan_out_graph_with_two_workers(reader_a, reader_b);
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(
|
||||
!result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("reads state key")),
|
||||
"upstream `topic` shouldn't trigger cross-branch read error: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scoped_output_var_in_state_updates_not_treated_as_read() {
|
||||
let scoped_user =
|
||||
llm_with_state_updates("worker_a", &[("a_key", "{{output}}")], Some("end"));
|
||||
let writes_output =
|
||||
llm_with_state_updates("worker_b", &[("output", "{{output}}")], Some("end"));
|
||||
let graph = fan_out_graph_with_two_workers(scoped_user, writes_output);
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(
|
||||
!result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("reads state key(s) `output`")
|
||||
&& e.message.contains("worker_a")),
|
||||
"scoped `{{{{output}}}}` inside state_updates value should NOT be treated as a read: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rag_query_reading_sibling_script_write_errors() {
|
||||
let mut rag = rag_node("worker_a", &["./k"], true);
|
||||
if let NodeType::Rag(ref mut n) = rag.node_type {
|
||||
n.query = Some("codes: {{loinc_codes}}\n{{db_result}}".into());
|
||||
if let Some(m) = n.state_updates.as_mut() {
|
||||
m.insert("rag_ctx".into(), "{{output.context}}".into());
|
||||
}
|
||||
}
|
||||
rag.next = Some("end".into());
|
||||
let mut script = script_with_state_updates("worker_b", &[("db_result", "{{output}}")]);
|
||||
script.next = Some("end".into());
|
||||
let graph = fan_out_graph_with_two_workers(rag, script);
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("reads state key(s) `db_result`")
|
||||
&& e.message.contains("'worker_b'")),
|
||||
"expected cross-branch read error for rag query reading db_result: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn map_over_reading_sibling_write_errors() {
|
||||
let map_n = Node {
|
||||
id: "fan".into(),
|
||||
description: String::new(),
|
||||
node_type: NodeType::Map(MapNode {
|
||||
over: "{{items}}".into(),
|
||||
as_name: "item".into(),
|
||||
branch: "branch_n".into(),
|
||||
output_key: "output".into(),
|
||||
collect_into: "results".into(),
|
||||
max_concurrency: None,
|
||||
}),
|
||||
next: Some("end".into()),
|
||||
};
|
||||
let branch_n = llm_with_prompt("branch_n", "Process {{item}}", None);
|
||||
let producer = llm_with_state_updates("producer", &[("items", "[1,2,3]")], Some("end"));
|
||||
let mut start = end_node("start");
|
||||
start.next = Some(NextTargets::Many(vec!["fan".into(), "producer".into()]));
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
("start", start),
|
||||
("fan", map_n),
|
||||
("branch_n", branch_n),
|
||||
("producer", producer),
|
||||
("end", end_node("end")),
|
||||
],
|
||||
"start",
|
||||
);
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("reads state key(s) `items`")
|
||||
&& e.message.contains("'producer'")),
|
||||
"expected cross-branch read error for map `over` reading sibling write: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user