feat: created the staging area for state merges per super-step and created the built-in reducers (and their application) for the state merge phase of a super step
This commit is contained in:
@@ -5,7 +5,9 @@ pub mod llm;
|
||||
pub mod logging;
|
||||
pub mod parser;
|
||||
pub mod rag;
|
||||
pub mod reducer;
|
||||
pub mod script;
|
||||
pub mod staging;
|
||||
pub mod state;
|
||||
pub mod structured;
|
||||
pub mod types;
|
||||
|
||||
@@ -0,0 +1,407 @@
|
||||
use super::types::Reducer;
|
||||
use anyhow::{Result, bail};
|
||||
use serde_json::{Number, Value};
|
||||
|
||||
/// Combines a branch's incoming write with the current state value (if any)
|
||||
/// via the specified reducer. The result is what gets written back to live
|
||||
/// state during the super-step merge phase.
|
||||
///
|
||||
/// `current = None` means the key has no prior value in this super-step or in
|
||||
/// live state. Most reducers treat absent as their identity (empty array,
|
||||
/// empty string, no prior value). `Overwrite` ignores `current` entirely.
|
||||
///
|
||||
/// Errors clearly when types are incompatible with the reducer (e.g.
|
||||
/// `Sum` on a string), naming the reducer and which side (`current` / `incoming`)
|
||||
/// has the wrong type.
|
||||
pub fn apply(reducer: Reducer, current: Option<&Value>, incoming: Value) -> Result<Value> {
|
||||
match reducer {
|
||||
Reducer::Append => apply_append(current, incoming),
|
||||
Reducer::Extend => apply_extend(current, incoming),
|
||||
Reducer::Concat => apply_concat(current, incoming),
|
||||
Reducer::Sum => apply_sum(current, incoming),
|
||||
Reducer::Max => apply_max(current, incoming),
|
||||
Reducer::Min => apply_min(current, incoming),
|
||||
Reducer::Merge => apply_merge(current, incoming),
|
||||
Reducer::Overwrite => Ok(incoming),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_append(current: Option<&Value>, incoming: Value) -> Result<Value> {
|
||||
let mut arr = match current {
|
||||
None => Vec::new(),
|
||||
Some(Value::Array(a)) => a.clone(),
|
||||
Some(other) => bail!(
|
||||
"reducer 'append' requires an array (or absent) for the current value, got {}",
|
||||
type_name(other)
|
||||
),
|
||||
};
|
||||
arr.push(incoming);
|
||||
Ok(Value::Array(arr))
|
||||
}
|
||||
|
||||
fn apply_extend(current: Option<&Value>, incoming: Value) -> Result<Value> {
|
||||
let mut arr = match current {
|
||||
None => Vec::new(),
|
||||
Some(Value::Array(a)) => a.clone(),
|
||||
Some(other) => bail!(
|
||||
"reducer 'extend' requires an array (or absent) for the current value, got {}",
|
||||
type_name(other)
|
||||
),
|
||||
};
|
||||
match incoming {
|
||||
Value::Array(items) => arr.extend(items),
|
||||
other => bail!(
|
||||
"reducer 'extend' requires an array for the incoming value, got {}",
|
||||
type_name(&other)
|
||||
),
|
||||
}
|
||||
Ok(Value::Array(arr))
|
||||
}
|
||||
|
||||
fn apply_concat(current: Option<&Value>, incoming: Value) -> Result<Value> {
|
||||
let incoming_str = match incoming {
|
||||
Value::String(s) => s,
|
||||
other => bail!(
|
||||
"reducer 'concat' requires a string for the incoming value, got {}",
|
||||
type_name(&other)
|
||||
),
|
||||
};
|
||||
let result = match current {
|
||||
None => incoming_str,
|
||||
Some(Value::String(c)) => {
|
||||
if c.is_empty() {
|
||||
incoming_str
|
||||
} else {
|
||||
format!("{c}\n{incoming_str}")
|
||||
}
|
||||
}
|
||||
Some(other) => bail!(
|
||||
"reducer 'concat' requires a string (or absent) for the current value, got {}",
|
||||
type_name(other)
|
||||
),
|
||||
};
|
||||
Ok(Value::String(result))
|
||||
}
|
||||
|
||||
fn apply_sum(current: Option<&Value>, incoming: Value) -> Result<Value> {
|
||||
let i = number_or_error(&incoming, "sum", "incoming")?;
|
||||
let c = match current {
|
||||
None => 0.0,
|
||||
Some(value) => number_or_error(value, "sum", "current")?,
|
||||
};
|
||||
Ok(json_number(c + i))
|
||||
}
|
||||
|
||||
fn apply_max(current: Option<&Value>, incoming: Value) -> Result<Value> {
|
||||
let i = number_or_error(&incoming, "max", "incoming")?;
|
||||
match current {
|
||||
None => Ok(json_number(i)),
|
||||
Some(value) => {
|
||||
let c = number_or_error(value, "max", "current")?;
|
||||
Ok(json_number(c.max(i)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_min(current: Option<&Value>, incoming: Value) -> Result<Value> {
|
||||
let i = number_or_error(&incoming, "min", "incoming")?;
|
||||
match current {
|
||||
None => Ok(json_number(i)),
|
||||
Some(value) => {
|
||||
let c = number_or_error(value, "min", "current")?;
|
||||
Ok(json_number(c.min(i)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_merge(current: Option<&Value>, incoming: Value) -> Result<Value> {
|
||||
let mut map = match current {
|
||||
None => serde_json::Map::new(),
|
||||
Some(Value::Object(m)) => m.clone(),
|
||||
Some(other) => bail!(
|
||||
"reducer 'merge' requires an object (or absent) for the current value, got {}",
|
||||
type_name(other)
|
||||
),
|
||||
};
|
||||
match incoming {
|
||||
Value::Object(items) => {
|
||||
for (k, v) in items {
|
||||
map.insert(k, v);
|
||||
}
|
||||
}
|
||||
other => bail!(
|
||||
"reducer 'merge' requires an object for the incoming value, got {}",
|
||||
type_name(&other)
|
||||
),
|
||||
}
|
||||
Ok(Value::Object(map))
|
||||
}
|
||||
|
||||
fn number_or_error(value: &Value, reducer_name: &str, position: &str) -> Result<f64> {
|
||||
match value.as_f64() {
|
||||
Some(n) => Ok(n),
|
||||
None => bail!(
|
||||
"reducer '{reducer_name}' requires a number for the {position} value, got {}",
|
||||
type_name(value)
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn type_name(value: &Value) -> &'static str {
|
||||
match value {
|
||||
Value::Null => "null",
|
||||
Value::Bool(_) => "bool",
|
||||
Value::Number(_) => "number",
|
||||
Value::String(_) => "string",
|
||||
Value::Array(_) => "array",
|
||||
Value::Object(_) => "object",
|
||||
}
|
||||
}
|
||||
|
||||
// Numeric reducers compute in f64 for simplicity. We preserve integer typing
|
||||
// when the result is losslessly representable as i64 so `count: sum` stays an
|
||||
// integer rather than degrading to a float. Non-finite values (NaN, Inf) can't
|
||||
// arise from finite inputs to +/max/min, so the fallback never fires in practice.
|
||||
fn json_number(n: f64) -> Value {
|
||||
if n.fract() == 0.0 && n.is_finite() && n.abs() <= (i64::MAX as f64) {
|
||||
Value::Number(Number::from(n as i64))
|
||||
} else {
|
||||
match Number::from_f64(n) {
|
||||
Some(num) => Value::Number(num),
|
||||
None => Value::Null,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn append_to_absent_creates_single_element_array() {
|
||||
let result = apply(Reducer::Append, None, json!("a")).unwrap();
|
||||
|
||||
assert_eq!(result, json!(["a"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_pushes_onto_existing_array() {
|
||||
let current = json!(["a", "b"]);
|
||||
let result = apply(Reducer::Append, Some(¤t), json!("c")).unwrap();
|
||||
|
||||
assert_eq!(result, json!(["a", "b", "c"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn append_errors_when_current_is_not_array() {
|
||||
let current = json!("not an array");
|
||||
|
||||
let err = apply(Reducer::Append, Some(¤t), json!("x"))
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'append'"), "got: {err}");
|
||||
assert!(err.contains("string"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extend_concatenates_arrays() {
|
||||
let current = json!([1, 2]);
|
||||
|
||||
let result = apply(Reducer::Extend, Some(¤t), json!([3, 4])).unwrap();
|
||||
|
||||
assert_eq!(result, json!([1, 2, 3, 4]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extend_from_absent_with_array() {
|
||||
let result = apply(Reducer::Extend, None, json!([1, 2])).unwrap();
|
||||
|
||||
assert_eq!(result, json!([1, 2]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extend_errors_when_incoming_is_not_array() {
|
||||
let err = apply(Reducer::Extend, None, json!(42))
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'extend'"), "got: {err}");
|
||||
assert!(err.contains("number"), "got: {err}");
|
||||
assert!(err.contains("incoming"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concat_joins_strings_with_newline() {
|
||||
let current = json!("first");
|
||||
|
||||
let result = apply(Reducer::Concat, Some(¤t), json!("second")).unwrap();
|
||||
|
||||
assert_eq!(result, json!("first\nsecond"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concat_from_absent_yields_incoming() {
|
||||
let result = apply(Reducer::Concat, None, json!("hello")).unwrap();
|
||||
|
||||
assert_eq!(result, json!("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concat_skips_separator_when_current_is_empty_string() {
|
||||
let current = json!("");
|
||||
|
||||
let result = apply(Reducer::Concat, Some(¤t), json!("first")).unwrap();
|
||||
|
||||
assert_eq!(result, json!("first"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn concat_errors_when_incoming_is_not_string() {
|
||||
let err = apply(Reducer::Concat, None, json!(42))
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'concat'"), "got: {err}");
|
||||
assert!(err.contains("number"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_adds_numbers() {
|
||||
let current = json!(5);
|
||||
|
||||
let result = apply(Reducer::Sum, Some(¤t), json!(7)).unwrap();
|
||||
|
||||
assert_eq!(result, json!(12));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_starts_from_zero_when_current_absent() {
|
||||
let result = apply(Reducer::Sum, None, json!(3.5)).unwrap();
|
||||
|
||||
assert_eq!(result, json!(3.5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_preserves_integer_type_for_whole_results() {
|
||||
let current = json!(2);
|
||||
|
||||
let result = apply(Reducer::Sum, Some(¤t), json!(3)).unwrap();
|
||||
|
||||
assert!(result.is_i64(), "expected integer, got {result:?}");
|
||||
assert_eq!(result, json!(5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_uses_float_when_result_has_fractional() {
|
||||
let current = json!(1.5);
|
||||
let result = apply(Reducer::Sum, Some(¤t), json!(2.25)).unwrap();
|
||||
|
||||
assert_eq!(result, json!(3.75));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum_errors_on_string_incoming() {
|
||||
let err = apply(Reducer::Sum, None, json!("not a number"))
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'sum'"), "got: {err}");
|
||||
assert!(err.contains("string"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_returns_larger_of_two() {
|
||||
let current = json!(5);
|
||||
let result = apply(Reducer::Max, Some(¤t), json!(3)).unwrap();
|
||||
assert_eq!(result, json!(5));
|
||||
|
||||
let result = apply(Reducer::Max, Some(¤t), json!(10)).unwrap();
|
||||
assert_eq!(result, json!(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_yields_incoming_when_current_absent() {
|
||||
let result = apply(Reducer::Max, None, json!(42)).unwrap();
|
||||
|
||||
assert_eq!(result, json!(42));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_returns_smaller_of_two() {
|
||||
let current = json!(5);
|
||||
let result = apply(Reducer::Min, Some(¤t), json!(3)).unwrap();
|
||||
assert_eq!(result, json!(3));
|
||||
|
||||
let result = apply(Reducer::Min, Some(¤t), json!(10)).unwrap();
|
||||
assert_eq!(result, json!(5));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn min_errors_on_non_numeric_current() {
|
||||
let current = json!("oops");
|
||||
|
||||
let err = apply(Reducer::Min, Some(¤t), json!(1))
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'min'"), "got: {err}");
|
||||
assert!(err.contains("current"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_unions_objects_with_incoming_winning_collisions() {
|
||||
let current = json!({ "a": 1, "b": 2 });
|
||||
let incoming = json!({ "b": 99, "c": 3 });
|
||||
|
||||
let result = apply(Reducer::Merge, Some(¤t), incoming).unwrap();
|
||||
|
||||
assert_eq!(result, json!({ "a": 1, "b": 99, "c": 3 }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_from_absent_yields_incoming_object() {
|
||||
let result = apply(Reducer::Merge, None, json!({ "k": "v" })).unwrap();
|
||||
|
||||
assert_eq!(result, json!({ "k": "v" }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_errors_when_incoming_is_not_object() {
|
||||
let err = apply(Reducer::Merge, None, json!([1, 2]))
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'merge'"), "got: {err}");
|
||||
assert!(err.contains("array"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn merge_errors_when_current_is_not_object() {
|
||||
let current = json!("not object");
|
||||
|
||||
let err = apply(Reducer::Merge, Some(¤t), json!({ "k": "v" }))
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'merge'"), "got: {err}");
|
||||
assert!(err.contains("current"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overwrite_ignores_current_and_returns_incoming() {
|
||||
let current = json!("old");
|
||||
|
||||
let result = apply(Reducer::Overwrite, Some(¤t), json!("new")).unwrap();
|
||||
|
||||
assert_eq!(result, json!("new"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overwrite_works_with_absent_current() {
|
||||
let result = apply(Reducer::Overwrite, None, json!(42)).unwrap();
|
||||
|
||||
assert_eq!(result, json!(42));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct StagingArea {
|
||||
writes: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl StagingArea {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn write(&mut self, key: impl Into<String>, value: Value) {
|
||||
self.writes.insert(key.into(), value);
|
||||
}
|
||||
|
||||
pub fn get(&self, key: &str) -> Option<&Value> {
|
||||
self.writes.get(key)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.writes.is_empty()
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.writes.len()
|
||||
}
|
||||
|
||||
pub fn into_writes(self) -> HashMap<String, Value> {
|
||||
self.writes
|
||||
}
|
||||
}
|
||||
|
||||
/// Published form of one branch's writes for the super-step merge phase.
|
||||
/// Callers assemble these into a deterministically-ordered `Vec` keyed by
|
||||
/// `(node_id, invocation_index)` before passing to
|
||||
/// `StateManager::apply_branch_writes`. `invocation_index` is 0 for normal
|
||||
/// branches and the input-list position for map sub-branches — so multiple
|
||||
/// invocations of the same `branch:` node by a `map` are still totally ordered.
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
pub struct BranchWrites {
|
||||
pub node_id: String,
|
||||
pub invocation_index: usize,
|
||||
pub writes: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn new_staging_area_is_empty() {
|
||||
let s = StagingArea::new();
|
||||
|
||||
assert!(s.is_empty());
|
||||
assert_eq!(s.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_stores_value_under_key() {
|
||||
let mut s = StagingArea::new();
|
||||
|
||||
s.write("key", json!("value"));
|
||||
|
||||
assert_eq!(s.get("key"), Some(&json!("value")));
|
||||
assert_eq!(s.len(), 1);
|
||||
assert!(!s.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn write_overwrites_existing_key() {
|
||||
let mut s = StagingArea::new();
|
||||
|
||||
s.write("k", json!(1));
|
||||
s.write("k", json!(2));
|
||||
|
||||
assert_eq!(s.get("k"), Some(&json!(2)));
|
||||
assert_eq!(s.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn into_writes_consumes_and_yields_map() {
|
||||
let mut s = StagingArea::new();
|
||||
s.write("a", json!(1));
|
||||
s.write("b", json!(2));
|
||||
|
||||
let writes = s.into_writes();
|
||||
|
||||
assert_eq!(writes.len(), 2);
|
||||
assert_eq!(writes.get("a"), Some(&json!(1)));
|
||||
assert_eq!(writes.get("b"), Some(&json!(2)));
|
||||
}
|
||||
}
|
||||
+415
-3
@@ -1,14 +1,16 @@
|
||||
use super::MAX_STATE_SIZE_BYTES;
|
||||
use super::types::GraphState;
|
||||
use super::reducer;
|
||||
use super::staging::BranchWrites;
|
||||
use super::types::{GraphState, Reducer};
|
||||
use crate::utils::temp_file;
|
||||
use anyhow::{Context, Result, bail};
|
||||
use fancy_regex::Regex;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::fs;
|
||||
use std::fs::write;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
|
||||
static TEMPLATE_VAR_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"\{\{([a-zA-Z0-9_\.\[\]]+)\}\}").expect("invalid template regex"));
|
||||
@@ -156,6 +158,96 @@ impl StateManager {
|
||||
let _ = fs::remove_file(path);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an `Arc`-wrapped snapshot of the current graph state. Each branch
|
||||
/// in a parallel super-step shares this snapshot for reads; their writes
|
||||
/// accumulate into per-branch `StagingArea` instances, which are merged via
|
||||
/// `apply_branch_writes` at the end of the super-step.
|
||||
///
|
||||
/// Distinct from the older `snapshot()` method (returns a `HashMap` clone of
|
||||
/// the data only — used by `script_executor` to ship state to child processes).
|
||||
#[allow(dead_code)]
|
||||
pub fn read_snapshot(&self) -> Arc<GraphState> {
|
||||
Arc::new(self.state.clone())
|
||||
}
|
||||
|
||||
/// Commits a deterministically-ordered set of per-branch writes back into
|
||||
/// live state, applying declared reducers where they exist.
|
||||
///
|
||||
/// Caller must pre-sort `writes` by `(node_id, invocation_index)` so that
|
||||
/// non-commutative reducers (`Concat`, `Merge`) produce reproducible output.
|
||||
///
|
||||
/// Errors when a key has writers from ≥2 branches but no reducer declared.
|
||||
/// The validator (Phase C) catches this at load time; this runtime check is
|
||||
/// defense-in-depth against a malformed or out-of-date validator missing it.
|
||||
#[allow(dead_code)]
|
||||
pub fn apply_branch_writes(
|
||||
&mut self,
|
||||
writes: Vec<BranchWrites>,
|
||||
reducers: &HashMap<String, Reducer>,
|
||||
) -> Result<()> {
|
||||
let mut by_key: BTreeMap<String, Vec<Value>> = BTreeMap::new();
|
||||
for branch in writes {
|
||||
for (key, value) in branch.writes {
|
||||
by_key.entry(key).or_default().push(value);
|
||||
}
|
||||
}
|
||||
|
||||
for (key, values) in by_key {
|
||||
match reducers.get(&key).copied() {
|
||||
Some(r) => {
|
||||
let mut current = self.state.get(&key).cloned();
|
||||
for value in values {
|
||||
current = Some(reducer::apply(r, current.as_ref(), value)?);
|
||||
}
|
||||
if let Some(final_value) = current {
|
||||
self.state.set(key, final_value);
|
||||
}
|
||||
}
|
||||
None if values.len() == 1 => {
|
||||
self.state.set(key, values.into_iter().next().unwrap());
|
||||
}
|
||||
None => {
|
||||
bail!(
|
||||
"Key '{key}' was written by {} parallel branches but has no \
|
||||
reducer declared. Add a reducer for '{key}' to the graph's \
|
||||
`reducers:` block, or rename one writer.",
|
||||
values.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Interpolates a template and returns a typed JSON `Value`.
|
||||
///
|
||||
/// Two paths depending on the template shape:
|
||||
/// - **Pure single reference** (the entire trimmed template is a single
|
||||
/// `{{key}}` expression, e.g. `"{{subjects}}"`, `"{{user.name}}"`,
|
||||
/// `"{{items[0]}}"`) — returns the typed `Value` at that key, preserving
|
||||
/// numbers, bools, arrays, and objects. Errors if the key is missing.
|
||||
/// - **Mixed template** (multiple refs, surrounding text, or no refs) —
|
||||
/// falls back to string interpolation via `interpolate()` and returns
|
||||
/// `Value::String(...)`. Strict on missing keys.
|
||||
///
|
||||
/// Required by:
|
||||
/// - `map.over: "{{subjects}}"` — must resolve to a JSON array, not its string form
|
||||
/// - `state_updates` writes that should preserve the source type (a `cost_usd: "{{api_cost}}"`
|
||||
/// write should land as a Number, not a String)
|
||||
#[allow(dead_code)]
|
||||
pub fn interpolate_raw(&self, template: &str) -> Result<Value> {
|
||||
let trimmed = template.trim();
|
||||
if let Some(key) = single_reference_key(trimmed) {
|
||||
match self.get_nested_value(key) {
|
||||
Some(value) => Ok(value.clone()),
|
||||
None => bail!("Template interpolation failed: '{key}' not found in state"),
|
||||
}
|
||||
} else {
|
||||
Ok(Value::String(self.interpolate(template)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StateManager {
|
||||
@@ -214,6 +306,22 @@ fn split_indices(segment: &str) -> Option<(&str, Vec<usize>)> {
|
||||
Some((key, indices))
|
||||
}
|
||||
|
||||
// Returns the inner key when `template` is exactly a single `{{key}}` reference
|
||||
// (no surrounding text, no other braces). Mirrors the character set the
|
||||
// TEMPLATE_VAR_RE regex accepts so `interpolate_raw` and `interpolate` stay
|
||||
// consistent about what counts as a valid key.
|
||||
fn single_reference_key(template: &str) -> Option<&str> {
|
||||
let inner = template.strip_prefix("{{")?.strip_suffix("}}")?;
|
||||
if inner.contains("{{") || inner.contains("}}") {
|
||||
return None;
|
||||
}
|
||||
let valid = !inner.is_empty()
|
||||
&& inner
|
||||
.chars()
|
||||
.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == '[' || c == ']');
|
||||
valid.then_some(inner)
|
||||
}
|
||||
|
||||
fn value_to_string(value: &Value) -> String {
|
||||
match value {
|
||||
Value::String(s) => s.clone(),
|
||||
@@ -532,4 +640,308 @@ mod tests {
|
||||
|
||||
assert_eq!(manager.state().get("status"), Some(&json!("complete")));
|
||||
}
|
||||
|
||||
fn branch(node_id: &str, idx: usize, writes: &[(&str, Value)]) -> BranchWrites {
|
||||
let mut map = HashMap::new();
|
||||
for (k, v) in writes {
|
||||
map.insert((*k).into(), v.clone());
|
||||
}
|
||||
BranchWrites {
|
||||
node_id: node_id.into(),
|
||||
invocation_index: idx,
|
||||
writes: map,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_snapshot_returns_arc_with_current_state() {
|
||||
let manager = manager_with(&[("k", json!("v"))]);
|
||||
|
||||
let snap = manager.read_snapshot();
|
||||
|
||||
assert_eq!(snap.get("k"), Some(&json!("v")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn read_snapshot_is_independent_of_later_mutations() {
|
||||
let mut manager = manager_with(&[("count", json!(1))]);
|
||||
let snap = manager.read_snapshot();
|
||||
|
||||
manager.state_mut().set("count".into(), json!(999));
|
||||
|
||||
assert_eq!(snap.get("count"), Some(&json!(1)));
|
||||
assert_eq!(manager.state().get("count"), Some(&json!(999)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_branch_writes_empty_is_noop() {
|
||||
let mut manager = manager_with(&[("k", json!("v"))]);
|
||||
let reducers = HashMap::new();
|
||||
|
||||
manager.apply_branch_writes(vec![], &reducers).unwrap();
|
||||
|
||||
assert_eq!(manager.state().get("k"), Some(&json!("v")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_branch_writes_single_writer_no_reducer_overwrites() {
|
||||
let mut manager = manager_with(&[]);
|
||||
let reducers = HashMap::new();
|
||||
|
||||
manager
|
||||
.apply_branch_writes(vec![branch("n", 0, &[("k", json!(42))])], &reducers)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.state().get("k"), Some(&json!(42)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_branch_writes_disjoint_keys_all_land() {
|
||||
let mut manager = manager_with(&[]);
|
||||
let reducers = HashMap::new();
|
||||
|
||||
manager
|
||||
.apply_branch_writes(
|
||||
vec![
|
||||
branch("a", 0, &[("x", json!(1))]),
|
||||
branch("b", 0, &[("y", json!(2))]),
|
||||
branch("c", 0, &[("z", json!(3))]),
|
||||
],
|
||||
&reducers,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.state().get("x"), Some(&json!(1)));
|
||||
assert_eq!(manager.state().get("y"), Some(&json!(2)));
|
||||
assert_eq!(manager.state().get("z"), Some(&json!(3)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_branch_writes_three_appends_preserve_input_order() {
|
||||
let mut manager = manager_with(&[]);
|
||||
let mut reducers = HashMap::new();
|
||||
reducers.insert("items".into(), Reducer::Append);
|
||||
|
||||
manager
|
||||
.apply_branch_writes(
|
||||
vec![
|
||||
branch("a", 0, &[("items", json!("first"))]),
|
||||
branch("b", 0, &[("items", json!("second"))]),
|
||||
branch("c", 0, &[("items", json!("third"))]),
|
||||
],
|
||||
&reducers,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
manager.state().get("items"),
|
||||
Some(&json!(["first", "second", "third"]))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_branch_writes_collision_without_reducer_bails() {
|
||||
let mut manager = manager_with(&[]);
|
||||
let reducers = HashMap::new();
|
||||
|
||||
let err = manager
|
||||
.apply_branch_writes(
|
||||
vec![
|
||||
branch("a", 0, &[("k", json!("first"))]),
|
||||
branch("b", 0, &[("k", json!("second"))]),
|
||||
],
|
||||
&reducers,
|
||||
)
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'k'"), "got: {err}");
|
||||
assert!(err.contains("no reducer"), "got: {err}");
|
||||
assert!(err.contains("2 parallel branches"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_branch_writes_sum_reducer_accumulates_with_existing_state() {
|
||||
let mut manager = manager_with(&[("cost", json!(10))]);
|
||||
let mut reducers = HashMap::new();
|
||||
reducers.insert("cost".into(), Reducer::Sum);
|
||||
|
||||
manager
|
||||
.apply_branch_writes(
|
||||
vec![
|
||||
branch("a", 0, &[("cost", json!(5))]),
|
||||
branch("b", 0, &[("cost", json!(7))]),
|
||||
],
|
||||
&reducers,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.state().get("cost"), Some(&json!(22)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_branch_writes_concat_respects_branch_order() {
|
||||
let mut manager = manager_with(&[]);
|
||||
let mut reducers = HashMap::new();
|
||||
reducers.insert("log".into(), Reducer::Concat);
|
||||
|
||||
manager
|
||||
.apply_branch_writes(
|
||||
vec![
|
||||
branch("a", 0, &[("log", json!("alpha"))]),
|
||||
branch("b", 0, &[("log", json!("bravo"))]),
|
||||
],
|
||||
&reducers,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.state().get("log"), Some(&json!("alpha\nbravo")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_branch_writes_mixed_keys_with_and_without_reducers() {
|
||||
let mut manager = manager_with(&[]);
|
||||
let mut reducers = HashMap::new();
|
||||
reducers.insert("results".into(), Reducer::Append);
|
||||
|
||||
manager
|
||||
.apply_branch_writes(
|
||||
vec![
|
||||
branch(
|
||||
"a",
|
||||
0,
|
||||
&[("results", json!("x")), ("status", json!("ok_a"))],
|
||||
),
|
||||
branch("b", 0, &[("results", json!("y"))]),
|
||||
],
|
||||
&reducers,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(manager.state().get("results"), Some(&json!(["x", "y"])));
|
||||
assert_eq!(manager.state().get("status"), Some(&json!("ok_a")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_pure_ref_returns_typed_number() {
|
||||
let manager = manager_with(&[("count", json!(42))]);
|
||||
|
||||
let result = manager.interpolate_raw("{{count}}").unwrap();
|
||||
|
||||
assert_eq!(result, json!(42));
|
||||
assert!(result.is_i64());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_pure_ref_returns_typed_array() {
|
||||
let manager = manager_with(&[("items", json!(["a", "b", "c"]))]);
|
||||
|
||||
let result = manager.interpolate_raw("{{items}}").unwrap();
|
||||
|
||||
assert_eq!(result, json!(["a", "b", "c"]));
|
||||
assert!(result.is_array());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_pure_ref_returns_typed_object() {
|
||||
let manager = manager_with(&[("user", json!({ "name": "alice", "age": 30 }))]);
|
||||
|
||||
let result = manager.interpolate_raw("{{user}}").unwrap();
|
||||
|
||||
assert_eq!(result, json!({ "name": "alice", "age": 30 }));
|
||||
assert!(result.is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_pure_ref_returns_typed_bool() {
|
||||
let manager = manager_with(&[("flag", json!(true))]);
|
||||
|
||||
let result = manager.interpolate_raw("{{flag}}").unwrap();
|
||||
|
||||
assert_eq!(result, json!(true));
|
||||
assert!(result.is_boolean());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_nested_path_returns_typed_value() {
|
||||
let manager = manager_with(&[("user", json!({ "email": "x@y.com" }))]);
|
||||
|
||||
let result = manager.interpolate_raw("{{user.email}}").unwrap();
|
||||
|
||||
assert_eq!(result, json!("x@y.com"));
|
||||
assert!(result.is_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_array_index_returns_typed_value() {
|
||||
let manager = manager_with(&[("items", json!([10, 20, 30]))]);
|
||||
|
||||
let result = manager.interpolate_raw("{{items[1]}}").unwrap();
|
||||
|
||||
assert_eq!(result, json!(20));
|
||||
assert!(result.is_i64());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_missing_pure_ref_errors() {
|
||||
let manager = manager_with(&[]);
|
||||
|
||||
let err = manager
|
||||
.interpolate_raw("{{ghost}}")
|
||||
.unwrap_err()
|
||||
.to_string();
|
||||
|
||||
assert!(err.contains("'ghost'"), "got: {err}");
|
||||
assert!(err.contains("not found"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_mixed_template_falls_back_to_string() {
|
||||
let manager = manager_with(&[("name", json!("alice"))]);
|
||||
|
||||
let result = manager.interpolate_raw("Hello {{name}}!").unwrap();
|
||||
|
||||
assert_eq!(result, json!("Hello alice!"));
|
||||
assert!(result.is_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_multiple_refs_fall_back_to_string() {
|
||||
let manager = manager_with(&[("a", json!(1)), ("b", json!(2))]);
|
||||
|
||||
let result = manager.interpolate_raw("{{a}}{{b}}").unwrap();
|
||||
|
||||
assert_eq!(result, json!("12"));
|
||||
assert!(result.is_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_no_refs_is_literal_string() {
|
||||
let manager = manager_with(&[]);
|
||||
|
||||
let result = manager.interpolate_raw("literal text").unwrap();
|
||||
|
||||
assert_eq!(result, json!("literal text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_whitespace_padding_still_resolves_pure_ref() {
|
||||
let manager = manager_with(&[("k", json!("v"))]);
|
||||
|
||||
let result = manager.interpolate_raw(" {{k}} ").unwrap();
|
||||
|
||||
assert_eq!(result, json!("v"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolate_raw_inner_spaces_treated_as_mixed() {
|
||||
let manager = manager_with(&[("k", json!("v"))]);
|
||||
|
||||
// `{{ k }}` is not a valid pure reference (spaces inside braces are
|
||||
// outside the allowed character set). Fall back to string interpolation
|
||||
// -- which doesn't match the regex either, so the literal passes through.
|
||||
let result = manager.interpolate_raw("{{ k }}").unwrap();
|
||||
|
||||
assert_eq!(result, json!("{{ k }}"));
|
||||
}
|
||||
}
|
||||
|
||||
+20
-20
@@ -940,9 +940,9 @@ prompt: Classify
|
||||
next: [retrieve_local, retrieve_web, retrieve_docs]
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
let targets = node.next.as_ref().expect("next should be present");
|
||||
|
||||
|
||||
assert!(targets.is_fan_out());
|
||||
assert_eq!(
|
||||
targets.as_slice(),
|
||||
@@ -963,9 +963,9 @@ prompt: Classify
|
||||
next: retrieve
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
let targets = node.next.as_ref().expect("next should be present");
|
||||
|
||||
|
||||
assert!(!targets.is_fan_out());
|
||||
assert_eq!(node.next_target(), Some("retrieve"));
|
||||
}
|
||||
@@ -979,9 +979,9 @@ prompt: Classify
|
||||
next: [a, b]
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
let err = node.next_single().unwrap_err().to_string();
|
||||
|
||||
|
||||
assert!(err.contains("Parallel fan-out"), "got: {err}");
|
||||
assert!(err.contains("not yet implemented"), "got: {err}");
|
||||
}
|
||||
@@ -994,9 +994,9 @@ type: llm
|
||||
prompt: Classify
|
||||
next: [retrieve]
|
||||
"#;
|
||||
|
||||
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
assert_eq!(node.next_single().unwrap(), Some("retrieve"));
|
||||
assert_eq!(node.next_target(), Some("retrieve"));
|
||||
}
|
||||
@@ -1036,9 +1036,9 @@ nodes:
|
||||
type: end
|
||||
output: ok
|
||||
"#;
|
||||
|
||||
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
assert_eq!(graph.reducers.len(), 8);
|
||||
assert_eq!(graph.reducers.get("sources"), Some(&Reducer::Append));
|
||||
assert_eq!(graph.reducers.get("findings"), Some(&Reducer::Extend));
|
||||
@@ -1053,18 +1053,18 @@ nodes:
|
||||
#[test]
|
||||
fn reducers_default_to_empty_when_block_absent() {
|
||||
let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n";
|
||||
|
||||
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
assert!(graph.reducers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_concurrency_defaults_to_eight() {
|
||||
let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n";
|
||||
|
||||
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
assert_eq!(graph.settings.max_concurrency, 8);
|
||||
}
|
||||
|
||||
@@ -1079,9 +1079,9 @@ nodes:
|
||||
x:
|
||||
type: end
|
||||
"#;
|
||||
|
||||
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
assert_eq!(graph.settings.max_concurrency, 16);
|
||||
}
|
||||
|
||||
@@ -1099,12 +1099,12 @@ max_concurrency: 5
|
||||
next: rank
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
let map = match node.node_type {
|
||||
NodeType::Map(m) => m,
|
||||
_ => panic!("expected Map variant"),
|
||||
};
|
||||
|
||||
|
||||
assert_eq!(map.over, "{{subjects}}");
|
||||
assert_eq!(map.as_name, "subject");
|
||||
assert_eq!(map.branch, "research_subject");
|
||||
@@ -1124,12 +1124,12 @@ branch: process
|
||||
collect_into: results
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
|
||||
let map = match node.node_type {
|
||||
NodeType::Map(m) => m,
|
||||
_ => panic!("expected Map variant"),
|
||||
};
|
||||
|
||||
|
||||
assert_eq!(map.output_key, "output");
|
||||
assert!(map.max_concurrency.is_none());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user