diff --git a/src/graph/state.rs b/src/graph/state.rs index d242273..5f8fb72 100644 --- a/src/graph/state.rs +++ b/src/graph/state.rs @@ -353,6 +353,24 @@ fn single_reference_key(template: &str) -> Option<&str> { valid.then_some(inner) } +// Returns the root state keys referenced by any `{{...}}` expressions in the +// given template string. The "root key" is the identifier before the first +// `.` or `[` — i.e. for `{{user.name}}` the root is `user`, for `{{items[0]}}` +// the root is `items`. Used by the validator to compute the static read-set of +// a node's templated fields without depending on a runtime `StateManager`. +pub(super) fn template_root_keys(template: &str) -> Vec { + TEMPLATE_VAR_RE + .captures_iter(template) + .flatten() + .filter_map(|c| c.get(1)) + .map(|m| { + let inner = m.as_str(); + let cut = inner.find(['.', '[']).unwrap_or(inner.len()); + inner[..cut].to_string() + }) + .collect() +} + fn value_to_string(value: &Value) -> String { match value { Value::String(s) => s.clone(), diff --git a/src/graph/validator.rs b/src/graph/validator.rs index 8785b4a..01cd166 100644 --- a/src/graph/validator.rs +++ b/src/graph/validator.rs @@ -1,3 +1,4 @@ +use super::state::template_root_keys; use super::types::{Graph, Node, NodeType}; use crate::client::{Model, ModelType}; use crate::config::{Agent, AppConfig, paths}; @@ -122,6 +123,7 @@ impl GraphValidator { self.validate_map_branches(graph, &mut result); self.validate_parallel_user_interaction(graph, &mut result); self.validate_parallel_writes(graph, &mut result); + self.validate_parallel_reads(graph, &mut result); result } @@ -539,6 +541,51 @@ impl GraphValidator { } } } + + fn validate_parallel_reads(&self, graph: &Graph, result: &mut ValidationResult) { + for group in compute_parallel_groups(graph) { + let nodes: Vec<(&String, &Node)> = group + .iter() + .filter_map(|id| graph.nodes.get(id).map(|n| (id, n))) + .collect(); + + for (id_a, node_a) in &nodes { + let read_set_a = read_set_of(node_a); + if read_set_a.is_empty() { + continue; + } + for (id_b, node_b) in &nodes { + if id_b == id_a { + continue; + } + let Some(write_set_b) = write_set_of(node_b) else { + continue; + }; + let mut collisions: Vec = + read_set_a.intersection(&write_set_b).cloned().collect(); + if collisions.is_empty() { + continue; + } + collisions.sort(); + let keys = collisions + .iter() + .map(|k| format!("`{k}`")) + .collect::>() + .join(", "); + result.error(ValidationError::with_node( + id_a.as_str(), + format!( + "node '{id_a}' reads state key(s) {keys} which sibling parallel \ + branch '{id_b}' writes in the same super-step; parallel branches \ + see a state snapshot taken BEFORE the super-step and cannot observe \ + each other's writes. Move the dependent read to a later super-step \ + (or remove the cross-branch reference)." + ), + )); + } + } + } + } } fn declared_targets(node: &Node) -> Vec<(String, &'static str)> { @@ -646,6 +693,103 @@ fn write_set_of(node: &Node) -> Option> { Some(writes) } +// Computes the set of root state keys this node's templated fields read from. +// +// "Root key" follows the same definition as `template_root_keys`: for a +// reference like `{{user.name}}` or `{{items[0]}}`, the root is the bare +// identifier before the first `.` or `[`. +// +// Templated fields scanned per node type: +// - llm: instructions, prompt, state_updates values +// - agent: prompt, state_updates values +// - rag: query (defaulting to "{{initial_prompt}}"), state_updates values +// - approval: question, state_updates values +// - input: question, default, state_updates values +// - end: output, state_updates values +// - map: over (its `{{...}}` IS the dynamic read of the list to fan out over) +// - script: state_updates values only (the script body is opaque to static +// analysis; its reads via GRAPH_STATE / GRAPH_STATE_FILE can't be +// inferred at load time) +// +// Scoped variables produced by THIS node's own execution are excluded from +// state_updates value scanning: +// - llm/agent/rag → "output" (the node's body output) +// - approval → "choice" (the user's selected option) +// - input → "input" (the user's typed text) +// These are bindings created inside the node, not reads from prior state, so +// they cannot race with a sibling's writes. +fn read_set_of(node: &Node) -> HashSet { + let mut reads: HashSet = HashSet::new(); + let scoped: &[&str] = match &node.node_type { + NodeType::Llm(_) | NodeType::Agent(_) | NodeType::Rag(_) => &["output"], + NodeType::Approval(_) => &["choice"], + NodeType::Input(_) => &["input"], + NodeType::Script(_) | NodeType::End(_) | NodeType::Map(_) => &[], + }; + + for s in primary_templated_fields(node) { + for k in template_root_keys(&s) { + reads.insert(k); + } + } + + if let Some(updates) = node_state_updates_map(node) { + for v in updates.values() { + for k in template_root_keys(v) { + if !scoped.contains(&k.as_str()) { + reads.insert(k); + } + } + } + } + + reads +} + +fn primary_templated_fields(node: &Node) -> Vec { + match &node.node_type { + NodeType::Llm(n) => { + let mut v = vec![n.prompt.clone()]; + if let Some(i) = &n.instructions { + v.push(i.clone()); + } + v + } + NodeType::Agent(n) => vec![n.prompt.clone()], + NodeType::Rag(n) => { + vec![ + n.query + .clone() + .unwrap_or_else(|| "{{initial_prompt}}".to_string()), + ] + } + NodeType::Approval(n) => vec![n.question.clone()], + NodeType::Input(n) => { + let mut v = vec![n.question.clone()]; + if let Some(d) = &n.default { + v.push(d.clone()); + } + v + } + NodeType::End(n) => vec![n.output.clone()], + NodeType::Map(n) => vec![n.over.clone()], + NodeType::Script(_) => Vec::new(), + } +} + +fn node_state_updates_map(node: &Node) -> Option<&std::collections::HashMap> { + match &node.node_type { + NodeType::Llm(n) => n.state_updates.as_ref(), + NodeType::Agent(n) => n.state_updates.as_ref(), + NodeType::Rag(n) => n.state_updates.as_ref(), + NodeType::Approval(n) => n.state_updates.as_ref(), + NodeType::Input(n) => n.state_updates.as_ref(), + NodeType::Script(n) => n.state_updates.as_ref(), + NodeType::End(n) => n.state_updates.as_ref(), + NodeType::Map(_) => None, + } +} + fn node_state_updates_keys(node: &Node) -> Option> { let updates = match &node.node_type { NodeType::Agent(n) => n.state_updates.as_ref(), @@ -2064,4 +2208,140 @@ mod tests { result.errors ); } + + fn llm_with_prompt(id: &str, prompt: &str, next: Option<&str>) -> Node { + let mut node = llm_node(id, None, next); + if let NodeType::Llm(ref mut n) = node.node_type { + n.prompt = prompt.into(); + } + node + } + + #[test] + fn parallel_read_of_sibling_write_errors() { + let reader = llm_with_prompt("worker_a", "Hello {{summary}}!", Some("end")); + let writer = llm_with_state_updates("worker_b", &[("summary", "static")], Some("end")); + let graph = fan_out_graph_with_two_workers(reader, writer); + + let result = validator().validate(&graph); + + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("reads state key(s) `summary`") + && e.message.contains("'worker_b'")), + "expected cross-branch read error mentioning `summary` and sibling writer: {:?}", + result.errors + ); + } + + #[test] + fn parallel_read_of_upstream_key_passes() { + let reader_a = llm_with_prompt("worker_a", "Topic is {{topic}}", Some("end")); + let reader_b = llm_with_prompt("worker_b", "Also {{topic}}", Some("end")); + let graph = fan_out_graph_with_two_workers(reader_a, reader_b); + + let result = validator().validate(&graph); + + assert!( + !result + .errors + .iter() + .any(|e| e.message.contains("reads state key")), + "upstream `topic` shouldn't trigger cross-branch read error: {:?}", + result.errors + ); + } + + #[test] + fn scoped_output_var_in_state_updates_not_treated_as_read() { + let scoped_user = + llm_with_state_updates("worker_a", &[("a_key", "{{output}}")], Some("end")); + let writes_output = + llm_with_state_updates("worker_b", &[("output", "{{output}}")], Some("end")); + let graph = fan_out_graph_with_two_workers(scoped_user, writes_output); + + let result = validator().validate(&graph); + + assert!( + !result + .errors + .iter() + .any(|e| e.message.contains("reads state key(s) `output`") + && e.message.contains("worker_a")), + "scoped `{{{{output}}}}` inside state_updates value should NOT be treated as a read: {:?}", + result.errors + ); + } + + #[test] + fn rag_query_reading_sibling_script_write_errors() { + let mut rag = rag_node("worker_a", &["./k"], true); + if let NodeType::Rag(ref mut n) = rag.node_type { + n.query = Some("codes: {{loinc_codes}}\n{{db_result}}".into()); + if let Some(m) = n.state_updates.as_mut() { + m.insert("rag_ctx".into(), "{{output.context}}".into()); + } + } + rag.next = Some("end".into()); + let mut script = script_with_state_updates("worker_b", &[("db_result", "{{output}}")]); + script.next = Some("end".into()); + let graph = fan_out_graph_with_two_workers(rag, script); + + let result = validator().validate(&graph); + + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("reads state key(s) `db_result`") + && e.message.contains("'worker_b'")), + "expected cross-branch read error for rag query reading db_result: {:?}", + result.errors + ); + } + + #[test] + fn map_over_reading_sibling_write_errors() { + let map_n = Node { + id: "fan".into(), + description: String::new(), + node_type: NodeType::Map(MapNode { + over: "{{items}}".into(), + as_name: "item".into(), + branch: "branch_n".into(), + output_key: "output".into(), + collect_into: "results".into(), + max_concurrency: None, + }), + next: Some("end".into()), + }; + let branch_n = llm_with_prompt("branch_n", "Process {{item}}", None); + let producer = llm_with_state_updates("producer", &[("items", "[1,2,3]")], Some("end")); + let mut start = end_node("start"); + start.next = Some(NextTargets::Many(vec!["fan".into(), "producer".into()])); + let graph = graph_with( + vec![ + ("start", start), + ("fan", map_n), + ("branch_n", branch_n), + ("producer", producer), + ("end", end_node("end")), + ], + "start", + ); + + let result = validator().validate(&graph); + + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("reads state key(s) `items`") + && e.message.contains("'producer'")), + "expected cross-branch read error for map `over` reading sibling write: {:?}", + result.errors + ); + } }