feat: created the staging area for state merges per super-step and created the built-in reducers (and their application) for the state merge phase of a super step

This commit is contained in:
2026-05-20 12:16:14 -06:00
parent 8c398b6360
commit 07c1f70df3
5 changed files with 941 additions and 23 deletions
+415 -3
View File
@@ -1,14 +1,16 @@
use super::MAX_STATE_SIZE_BYTES;
use super::types::GraphState;
use super::reducer;
use super::staging::BranchWrites;
use super::types::{GraphState, Reducer};
use crate::utils::temp_file;
use anyhow::{Context, Result, bail};
use fancy_regex::Regex;
use serde_json::Value;
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};
use std::fs;
use std::fs::write;
use std::path::PathBuf;
use std::sync::LazyLock;
use std::sync::{Arc, LazyLock};
static TEMPLATE_VAR_RE: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\{\{([a-zA-Z0-9_\.\[\]]+)\}\}").expect("invalid template regex"));
@@ -156,6 +158,96 @@ impl StateManager {
let _ = fs::remove_file(path);
}
}
/// Returns an `Arc`-wrapped snapshot of the current graph state. Each branch
/// in a parallel super-step shares this snapshot for reads; their writes
/// accumulate into per-branch `StagingArea` instances, which are merged via
/// `apply_branch_writes` at the end of the super-step.
///
/// Distinct from the older `snapshot()` method (returns a `HashMap` clone of
/// the data only — used by `script_executor` to ship state to child processes).
#[allow(dead_code)]
pub fn read_snapshot(&self) -> Arc<GraphState> {
Arc::new(self.state.clone())
}
/// Commits a deterministically-ordered set of per-branch writes back into
/// live state, applying declared reducers where they exist.
///
/// Caller must pre-sort `writes` by `(node_id, invocation_index)` so that
/// non-commutative reducers (`Concat`, `Merge`) produce reproducible output.
///
/// Errors when a key has writers from ≥2 branches but no reducer declared.
/// The validator (Phase C) catches this at load time; this runtime check is
/// defense-in-depth against a malformed or out-of-date validator missing it.
#[allow(dead_code)]
pub fn apply_branch_writes(
&mut self,
writes: Vec<BranchWrites>,
reducers: &HashMap<String, Reducer>,
) -> Result<()> {
let mut by_key: BTreeMap<String, Vec<Value>> = BTreeMap::new();
for branch in writes {
for (key, value) in branch.writes {
by_key.entry(key).or_default().push(value);
}
}
for (key, values) in by_key {
match reducers.get(&key).copied() {
Some(r) => {
let mut current = self.state.get(&key).cloned();
for value in values {
current = Some(reducer::apply(r, current.as_ref(), value)?);
}
if let Some(final_value) = current {
self.state.set(key, final_value);
}
}
None if values.len() == 1 => {
self.state.set(key, values.into_iter().next().unwrap());
}
None => {
bail!(
"Key '{key}' was written by {} parallel branches but has no \
reducer declared. Add a reducer for '{key}' to the graph's \
`reducers:` block, or rename one writer.",
values.len()
);
}
}
}
Ok(())
}
/// Interpolates a template and returns a typed JSON `Value`.
///
/// Two paths depending on the template shape:
/// - **Pure single reference** (the entire trimmed template is a single
/// `{{key}}` expression, e.g. `"{{subjects}}"`, `"{{user.name}}"`,
/// `"{{items[0]}}"`) — returns the typed `Value` at that key, preserving
/// numbers, bools, arrays, and objects. Errors if the key is missing.
/// - **Mixed template** (multiple refs, surrounding text, or no refs) —
/// falls back to string interpolation via `interpolate()` and returns
/// `Value::String(...)`. Strict on missing keys.
///
/// Required by:
/// - `map.over: "{{subjects}}"` — must resolve to a JSON array, not its string form
/// - `state_updates` writes that should preserve the source type (a `cost_usd: "{{api_cost}}"`
/// write should land as a Number, not a String)
#[allow(dead_code)]
pub fn interpolate_raw(&self, template: &str) -> Result<Value> {
let trimmed = template.trim();
if let Some(key) = single_reference_key(trimmed) {
match self.get_nested_value(key) {
Some(value) => Ok(value.clone()),
None => bail!("Template interpolation failed: '{key}' not found in state"),
}
} else {
Ok(Value::String(self.interpolate(template)?))
}
}
}
impl Drop for StateManager {
@@ -214,6 +306,22 @@ fn split_indices(segment: &str) -> Option<(&str, Vec<usize>)> {
Some((key, indices))
}
// Returns the inner key when `template` is exactly a single `{{key}}` reference
// (no surrounding text, no other braces). Mirrors the character set the
// TEMPLATE_VAR_RE regex accepts so `interpolate_raw` and `interpolate` stay
// consistent about what counts as a valid key.
fn single_reference_key(template: &str) -> Option<&str> {
let inner = template.strip_prefix("{{")?.strip_suffix("}}")?;
if inner.contains("{{") || inner.contains("}}") {
return None;
}
let valid = !inner.is_empty()
&& inner
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == '[' || c == ']');
valid.then_some(inner)
}
fn value_to_string(value: &Value) -> String {
match value {
Value::String(s) => s.clone(),
@@ -532,4 +640,308 @@ mod tests {
assert_eq!(manager.state().get("status"), Some(&json!("complete")));
}
fn branch(node_id: &str, idx: usize, writes: &[(&str, Value)]) -> BranchWrites {
let mut map = HashMap::new();
for (k, v) in writes {
map.insert((*k).into(), v.clone());
}
BranchWrites {
node_id: node_id.into(),
invocation_index: idx,
writes: map,
}
}
#[test]
fn read_snapshot_returns_arc_with_current_state() {
let manager = manager_with(&[("k", json!("v"))]);
let snap = manager.read_snapshot();
assert_eq!(snap.get("k"), Some(&json!("v")));
}
#[test]
fn read_snapshot_is_independent_of_later_mutations() {
let mut manager = manager_with(&[("count", json!(1))]);
let snap = manager.read_snapshot();
manager.state_mut().set("count".into(), json!(999));
assert_eq!(snap.get("count"), Some(&json!(1)));
assert_eq!(manager.state().get("count"), Some(&json!(999)));
}
#[test]
fn apply_branch_writes_empty_is_noop() {
let mut manager = manager_with(&[("k", json!("v"))]);
let reducers = HashMap::new();
manager.apply_branch_writes(vec![], &reducers).unwrap();
assert_eq!(manager.state().get("k"), Some(&json!("v")));
}
#[test]
fn apply_branch_writes_single_writer_no_reducer_overwrites() {
let mut manager = manager_with(&[]);
let reducers = HashMap::new();
manager
.apply_branch_writes(vec![branch("n", 0, &[("k", json!(42))])], &reducers)
.unwrap();
assert_eq!(manager.state().get("k"), Some(&json!(42)));
}
#[test]
fn apply_branch_writes_disjoint_keys_all_land() {
let mut manager = manager_with(&[]);
let reducers = HashMap::new();
manager
.apply_branch_writes(
vec![
branch("a", 0, &[("x", json!(1))]),
branch("b", 0, &[("y", json!(2))]),
branch("c", 0, &[("z", json!(3))]),
],
&reducers,
)
.unwrap();
assert_eq!(manager.state().get("x"), Some(&json!(1)));
assert_eq!(manager.state().get("y"), Some(&json!(2)));
assert_eq!(manager.state().get("z"), Some(&json!(3)));
}
#[test]
fn apply_branch_writes_three_appends_preserve_input_order() {
let mut manager = manager_with(&[]);
let mut reducers = HashMap::new();
reducers.insert("items".into(), Reducer::Append);
manager
.apply_branch_writes(
vec![
branch("a", 0, &[("items", json!("first"))]),
branch("b", 0, &[("items", json!("second"))]),
branch("c", 0, &[("items", json!("third"))]),
],
&reducers,
)
.unwrap();
assert_eq!(
manager.state().get("items"),
Some(&json!(["first", "second", "third"]))
);
}
#[test]
fn apply_branch_writes_collision_without_reducer_bails() {
let mut manager = manager_with(&[]);
let reducers = HashMap::new();
let err = manager
.apply_branch_writes(
vec![
branch("a", 0, &[("k", json!("first"))]),
branch("b", 0, &[("k", json!("second"))]),
],
&reducers,
)
.unwrap_err()
.to_string();
assert!(err.contains("'k'"), "got: {err}");
assert!(err.contains("no reducer"), "got: {err}");
assert!(err.contains("2 parallel branches"), "got: {err}");
}
#[test]
fn apply_branch_writes_sum_reducer_accumulates_with_existing_state() {
let mut manager = manager_with(&[("cost", json!(10))]);
let mut reducers = HashMap::new();
reducers.insert("cost".into(), Reducer::Sum);
manager
.apply_branch_writes(
vec![
branch("a", 0, &[("cost", json!(5))]),
branch("b", 0, &[("cost", json!(7))]),
],
&reducers,
)
.unwrap();
assert_eq!(manager.state().get("cost"), Some(&json!(22)));
}
#[test]
fn apply_branch_writes_concat_respects_branch_order() {
let mut manager = manager_with(&[]);
let mut reducers = HashMap::new();
reducers.insert("log".into(), Reducer::Concat);
manager
.apply_branch_writes(
vec![
branch("a", 0, &[("log", json!("alpha"))]),
branch("b", 0, &[("log", json!("bravo"))]),
],
&reducers,
)
.unwrap();
assert_eq!(manager.state().get("log"), Some(&json!("alpha\nbravo")));
}
#[test]
fn apply_branch_writes_mixed_keys_with_and_without_reducers() {
let mut manager = manager_with(&[]);
let mut reducers = HashMap::new();
reducers.insert("results".into(), Reducer::Append);
manager
.apply_branch_writes(
vec![
branch(
"a",
0,
&[("results", json!("x")), ("status", json!("ok_a"))],
),
branch("b", 0, &[("results", json!("y"))]),
],
&reducers,
)
.unwrap();
assert_eq!(manager.state().get("results"), Some(&json!(["x", "y"])));
assert_eq!(manager.state().get("status"), Some(&json!("ok_a")));
}
#[test]
fn interpolate_raw_pure_ref_returns_typed_number() {
let manager = manager_with(&[("count", json!(42))]);
let result = manager.interpolate_raw("{{count}}").unwrap();
assert_eq!(result, json!(42));
assert!(result.is_i64());
}
#[test]
fn interpolate_raw_pure_ref_returns_typed_array() {
let manager = manager_with(&[("items", json!(["a", "b", "c"]))]);
let result = manager.interpolate_raw("{{items}}").unwrap();
assert_eq!(result, json!(["a", "b", "c"]));
assert!(result.is_array());
}
#[test]
fn interpolate_raw_pure_ref_returns_typed_object() {
let manager = manager_with(&[("user", json!({ "name": "alice", "age": 30 }))]);
let result = manager.interpolate_raw("{{user}}").unwrap();
assert_eq!(result, json!({ "name": "alice", "age": 30 }));
assert!(result.is_object());
}
#[test]
fn interpolate_raw_pure_ref_returns_typed_bool() {
let manager = manager_with(&[("flag", json!(true))]);
let result = manager.interpolate_raw("{{flag}}").unwrap();
assert_eq!(result, json!(true));
assert!(result.is_boolean());
}
#[test]
fn interpolate_raw_nested_path_returns_typed_value() {
let manager = manager_with(&[("user", json!({ "email": "x@y.com" }))]);
let result = manager.interpolate_raw("{{user.email}}").unwrap();
assert_eq!(result, json!("x@y.com"));
assert!(result.is_string());
}
#[test]
fn interpolate_raw_array_index_returns_typed_value() {
let manager = manager_with(&[("items", json!([10, 20, 30]))]);
let result = manager.interpolate_raw("{{items[1]}}").unwrap();
assert_eq!(result, json!(20));
assert!(result.is_i64());
}
#[test]
fn interpolate_raw_missing_pure_ref_errors() {
let manager = manager_with(&[]);
let err = manager
.interpolate_raw("{{ghost}}")
.unwrap_err()
.to_string();
assert!(err.contains("'ghost'"), "got: {err}");
assert!(err.contains("not found"), "got: {err}");
}
#[test]
fn interpolate_raw_mixed_template_falls_back_to_string() {
let manager = manager_with(&[("name", json!("alice"))]);
let result = manager.interpolate_raw("Hello {{name}}!").unwrap();
assert_eq!(result, json!("Hello alice!"));
assert!(result.is_string());
}
#[test]
fn interpolate_raw_multiple_refs_fall_back_to_string() {
let manager = manager_with(&[("a", json!(1)), ("b", json!(2))]);
let result = manager.interpolate_raw("{{a}}{{b}}").unwrap();
assert_eq!(result, json!("12"));
assert!(result.is_string());
}
#[test]
fn interpolate_raw_no_refs_is_literal_string() {
let manager = manager_with(&[]);
let result = manager.interpolate_raw("literal text").unwrap();
assert_eq!(result, json!("literal text"));
}
#[test]
fn interpolate_raw_whitespace_padding_still_resolves_pure_ref() {
let manager = manager_with(&[("k", json!("v"))]);
let result = manager.interpolate_raw(" {{k}} ").unwrap();
assert_eq!(result, json!("v"));
}
#[test]
fn interpolate_raw_inner_spaces_treated_as_mixed() {
let manager = manager_with(&[("k", json!("v"))]);
// `{{ k }}` is not a valid pure reference (spaces inside braces are
// outside the allowed character set). Fall back to string interpolation
// -- which doesn't match the regex either, so the literal passes through.
let result = manager.interpolate_raw("{{ k }}").unwrap();
assert_eq!(result, json!("{{ k }}"));
}
}