fix: Added additional graph validation for parallel reads and writes with dependencies between nodes states

This commit is contained in:
2026-05-20 17:35:33 -06:00
parent 4536d00067
commit 3c7d19da07
2 changed files with 298 additions and 0 deletions
+18
View File
@@ -353,6 +353,24 @@ fn single_reference_key(template: &str) -> Option<&str> {
valid.then_some(inner)
}
// Returns the root state keys referenced by any `{{...}}` expressions in the
// given template string. The "root key" is the identifier before the first
// `.` or `[` — i.e. for `{{user.name}}` the root is `user`, for `{{items[0]}}`
// the root is `items`. Used by the validator to compute the static read-set of
// a node's templated fields without depending on a runtime `StateManager`.
pub(super) fn template_root_keys(template: &str) -> Vec<String> {
TEMPLATE_VAR_RE
.captures_iter(template)
.flatten()
.filter_map(|c| c.get(1))
.map(|m| {
let inner = m.as_str();
let cut = inner.find(['.', '[']).unwrap_or(inner.len());
inner[..cut].to_string()
})
.collect()
}
fn value_to_string(value: &Value) -> String {
match value {
Value::String(s) => s.clone(),
+280
View File
@@ -1,3 +1,4 @@
use super::state::template_root_keys;
use super::types::{Graph, Node, NodeType};
use crate::client::{Model, ModelType};
use crate::config::{Agent, AppConfig, paths};
@@ -122,6 +123,7 @@ impl GraphValidator {
self.validate_map_branches(graph, &mut result);
self.validate_parallel_user_interaction(graph, &mut result);
self.validate_parallel_writes(graph, &mut result);
self.validate_parallel_reads(graph, &mut result);
result
}
@@ -539,6 +541,51 @@ impl GraphValidator {
}
}
}
fn validate_parallel_reads(&self, graph: &Graph, result: &mut ValidationResult) {
for group in compute_parallel_groups(graph) {
let nodes: Vec<(&String, &Node)> = group
.iter()
.filter_map(|id| graph.nodes.get(id).map(|n| (id, n)))
.collect();
for (id_a, node_a) in &nodes {
let read_set_a = read_set_of(node_a);
if read_set_a.is_empty() {
continue;
}
for (id_b, node_b) in &nodes {
if id_b == id_a {
continue;
}
let Some(write_set_b) = write_set_of(node_b) else {
continue;
};
let mut collisions: Vec<String> =
read_set_a.intersection(&write_set_b).cloned().collect();
if collisions.is_empty() {
continue;
}
collisions.sort();
let keys = collisions
.iter()
.map(|k| format!("`{k}`"))
.collect::<Vec<_>>()
.join(", ");
result.error(ValidationError::with_node(
id_a.as_str(),
format!(
"node '{id_a}' reads state key(s) {keys} which sibling parallel \
branch '{id_b}' writes in the same super-step; parallel branches \
see a state snapshot taken BEFORE the super-step and cannot observe \
each other's writes. Move the dependent read to a later super-step \
(or remove the cross-branch reference)."
),
));
}
}
}
}
}
fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
@@ -646,6 +693,103 @@ fn write_set_of(node: &Node) -> Option<HashSet<String>> {
Some(writes)
}
// Computes the set of root state keys this node's templated fields read from.
//
// "Root key" follows the same definition as `template_root_keys`: for a
// reference like `{{user.name}}` or `{{items[0]}}`, the root is the bare
// identifier before the first `.` or `[`.
//
// Templated fields scanned per node type:
// - llm: instructions, prompt, state_updates values
// - agent: prompt, state_updates values
// - rag: query (defaulting to "{{initial_prompt}}"), state_updates values
// - approval: question, state_updates values
// - input: question, default, state_updates values
// - end: output, state_updates values
// - map: over (its `{{...}}` IS the dynamic read of the list to fan out over)
// - script: state_updates values only (the script body is opaque to static
// analysis; its reads via GRAPH_STATE / GRAPH_STATE_FILE can't be
// inferred at load time)
//
// Scoped variables produced by THIS node's own execution are excluded from
// state_updates value scanning:
// - llm/agent/rag → "output" (the node's body output)
// - approval → "choice" (the user's selected option)
// - input → "input" (the user's typed text)
// These are bindings created inside the node, not reads from prior state, so
// they cannot race with a sibling's writes.
fn read_set_of(node: &Node) -> HashSet<String> {
let mut reads: HashSet<String> = HashSet::new();
let scoped: &[&str] = match &node.node_type {
NodeType::Llm(_) | NodeType::Agent(_) | NodeType::Rag(_) => &["output"],
NodeType::Approval(_) => &["choice"],
NodeType::Input(_) => &["input"],
NodeType::Script(_) | NodeType::End(_) | NodeType::Map(_) => &[],
};
for s in primary_templated_fields(node) {
for k in template_root_keys(&s) {
reads.insert(k);
}
}
if let Some(updates) = node_state_updates_map(node) {
for v in updates.values() {
for k in template_root_keys(v) {
if !scoped.contains(&k.as_str()) {
reads.insert(k);
}
}
}
}
reads
}
fn primary_templated_fields(node: &Node) -> Vec<String> {
match &node.node_type {
NodeType::Llm(n) => {
let mut v = vec![n.prompt.clone()];
if let Some(i) = &n.instructions {
v.push(i.clone());
}
v
}
NodeType::Agent(n) => vec![n.prompt.clone()],
NodeType::Rag(n) => {
vec![
n.query
.clone()
.unwrap_or_else(|| "{{initial_prompt}}".to_string()),
]
}
NodeType::Approval(n) => vec![n.question.clone()],
NodeType::Input(n) => {
let mut v = vec![n.question.clone()];
if let Some(d) = &n.default {
v.push(d.clone());
}
v
}
NodeType::End(n) => vec![n.output.clone()],
NodeType::Map(n) => vec![n.over.clone()],
NodeType::Script(_) => Vec::new(),
}
}
fn node_state_updates_map(node: &Node) -> Option<&std::collections::HashMap<String, String>> {
match &node.node_type {
NodeType::Llm(n) => n.state_updates.as_ref(),
NodeType::Agent(n) => n.state_updates.as_ref(),
NodeType::Rag(n) => n.state_updates.as_ref(),
NodeType::Approval(n) => n.state_updates.as_ref(),
NodeType::Input(n) => n.state_updates.as_ref(),
NodeType::Script(n) => n.state_updates.as_ref(),
NodeType::End(n) => n.state_updates.as_ref(),
NodeType::Map(_) => None,
}
}
fn node_state_updates_keys(node: &Node) -> Option<HashSet<String>> {
let updates = match &node.node_type {
NodeType::Agent(n) => n.state_updates.as_ref(),
@@ -2064,4 +2208,140 @@ mod tests {
result.errors
);
}
fn llm_with_prompt(id: &str, prompt: &str, next: Option<&str>) -> Node {
let mut node = llm_node(id, None, next);
if let NodeType::Llm(ref mut n) = node.node_type {
n.prompt = prompt.into();
}
node
}
#[test]
fn parallel_read_of_sibling_write_errors() {
let reader = llm_with_prompt("worker_a", "Hello {{summary}}!", Some("end"));
let writer = llm_with_state_updates("worker_b", &[("summary", "static")], Some("end"));
let graph = fan_out_graph_with_two_workers(reader, writer);
let result = validator().validate(&graph);
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("reads state key(s) `summary`")
&& e.message.contains("'worker_b'")),
"expected cross-branch read error mentioning `summary` and sibling writer: {:?}",
result.errors
);
}
#[test]
fn parallel_read_of_upstream_key_passes() {
let reader_a = llm_with_prompt("worker_a", "Topic is {{topic}}", Some("end"));
let reader_b = llm_with_prompt("worker_b", "Also {{topic}}", Some("end"));
let graph = fan_out_graph_with_two_workers(reader_a, reader_b);
let result = validator().validate(&graph);
assert!(
!result
.errors
.iter()
.any(|e| e.message.contains("reads state key")),
"upstream `topic` shouldn't trigger cross-branch read error: {:?}",
result.errors
);
}
#[test]
fn scoped_output_var_in_state_updates_not_treated_as_read() {
let scoped_user =
llm_with_state_updates("worker_a", &[("a_key", "{{output}}")], Some("end"));
let writes_output =
llm_with_state_updates("worker_b", &[("output", "{{output}}")], Some("end"));
let graph = fan_out_graph_with_two_workers(scoped_user, writes_output);
let result = validator().validate(&graph);
assert!(
!result
.errors
.iter()
.any(|e| e.message.contains("reads state key(s) `output`")
&& e.message.contains("worker_a")),
"scoped `{{{{output}}}}` inside state_updates value should NOT be treated as a read: {:?}",
result.errors
);
}
#[test]
fn rag_query_reading_sibling_script_write_errors() {
let mut rag = rag_node("worker_a", &["./k"], true);
if let NodeType::Rag(ref mut n) = rag.node_type {
n.query = Some("codes: {{loinc_codes}}\n{{db_result}}".into());
if let Some(m) = n.state_updates.as_mut() {
m.insert("rag_ctx".into(), "{{output.context}}".into());
}
}
rag.next = Some("end".into());
let mut script = script_with_state_updates("worker_b", &[("db_result", "{{output}}")]);
script.next = Some("end".into());
let graph = fan_out_graph_with_two_workers(rag, script);
let result = validator().validate(&graph);
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("reads state key(s) `db_result`")
&& e.message.contains("'worker_b'")),
"expected cross-branch read error for rag query reading db_result: {:?}",
result.errors
);
}
#[test]
fn map_over_reading_sibling_write_errors() {
let map_n = Node {
id: "fan".into(),
description: String::new(),
node_type: NodeType::Map(MapNode {
over: "{{items}}".into(),
as_name: "item".into(),
branch: "branch_n".into(),
output_key: "output".into(),
collect_into: "results".into(),
max_concurrency: None,
}),
next: Some("end".into()),
};
let branch_n = llm_with_prompt("branch_n", "Process {{item}}", None);
let producer = llm_with_state_updates("producer", &[("items", "[1,2,3]")], Some("end"));
let mut start = end_node("start");
start.next = Some(NextTargets::Many(vec!["fan".into(), "producer".into()]));
let graph = graph_with(
vec![
("start", start),
("fan", map_n),
("branch_n", branch_n),
("producer", producer),
("end", end_node("end")),
],
"start",
);
let result = validator().validate(&graph);
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("reads state key(s) `items`")
&& e.message.contains("'producer'")),
"expected cross-branch read error for map `over` reading sibling write: {:?}",
result.errors
);
}
}