fix: Added additional graph validation for parallel reads and writes with dependencies between nodes states
This commit is contained in:
@@ -353,6 +353,24 @@ fn single_reference_key(template: &str) -> Option<&str> {
|
|||||||
valid.then_some(inner)
|
valid.then_some(inner)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the root state keys referenced by any `{{...}}` expressions in the
|
||||||
|
// given template string. The "root key" is the identifier before the first
|
||||||
|
// `.` or `[` — i.e. for `{{user.name}}` the root is `user`, for `{{items[0]}}`
|
||||||
|
// the root is `items`. Used by the validator to compute the static read-set of
|
||||||
|
// a node's templated fields without depending on a runtime `StateManager`.
|
||||||
|
pub(super) fn template_root_keys(template: &str) -> Vec<String> {
|
||||||
|
TEMPLATE_VAR_RE
|
||||||
|
.captures_iter(template)
|
||||||
|
.flatten()
|
||||||
|
.filter_map(|c| c.get(1))
|
||||||
|
.map(|m| {
|
||||||
|
let inner = m.as_str();
|
||||||
|
let cut = inner.find(['.', '[']).unwrap_or(inner.len());
|
||||||
|
inner[..cut].to_string()
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
fn value_to_string(value: &Value) -> String {
|
fn value_to_string(value: &Value) -> String {
|
||||||
match value {
|
match value {
|
||||||
Value::String(s) => s.clone(),
|
Value::String(s) => s.clone(),
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use super::state::template_root_keys;
|
||||||
use super::types::{Graph, Node, NodeType};
|
use super::types::{Graph, Node, NodeType};
|
||||||
use crate::client::{Model, ModelType};
|
use crate::client::{Model, ModelType};
|
||||||
use crate::config::{Agent, AppConfig, paths};
|
use crate::config::{Agent, AppConfig, paths};
|
||||||
@@ -122,6 +123,7 @@ impl GraphValidator {
|
|||||||
self.validate_map_branches(graph, &mut result);
|
self.validate_map_branches(graph, &mut result);
|
||||||
self.validate_parallel_user_interaction(graph, &mut result);
|
self.validate_parallel_user_interaction(graph, &mut result);
|
||||||
self.validate_parallel_writes(graph, &mut result);
|
self.validate_parallel_writes(graph, &mut result);
|
||||||
|
self.validate_parallel_reads(graph, &mut result);
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
@@ -539,6 +541,51 @@ impl GraphValidator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_parallel_reads(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||||
|
for group in compute_parallel_groups(graph) {
|
||||||
|
let nodes: Vec<(&String, &Node)> = group
|
||||||
|
.iter()
|
||||||
|
.filter_map(|id| graph.nodes.get(id).map(|n| (id, n)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
for (id_a, node_a) in &nodes {
|
||||||
|
let read_set_a = read_set_of(node_a);
|
||||||
|
if read_set_a.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (id_b, node_b) in &nodes {
|
||||||
|
if id_b == id_a {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let Some(write_set_b) = write_set_of(node_b) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let mut collisions: Vec<String> =
|
||||||
|
read_set_a.intersection(&write_set_b).cloned().collect();
|
||||||
|
if collisions.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
collisions.sort();
|
||||||
|
let keys = collisions
|
||||||
|
.iter()
|
||||||
|
.map(|k| format!("`{k}`"))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(", ");
|
||||||
|
result.error(ValidationError::with_node(
|
||||||
|
id_a.as_str(),
|
||||||
|
format!(
|
||||||
|
"node '{id_a}' reads state key(s) {keys} which sibling parallel \
|
||||||
|
branch '{id_b}' writes in the same super-step; parallel branches \
|
||||||
|
see a state snapshot taken BEFORE the super-step and cannot observe \
|
||||||
|
each other's writes. Move the dependent read to a later super-step \
|
||||||
|
(or remove the cross-branch reference)."
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
||||||
@@ -646,6 +693,103 @@ fn write_set_of(node: &Node) -> Option<HashSet<String>> {
|
|||||||
Some(writes)
|
Some(writes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes the set of root state keys this node's templated fields read from.
|
||||||
|
//
|
||||||
|
// "Root key" follows the same definition as `template_root_keys`: for a
|
||||||
|
// reference like `{{user.name}}` or `{{items[0]}}`, the root is the bare
|
||||||
|
// identifier before the first `.` or `[`.
|
||||||
|
//
|
||||||
|
// Templated fields scanned per node type:
|
||||||
|
// - llm: instructions, prompt, state_updates values
|
||||||
|
// - agent: prompt, state_updates values
|
||||||
|
// - rag: query (defaulting to "{{initial_prompt}}"), state_updates values
|
||||||
|
// - approval: question, state_updates values
|
||||||
|
// - input: question, default, state_updates values
|
||||||
|
// - end: output, state_updates values
|
||||||
|
// - map: over (its `{{...}}` IS the dynamic read of the list to fan out over)
|
||||||
|
// - script: state_updates values only (the script body is opaque to static
|
||||||
|
// analysis; its reads via GRAPH_STATE / GRAPH_STATE_FILE can't be
|
||||||
|
// inferred at load time)
|
||||||
|
//
|
||||||
|
// Scoped variables produced by THIS node's own execution are excluded from
|
||||||
|
// state_updates value scanning:
|
||||||
|
// - llm/agent/rag → "output" (the node's body output)
|
||||||
|
// - approval → "choice" (the user's selected option)
|
||||||
|
// - input → "input" (the user's typed text)
|
||||||
|
// These are bindings created inside the node, not reads from prior state, so
|
||||||
|
// they cannot race with a sibling's writes.
|
||||||
|
fn read_set_of(node: &Node) -> HashSet<String> {
|
||||||
|
let mut reads: HashSet<String> = HashSet::new();
|
||||||
|
let scoped: &[&str] = match &node.node_type {
|
||||||
|
NodeType::Llm(_) | NodeType::Agent(_) | NodeType::Rag(_) => &["output"],
|
||||||
|
NodeType::Approval(_) => &["choice"],
|
||||||
|
NodeType::Input(_) => &["input"],
|
||||||
|
NodeType::Script(_) | NodeType::End(_) | NodeType::Map(_) => &[],
|
||||||
|
};
|
||||||
|
|
||||||
|
for s in primary_templated_fields(node) {
|
||||||
|
for k in template_root_keys(&s) {
|
||||||
|
reads.insert(k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(updates) = node_state_updates_map(node) {
|
||||||
|
for v in updates.values() {
|
||||||
|
for k in template_root_keys(v) {
|
||||||
|
if !scoped.contains(&k.as_str()) {
|
||||||
|
reads.insert(k);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
reads
|
||||||
|
}
|
||||||
|
|
||||||
|
fn primary_templated_fields(node: &Node) -> Vec<String> {
|
||||||
|
match &node.node_type {
|
||||||
|
NodeType::Llm(n) => {
|
||||||
|
let mut v = vec![n.prompt.clone()];
|
||||||
|
if let Some(i) = &n.instructions {
|
||||||
|
v.push(i.clone());
|
||||||
|
}
|
||||||
|
v
|
||||||
|
}
|
||||||
|
NodeType::Agent(n) => vec![n.prompt.clone()],
|
||||||
|
NodeType::Rag(n) => {
|
||||||
|
vec![
|
||||||
|
n.query
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| "{{initial_prompt}}".to_string()),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
NodeType::Approval(n) => vec![n.question.clone()],
|
||||||
|
NodeType::Input(n) => {
|
||||||
|
let mut v = vec![n.question.clone()];
|
||||||
|
if let Some(d) = &n.default {
|
||||||
|
v.push(d.clone());
|
||||||
|
}
|
||||||
|
v
|
||||||
|
}
|
||||||
|
NodeType::End(n) => vec![n.output.clone()],
|
||||||
|
NodeType::Map(n) => vec![n.over.clone()],
|
||||||
|
NodeType::Script(_) => Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn node_state_updates_map(node: &Node) -> Option<&std::collections::HashMap<String, String>> {
|
||||||
|
match &node.node_type {
|
||||||
|
NodeType::Llm(n) => n.state_updates.as_ref(),
|
||||||
|
NodeType::Agent(n) => n.state_updates.as_ref(),
|
||||||
|
NodeType::Rag(n) => n.state_updates.as_ref(),
|
||||||
|
NodeType::Approval(n) => n.state_updates.as_ref(),
|
||||||
|
NodeType::Input(n) => n.state_updates.as_ref(),
|
||||||
|
NodeType::Script(n) => n.state_updates.as_ref(),
|
||||||
|
NodeType::End(n) => n.state_updates.as_ref(),
|
||||||
|
NodeType::Map(_) => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn node_state_updates_keys(node: &Node) -> Option<HashSet<String>> {
|
fn node_state_updates_keys(node: &Node) -> Option<HashSet<String>> {
|
||||||
let updates = match &node.node_type {
|
let updates = match &node.node_type {
|
||||||
NodeType::Agent(n) => n.state_updates.as_ref(),
|
NodeType::Agent(n) => n.state_updates.as_ref(),
|
||||||
@@ -2064,4 +2208,140 @@ mod tests {
|
|||||||
result.errors
|
result.errors
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn llm_with_prompt(id: &str, prompt: &str, next: Option<&str>) -> Node {
|
||||||
|
let mut node = llm_node(id, None, next);
|
||||||
|
if let NodeType::Llm(ref mut n) = node.node_type {
|
||||||
|
n.prompt = prompt.into();
|
||||||
|
}
|
||||||
|
node
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parallel_read_of_sibling_write_errors() {
|
||||||
|
let reader = llm_with_prompt("worker_a", "Hello {{summary}}!", Some("end"));
|
||||||
|
let writer = llm_with_state_updates("worker_b", &[("summary", "static")], Some("end"));
|
||||||
|
let graph = fan_out_graph_with_two_workers(reader, writer);
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("reads state key(s) `summary`")
|
||||||
|
&& e.message.contains("'worker_b'")),
|
||||||
|
"expected cross-branch read error mentioning `summary` and sibling writer: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parallel_read_of_upstream_key_passes() {
|
||||||
|
let reader_a = llm_with_prompt("worker_a", "Topic is {{topic}}", Some("end"));
|
||||||
|
let reader_b = llm_with_prompt("worker_b", "Also {{topic}}", Some("end"));
|
||||||
|
let graph = fan_out_graph_with_two_workers(reader_a, reader_b);
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("reads state key")),
|
||||||
|
"upstream `topic` shouldn't trigger cross-branch read error: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scoped_output_var_in_state_updates_not_treated_as_read() {
|
||||||
|
let scoped_user =
|
||||||
|
llm_with_state_updates("worker_a", &[("a_key", "{{output}}")], Some("end"));
|
||||||
|
let writes_output =
|
||||||
|
llm_with_state_updates("worker_b", &[("output", "{{output}}")], Some("end"));
|
||||||
|
let graph = fan_out_graph_with_two_workers(scoped_user, writes_output);
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("reads state key(s) `output`")
|
||||||
|
&& e.message.contains("worker_a")),
|
||||||
|
"scoped `{{{{output}}}}` inside state_updates value should NOT be treated as a read: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rag_query_reading_sibling_script_write_errors() {
|
||||||
|
let mut rag = rag_node("worker_a", &["./k"], true);
|
||||||
|
if let NodeType::Rag(ref mut n) = rag.node_type {
|
||||||
|
n.query = Some("codes: {{loinc_codes}}\n{{db_result}}".into());
|
||||||
|
if let Some(m) = n.state_updates.as_mut() {
|
||||||
|
m.insert("rag_ctx".into(), "{{output.context}}".into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rag.next = Some("end".into());
|
||||||
|
let mut script = script_with_state_updates("worker_b", &[("db_result", "{{output}}")]);
|
||||||
|
script.next = Some("end".into());
|
||||||
|
let graph = fan_out_graph_with_two_workers(rag, script);
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("reads state key(s) `db_result`")
|
||||||
|
&& e.message.contains("'worker_b'")),
|
||||||
|
"expected cross-branch read error for rag query reading db_result: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn map_over_reading_sibling_write_errors() {
|
||||||
|
let map_n = Node {
|
||||||
|
id: "fan".into(),
|
||||||
|
description: String::new(),
|
||||||
|
node_type: NodeType::Map(MapNode {
|
||||||
|
over: "{{items}}".into(),
|
||||||
|
as_name: "item".into(),
|
||||||
|
branch: "branch_n".into(),
|
||||||
|
output_key: "output".into(),
|
||||||
|
collect_into: "results".into(),
|
||||||
|
max_concurrency: None,
|
||||||
|
}),
|
||||||
|
next: Some("end".into()),
|
||||||
|
};
|
||||||
|
let branch_n = llm_with_prompt("branch_n", "Process {{item}}", None);
|
||||||
|
let producer = llm_with_state_updates("producer", &[("items", "[1,2,3]")], Some("end"));
|
||||||
|
let mut start = end_node("start");
|
||||||
|
start.next = Some(NextTargets::Many(vec!["fan".into(), "producer".into()]));
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("start", start),
|
||||||
|
("fan", map_n),
|
||||||
|
("branch_n", branch_n),
|
||||||
|
("producer", producer),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"start",
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("reads state key(s) `items`")
|
||||||
|
&& e.message.contains("'producer'")),
|
||||||
|
"expected cross-branch read error for map `over` reading sibling write: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user