diff --git a/src/graph/mod.rs b/src/graph/mod.rs index c236d2c..3268ed2 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -5,7 +5,9 @@ pub mod llm; pub mod logging; pub mod parser; pub mod rag; +pub mod reducer; pub mod script; +pub mod staging; pub mod state; pub mod structured; pub mod types; diff --git a/src/graph/reducer.rs b/src/graph/reducer.rs new file mode 100644 index 0000000..1ceb058 --- /dev/null +++ b/src/graph/reducer.rs @@ -0,0 +1,407 @@ +use super::types::Reducer; +use anyhow::{Result, bail}; +use serde_json::{Number, Value}; + +/// Combines a branch's incoming write with the current state value (if any) +/// via the specified reducer. The result is what gets written back to live +/// state during the super-step merge phase. +/// +/// `current = None` means the key has no prior value in this super-step or in +/// live state. Most reducers treat absent as their identity (empty array, +/// empty string, no prior value). `Overwrite` ignores `current` entirely. +/// +/// Errors clearly when types are incompatible with the reducer (e.g. +/// `Sum` on a string), naming the reducer and which side (`current` / `incoming`) +/// has the wrong type. +pub fn apply(reducer: Reducer, current: Option<&Value>, incoming: Value) -> Result { + match reducer { + Reducer::Append => apply_append(current, incoming), + Reducer::Extend => apply_extend(current, incoming), + Reducer::Concat => apply_concat(current, incoming), + Reducer::Sum => apply_sum(current, incoming), + Reducer::Max => apply_max(current, incoming), + Reducer::Min => apply_min(current, incoming), + Reducer::Merge => apply_merge(current, incoming), + Reducer::Overwrite => Ok(incoming), + } +} + +fn apply_append(current: Option<&Value>, incoming: Value) -> Result { + let mut arr = match current { + None => Vec::new(), + Some(Value::Array(a)) => a.clone(), + Some(other) => bail!( + "reducer 'append' requires an array (or absent) for the current value, got {}", + type_name(other) + ), + }; + arr.push(incoming); + Ok(Value::Array(arr)) +} + +fn apply_extend(current: Option<&Value>, incoming: Value) -> Result { + let mut arr = match current { + None => Vec::new(), + Some(Value::Array(a)) => a.clone(), + Some(other) => bail!( + "reducer 'extend' requires an array (or absent) for the current value, got {}", + type_name(other) + ), + }; + match incoming { + Value::Array(items) => arr.extend(items), + other => bail!( + "reducer 'extend' requires an array for the incoming value, got {}", + type_name(&other) + ), + } + Ok(Value::Array(arr)) +} + +fn apply_concat(current: Option<&Value>, incoming: Value) -> Result { + let incoming_str = match incoming { + Value::String(s) => s, + other => bail!( + "reducer 'concat' requires a string for the incoming value, got {}", + type_name(&other) + ), + }; + let result = match current { + None => incoming_str, + Some(Value::String(c)) => { + if c.is_empty() { + incoming_str + } else { + format!("{c}\n{incoming_str}") + } + } + Some(other) => bail!( + "reducer 'concat' requires a string (or absent) for the current value, got {}", + type_name(other) + ), + }; + Ok(Value::String(result)) +} + +fn apply_sum(current: Option<&Value>, incoming: Value) -> Result { + let i = number_or_error(&incoming, "sum", "incoming")?; + let c = match current { + None => 0.0, + Some(value) => number_or_error(value, "sum", "current")?, + }; + Ok(json_number(c + i)) +} + +fn apply_max(current: Option<&Value>, incoming: Value) -> Result { + let i = number_or_error(&incoming, "max", "incoming")?; + match current { + None => Ok(json_number(i)), + Some(value) => { + let c = number_or_error(value, "max", "current")?; + Ok(json_number(c.max(i))) + } + } +} + +fn apply_min(current: Option<&Value>, incoming: Value) -> Result { + let i = number_or_error(&incoming, "min", "incoming")?; + match current { + None => Ok(json_number(i)), + Some(value) => { + let c = number_or_error(value, "min", "current")?; + Ok(json_number(c.min(i))) + } + } +} + +fn apply_merge(current: Option<&Value>, incoming: Value) -> Result { + let mut map = match current { + None => serde_json::Map::new(), + Some(Value::Object(m)) => m.clone(), + Some(other) => bail!( + "reducer 'merge' requires an object (or absent) for the current value, got {}", + type_name(other) + ), + }; + match incoming { + Value::Object(items) => { + for (k, v) in items { + map.insert(k, v); + } + } + other => bail!( + "reducer 'merge' requires an object for the incoming value, got {}", + type_name(&other) + ), + } + Ok(Value::Object(map)) +} + +fn number_or_error(value: &Value, reducer_name: &str, position: &str) -> Result { + match value.as_f64() { + Some(n) => Ok(n), + None => bail!( + "reducer '{reducer_name}' requires a number for the {position} value, got {}", + type_name(value) + ), + } +} + +fn type_name(value: &Value) -> &'static str { + match value { + Value::Null => "null", + Value::Bool(_) => "bool", + Value::Number(_) => "number", + Value::String(_) => "string", + Value::Array(_) => "array", + Value::Object(_) => "object", + } +} + +// Numeric reducers compute in f64 for simplicity. We preserve integer typing +// when the result is losslessly representable as i64 so `count: sum` stays an +// integer rather than degrading to a float. Non-finite values (NaN, Inf) can't +// arise from finite inputs to +/max/min, so the fallback never fires in practice. +fn json_number(n: f64) -> Value { + if n.fract() == 0.0 && n.is_finite() && n.abs() <= (i64::MAX as f64) { + Value::Number(Number::from(n as i64)) + } else { + match Number::from_f64(n) { + Some(num) => Value::Number(num), + None => Value::Null, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn append_to_absent_creates_single_element_array() { + let result = apply(Reducer::Append, None, json!("a")).unwrap(); + + assert_eq!(result, json!(["a"])); + } + + #[test] + fn append_pushes_onto_existing_array() { + let current = json!(["a", "b"]); + let result = apply(Reducer::Append, Some(¤t), json!("c")).unwrap(); + + assert_eq!(result, json!(["a", "b", "c"])); + } + + #[test] + fn append_errors_when_current_is_not_array() { + let current = json!("not an array"); + + let err = apply(Reducer::Append, Some(¤t), json!("x")) + .unwrap_err() + .to_string(); + + assert!(err.contains("'append'"), "got: {err}"); + assert!(err.contains("string"), "got: {err}"); + } + + #[test] + fn extend_concatenates_arrays() { + let current = json!([1, 2]); + + let result = apply(Reducer::Extend, Some(¤t), json!([3, 4])).unwrap(); + + assert_eq!(result, json!([1, 2, 3, 4])); + } + + #[test] + fn extend_from_absent_with_array() { + let result = apply(Reducer::Extend, None, json!([1, 2])).unwrap(); + + assert_eq!(result, json!([1, 2])); + } + + #[test] + fn extend_errors_when_incoming_is_not_array() { + let err = apply(Reducer::Extend, None, json!(42)) + .unwrap_err() + .to_string(); + + assert!(err.contains("'extend'"), "got: {err}"); + assert!(err.contains("number"), "got: {err}"); + assert!(err.contains("incoming"), "got: {err}"); + } + + #[test] + fn concat_joins_strings_with_newline() { + let current = json!("first"); + + let result = apply(Reducer::Concat, Some(¤t), json!("second")).unwrap(); + + assert_eq!(result, json!("first\nsecond")); + } + + #[test] + fn concat_from_absent_yields_incoming() { + let result = apply(Reducer::Concat, None, json!("hello")).unwrap(); + + assert_eq!(result, json!("hello")); + } + + #[test] + fn concat_skips_separator_when_current_is_empty_string() { + let current = json!(""); + + let result = apply(Reducer::Concat, Some(¤t), json!("first")).unwrap(); + + assert_eq!(result, json!("first")); + } + + #[test] + fn concat_errors_when_incoming_is_not_string() { + let err = apply(Reducer::Concat, None, json!(42)) + .unwrap_err() + .to_string(); + + assert!(err.contains("'concat'"), "got: {err}"); + assert!(err.contains("number"), "got: {err}"); + } + + #[test] + fn sum_adds_numbers() { + let current = json!(5); + + let result = apply(Reducer::Sum, Some(¤t), json!(7)).unwrap(); + + assert_eq!(result, json!(12)); + } + + #[test] + fn sum_starts_from_zero_when_current_absent() { + let result = apply(Reducer::Sum, None, json!(3.5)).unwrap(); + + assert_eq!(result, json!(3.5)); + } + + #[test] + fn sum_preserves_integer_type_for_whole_results() { + let current = json!(2); + + let result = apply(Reducer::Sum, Some(¤t), json!(3)).unwrap(); + + assert!(result.is_i64(), "expected integer, got {result:?}"); + assert_eq!(result, json!(5)); + } + + #[test] + fn sum_uses_float_when_result_has_fractional() { + let current = json!(1.5); + let result = apply(Reducer::Sum, Some(¤t), json!(2.25)).unwrap(); + + assert_eq!(result, json!(3.75)); + } + + #[test] + fn sum_errors_on_string_incoming() { + let err = apply(Reducer::Sum, None, json!("not a number")) + .unwrap_err() + .to_string(); + + assert!(err.contains("'sum'"), "got: {err}"); + assert!(err.contains("string"), "got: {err}"); + } + + #[test] + fn max_returns_larger_of_two() { + let current = json!(5); + let result = apply(Reducer::Max, Some(¤t), json!(3)).unwrap(); + assert_eq!(result, json!(5)); + + let result = apply(Reducer::Max, Some(¤t), json!(10)).unwrap(); + assert_eq!(result, json!(10)); + } + + #[test] + fn max_yields_incoming_when_current_absent() { + let result = apply(Reducer::Max, None, json!(42)).unwrap(); + + assert_eq!(result, json!(42)); + } + + #[test] + fn min_returns_smaller_of_two() { + let current = json!(5); + let result = apply(Reducer::Min, Some(¤t), json!(3)).unwrap(); + assert_eq!(result, json!(3)); + + let result = apply(Reducer::Min, Some(¤t), json!(10)).unwrap(); + assert_eq!(result, json!(5)); + } + + #[test] + fn min_errors_on_non_numeric_current() { + let current = json!("oops"); + + let err = apply(Reducer::Min, Some(¤t), json!(1)) + .unwrap_err() + .to_string(); + + assert!(err.contains("'min'"), "got: {err}"); + assert!(err.contains("current"), "got: {err}"); + } + + #[test] + fn merge_unions_objects_with_incoming_winning_collisions() { + let current = json!({ "a": 1, "b": 2 }); + let incoming = json!({ "b": 99, "c": 3 }); + + let result = apply(Reducer::Merge, Some(¤t), incoming).unwrap(); + + assert_eq!(result, json!({ "a": 1, "b": 99, "c": 3 })); + } + + #[test] + fn merge_from_absent_yields_incoming_object() { + let result = apply(Reducer::Merge, None, json!({ "k": "v" })).unwrap(); + + assert_eq!(result, json!({ "k": "v" })); + } + + #[test] + fn merge_errors_when_incoming_is_not_object() { + let err = apply(Reducer::Merge, None, json!([1, 2])) + .unwrap_err() + .to_string(); + + assert!(err.contains("'merge'"), "got: {err}"); + assert!(err.contains("array"), "got: {err}"); + } + + #[test] + fn merge_errors_when_current_is_not_object() { + let current = json!("not object"); + + let err = apply(Reducer::Merge, Some(¤t), json!({ "k": "v" })) + .unwrap_err() + .to_string(); + + assert!(err.contains("'merge'"), "got: {err}"); + assert!(err.contains("current"), "got: {err}"); + } + + #[test] + fn overwrite_ignores_current_and_returns_incoming() { + let current = json!("old"); + + let result = apply(Reducer::Overwrite, Some(¤t), json!("new")).unwrap(); + + assert_eq!(result, json!("new")); + } + + #[test] + fn overwrite_works_with_absent_current() { + let result = apply(Reducer::Overwrite, None, json!(42)).unwrap(); + + assert_eq!(result, json!(42)); + } +} diff --git a/src/graph/staging.rs b/src/graph/staging.rs new file mode 100644 index 0000000..c89631c --- /dev/null +++ b/src/graph/staging.rs @@ -0,0 +1,97 @@ +use serde_json::Value; +use std::collections::HashMap; + +#[derive(Debug, Default, Clone)] +pub struct StagingArea { + writes: HashMap, +} + +#[allow(dead_code)] +impl StagingArea { + pub fn new() -> Self { + Self::default() + } + + pub fn write(&mut self, key: impl Into, value: Value) { + self.writes.insert(key.into(), value); + } + + pub fn get(&self, key: &str) -> Option<&Value> { + self.writes.get(key) + } + + pub fn is_empty(&self) -> bool { + self.writes.is_empty() + } + + pub fn len(&self) -> usize { + self.writes.len() + } + + pub fn into_writes(self) -> HashMap { + self.writes + } +} + +/// Published form of one branch's writes for the super-step merge phase. +/// Callers assemble these into a deterministically-ordered `Vec` keyed by +/// `(node_id, invocation_index)` before passing to +/// `StateManager::apply_branch_writes`. `invocation_index` is 0 for normal +/// branches and the input-list position for map sub-branches — so multiple +/// invocations of the same `branch:` node by a `map` are still totally ordered. +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct BranchWrites { + pub node_id: String, + pub invocation_index: usize, + pub writes: HashMap, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn new_staging_area_is_empty() { + let s = StagingArea::new(); + + assert!(s.is_empty()); + assert_eq!(s.len(), 0); + } + + #[test] + fn write_stores_value_under_key() { + let mut s = StagingArea::new(); + + s.write("key", json!("value")); + + assert_eq!(s.get("key"), Some(&json!("value"))); + assert_eq!(s.len(), 1); + assert!(!s.is_empty()); + } + + #[test] + fn write_overwrites_existing_key() { + let mut s = StagingArea::new(); + + s.write("k", json!(1)); + s.write("k", json!(2)); + + assert_eq!(s.get("k"), Some(&json!(2))); + assert_eq!(s.len(), 1); + } + + #[test] + fn into_writes_consumes_and_yields_map() { + let mut s = StagingArea::new(); + s.write("a", json!(1)); + s.write("b", json!(2)); + + let writes = s.into_writes(); + + assert_eq!(writes.len(), 2); + assert_eq!(writes.get("a"), Some(&json!(1))); + assert_eq!(writes.get("b"), Some(&json!(2))); + } +} diff --git a/src/graph/state.rs b/src/graph/state.rs index 7b78caa..a7d6f8b 100644 --- a/src/graph/state.rs +++ b/src/graph/state.rs @@ -1,14 +1,16 @@ use super::MAX_STATE_SIZE_BYTES; -use super::types::GraphState; +use super::reducer; +use super::staging::BranchWrites; +use super::types::{GraphState, Reducer}; use crate::utils::temp_file; use anyhow::{Context, Result, bail}; use fancy_regex::Regex; use serde_json::Value; -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use std::fs; use std::fs::write; use std::path::PathBuf; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; static TEMPLATE_VAR_RE: LazyLock = LazyLock::new(|| Regex::new(r"\{\{([a-zA-Z0-9_\.\[\]]+)\}\}").expect("invalid template regex")); @@ -156,6 +158,96 @@ impl StateManager { let _ = fs::remove_file(path); } } + + /// Returns an `Arc`-wrapped snapshot of the current graph state. Each branch + /// in a parallel super-step shares this snapshot for reads; their writes + /// accumulate into per-branch `StagingArea` instances, which are merged via + /// `apply_branch_writes` at the end of the super-step. + /// + /// Distinct from the older `snapshot()` method (returns a `HashMap` clone of + /// the data only — used by `script_executor` to ship state to child processes). + #[allow(dead_code)] + pub fn read_snapshot(&self) -> Arc { + Arc::new(self.state.clone()) + } + + /// Commits a deterministically-ordered set of per-branch writes back into + /// live state, applying declared reducers where they exist. + /// + /// Caller must pre-sort `writes` by `(node_id, invocation_index)` so that + /// non-commutative reducers (`Concat`, `Merge`) produce reproducible output. + /// + /// Errors when a key has writers from ≥2 branches but no reducer declared. + /// The validator (Phase C) catches this at load time; this runtime check is + /// defense-in-depth against a malformed or out-of-date validator missing it. + #[allow(dead_code)] + pub fn apply_branch_writes( + &mut self, + writes: Vec, + reducers: &HashMap, + ) -> Result<()> { + let mut by_key: BTreeMap> = BTreeMap::new(); + for branch in writes { + for (key, value) in branch.writes { + by_key.entry(key).or_default().push(value); + } + } + + for (key, values) in by_key { + match reducers.get(&key).copied() { + Some(r) => { + let mut current = self.state.get(&key).cloned(); + for value in values { + current = Some(reducer::apply(r, current.as_ref(), value)?); + } + if let Some(final_value) = current { + self.state.set(key, final_value); + } + } + None if values.len() == 1 => { + self.state.set(key, values.into_iter().next().unwrap()); + } + None => { + bail!( + "Key '{key}' was written by {} parallel branches but has no \ + reducer declared. Add a reducer for '{key}' to the graph's \ + `reducers:` block, or rename one writer.", + values.len() + ); + } + } + } + + Ok(()) + } + + /// Interpolates a template and returns a typed JSON `Value`. + /// + /// Two paths depending on the template shape: + /// - **Pure single reference** (the entire trimmed template is a single + /// `{{key}}` expression, e.g. `"{{subjects}}"`, `"{{user.name}}"`, + /// `"{{items[0]}}"`) — returns the typed `Value` at that key, preserving + /// numbers, bools, arrays, and objects. Errors if the key is missing. + /// - **Mixed template** (multiple refs, surrounding text, or no refs) — + /// falls back to string interpolation via `interpolate()` and returns + /// `Value::String(...)`. Strict on missing keys. + /// + /// Required by: + /// - `map.over: "{{subjects}}"` — must resolve to a JSON array, not its string form + /// - `state_updates` writes that should preserve the source type (a `cost_usd: "{{api_cost}}"` + /// write should land as a Number, not a String) + #[allow(dead_code)] + pub fn interpolate_raw(&self, template: &str) -> Result { + let trimmed = template.trim(); + if let Some(key) = single_reference_key(trimmed) { + match self.get_nested_value(key) { + Some(value) => Ok(value.clone()), + None => bail!("Template interpolation failed: '{key}' not found in state"), + } + } else { + Ok(Value::String(self.interpolate(template)?)) + } + } } impl Drop for StateManager { @@ -214,6 +306,22 @@ fn split_indices(segment: &str) -> Option<(&str, Vec)> { Some((key, indices)) } +// Returns the inner key when `template` is exactly a single `{{key}}` reference +// (no surrounding text, no other braces). Mirrors the character set the +// TEMPLATE_VAR_RE regex accepts so `interpolate_raw` and `interpolate` stay +// consistent about what counts as a valid key. +fn single_reference_key(template: &str) -> Option<&str> { + let inner = template.strip_prefix("{{")?.strip_suffix("}}")?; + if inner.contains("{{") || inner.contains("}}") { + return None; + } + let valid = !inner.is_empty() + && inner + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '.' || c == '[' || c == ']'); + valid.then_some(inner) +} + fn value_to_string(value: &Value) -> String { match value { Value::String(s) => s.clone(), @@ -532,4 +640,308 @@ mod tests { assert_eq!(manager.state().get("status"), Some(&json!("complete"))); } + + fn branch(node_id: &str, idx: usize, writes: &[(&str, Value)]) -> BranchWrites { + let mut map = HashMap::new(); + for (k, v) in writes { + map.insert((*k).into(), v.clone()); + } + BranchWrites { + node_id: node_id.into(), + invocation_index: idx, + writes: map, + } + } + + #[test] + fn read_snapshot_returns_arc_with_current_state() { + let manager = manager_with(&[("k", json!("v"))]); + + let snap = manager.read_snapshot(); + + assert_eq!(snap.get("k"), Some(&json!("v"))); + } + + #[test] + fn read_snapshot_is_independent_of_later_mutations() { + let mut manager = manager_with(&[("count", json!(1))]); + let snap = manager.read_snapshot(); + + manager.state_mut().set("count".into(), json!(999)); + + assert_eq!(snap.get("count"), Some(&json!(1))); + assert_eq!(manager.state().get("count"), Some(&json!(999))); + } + + #[test] + fn apply_branch_writes_empty_is_noop() { + let mut manager = manager_with(&[("k", json!("v"))]); + let reducers = HashMap::new(); + + manager.apply_branch_writes(vec![], &reducers).unwrap(); + + assert_eq!(manager.state().get("k"), Some(&json!("v"))); + } + + #[test] + fn apply_branch_writes_single_writer_no_reducer_overwrites() { + let mut manager = manager_with(&[]); + let reducers = HashMap::new(); + + manager + .apply_branch_writes(vec![branch("n", 0, &[("k", json!(42))])], &reducers) + .unwrap(); + + assert_eq!(manager.state().get("k"), Some(&json!(42))); + } + + #[test] + fn apply_branch_writes_disjoint_keys_all_land() { + let mut manager = manager_with(&[]); + let reducers = HashMap::new(); + + manager + .apply_branch_writes( + vec![ + branch("a", 0, &[("x", json!(1))]), + branch("b", 0, &[("y", json!(2))]), + branch("c", 0, &[("z", json!(3))]), + ], + &reducers, + ) + .unwrap(); + + assert_eq!(manager.state().get("x"), Some(&json!(1))); + assert_eq!(manager.state().get("y"), Some(&json!(2))); + assert_eq!(manager.state().get("z"), Some(&json!(3))); + } + + #[test] + fn apply_branch_writes_three_appends_preserve_input_order() { + let mut manager = manager_with(&[]); + let mut reducers = HashMap::new(); + reducers.insert("items".into(), Reducer::Append); + + manager + .apply_branch_writes( + vec![ + branch("a", 0, &[("items", json!("first"))]), + branch("b", 0, &[("items", json!("second"))]), + branch("c", 0, &[("items", json!("third"))]), + ], + &reducers, + ) + .unwrap(); + + assert_eq!( + manager.state().get("items"), + Some(&json!(["first", "second", "third"])) + ); + } + + #[test] + fn apply_branch_writes_collision_without_reducer_bails() { + let mut manager = manager_with(&[]); + let reducers = HashMap::new(); + + let err = manager + .apply_branch_writes( + vec![ + branch("a", 0, &[("k", json!("first"))]), + branch("b", 0, &[("k", json!("second"))]), + ], + &reducers, + ) + .unwrap_err() + .to_string(); + + assert!(err.contains("'k'"), "got: {err}"); + assert!(err.contains("no reducer"), "got: {err}"); + assert!(err.contains("2 parallel branches"), "got: {err}"); + } + + #[test] + fn apply_branch_writes_sum_reducer_accumulates_with_existing_state() { + let mut manager = manager_with(&[("cost", json!(10))]); + let mut reducers = HashMap::new(); + reducers.insert("cost".into(), Reducer::Sum); + + manager + .apply_branch_writes( + vec![ + branch("a", 0, &[("cost", json!(5))]), + branch("b", 0, &[("cost", json!(7))]), + ], + &reducers, + ) + .unwrap(); + + assert_eq!(manager.state().get("cost"), Some(&json!(22))); + } + + #[test] + fn apply_branch_writes_concat_respects_branch_order() { + let mut manager = manager_with(&[]); + let mut reducers = HashMap::new(); + reducers.insert("log".into(), Reducer::Concat); + + manager + .apply_branch_writes( + vec![ + branch("a", 0, &[("log", json!("alpha"))]), + branch("b", 0, &[("log", json!("bravo"))]), + ], + &reducers, + ) + .unwrap(); + + assert_eq!(manager.state().get("log"), Some(&json!("alpha\nbravo"))); + } + + #[test] + fn apply_branch_writes_mixed_keys_with_and_without_reducers() { + let mut manager = manager_with(&[]); + let mut reducers = HashMap::new(); + reducers.insert("results".into(), Reducer::Append); + + manager + .apply_branch_writes( + vec![ + branch( + "a", + 0, + &[("results", json!("x")), ("status", json!("ok_a"))], + ), + branch("b", 0, &[("results", json!("y"))]), + ], + &reducers, + ) + .unwrap(); + + assert_eq!(manager.state().get("results"), Some(&json!(["x", "y"]))); + assert_eq!(manager.state().get("status"), Some(&json!("ok_a"))); + } + + #[test] + fn interpolate_raw_pure_ref_returns_typed_number() { + let manager = manager_with(&[("count", json!(42))]); + + let result = manager.interpolate_raw("{{count}}").unwrap(); + + assert_eq!(result, json!(42)); + assert!(result.is_i64()); + } + + #[test] + fn interpolate_raw_pure_ref_returns_typed_array() { + let manager = manager_with(&[("items", json!(["a", "b", "c"]))]); + + let result = manager.interpolate_raw("{{items}}").unwrap(); + + assert_eq!(result, json!(["a", "b", "c"])); + assert!(result.is_array()); + } + + #[test] + fn interpolate_raw_pure_ref_returns_typed_object() { + let manager = manager_with(&[("user", json!({ "name": "alice", "age": 30 }))]); + + let result = manager.interpolate_raw("{{user}}").unwrap(); + + assert_eq!(result, json!({ "name": "alice", "age": 30 })); + assert!(result.is_object()); + } + + #[test] + fn interpolate_raw_pure_ref_returns_typed_bool() { + let manager = manager_with(&[("flag", json!(true))]); + + let result = manager.interpolate_raw("{{flag}}").unwrap(); + + assert_eq!(result, json!(true)); + assert!(result.is_boolean()); + } + + #[test] + fn interpolate_raw_nested_path_returns_typed_value() { + let manager = manager_with(&[("user", json!({ "email": "x@y.com" }))]); + + let result = manager.interpolate_raw("{{user.email}}").unwrap(); + + assert_eq!(result, json!("x@y.com")); + assert!(result.is_string()); + } + + #[test] + fn interpolate_raw_array_index_returns_typed_value() { + let manager = manager_with(&[("items", json!([10, 20, 30]))]); + + let result = manager.interpolate_raw("{{items[1]}}").unwrap(); + + assert_eq!(result, json!(20)); + assert!(result.is_i64()); + } + + #[test] + fn interpolate_raw_missing_pure_ref_errors() { + let manager = manager_with(&[]); + + let err = manager + .interpolate_raw("{{ghost}}") + .unwrap_err() + .to_string(); + + assert!(err.contains("'ghost'"), "got: {err}"); + assert!(err.contains("not found"), "got: {err}"); + } + + #[test] + fn interpolate_raw_mixed_template_falls_back_to_string() { + let manager = manager_with(&[("name", json!("alice"))]); + + let result = manager.interpolate_raw("Hello {{name}}!").unwrap(); + + assert_eq!(result, json!("Hello alice!")); + assert!(result.is_string()); + } + + #[test] + fn interpolate_raw_multiple_refs_fall_back_to_string() { + let manager = manager_with(&[("a", json!(1)), ("b", json!(2))]); + + let result = manager.interpolate_raw("{{a}}{{b}}").unwrap(); + + assert_eq!(result, json!("12")); + assert!(result.is_string()); + } + + #[test] + fn interpolate_raw_no_refs_is_literal_string() { + let manager = manager_with(&[]); + + let result = manager.interpolate_raw("literal text").unwrap(); + + assert_eq!(result, json!("literal text")); + } + + #[test] + fn interpolate_raw_whitespace_padding_still_resolves_pure_ref() { + let manager = manager_with(&[("k", json!("v"))]); + + let result = manager.interpolate_raw(" {{k}} ").unwrap(); + + assert_eq!(result, json!("v")); + } + + #[test] + fn interpolate_raw_inner_spaces_treated_as_mixed() { + let manager = manager_with(&[("k", json!("v"))]); + + // `{{ k }}` is not a valid pure reference (spaces inside braces are + // outside the allowed character set). Fall back to string interpolation + // -- which doesn't match the regex either, so the literal passes through. + let result = manager.interpolate_raw("{{ k }}").unwrap(); + + assert_eq!(result, json!("{{ k }}")); + } } diff --git a/src/graph/types.rs b/src/graph/types.rs index 81ba8f9..70e8ccb 100644 --- a/src/graph/types.rs +++ b/src/graph/types.rs @@ -940,9 +940,9 @@ prompt: Classify next: [retrieve_local, retrieve_web, retrieve_docs] "#; let node: Node = serde_yaml::from_str(yaml).unwrap(); - + let targets = node.next.as_ref().expect("next should be present"); - + assert!(targets.is_fan_out()); assert_eq!( targets.as_slice(), @@ -963,9 +963,9 @@ prompt: Classify next: retrieve "#; let node: Node = serde_yaml::from_str(yaml).unwrap(); - + let targets = node.next.as_ref().expect("next should be present"); - + assert!(!targets.is_fan_out()); assert_eq!(node.next_target(), Some("retrieve")); } @@ -979,9 +979,9 @@ prompt: Classify next: [a, b] "#; let node: Node = serde_yaml::from_str(yaml).unwrap(); - + let err = node.next_single().unwrap_err().to_string(); - + assert!(err.contains("Parallel fan-out"), "got: {err}"); assert!(err.contains("not yet implemented"), "got: {err}"); } @@ -994,9 +994,9 @@ type: llm prompt: Classify next: [retrieve] "#; - + let node: Node = serde_yaml::from_str(yaml).unwrap(); - + assert_eq!(node.next_single().unwrap(), Some("retrieve")); assert_eq!(node.next_target(), Some("retrieve")); } @@ -1036,9 +1036,9 @@ nodes: type: end output: ok "#; - + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); - + assert_eq!(graph.reducers.len(), 8); assert_eq!(graph.reducers.get("sources"), Some(&Reducer::Append)); assert_eq!(graph.reducers.get("findings"), Some(&Reducer::Extend)); @@ -1053,18 +1053,18 @@ nodes: #[test] fn reducers_default_to_empty_when_block_absent() { let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n"; - + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); - + assert!(graph.reducers.is_empty()); } #[test] fn max_concurrency_defaults_to_eight() { let yaml = "name: g\nstart: x\nnodes:\n x:\n type: end\n"; - + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); - + assert_eq!(graph.settings.max_concurrency, 8); } @@ -1079,9 +1079,9 @@ nodes: x: type: end "#; - + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); - + assert_eq!(graph.settings.max_concurrency, 16); } @@ -1099,12 +1099,12 @@ max_concurrency: 5 next: rank "#; let node: Node = serde_yaml::from_str(yaml).unwrap(); - + let map = match node.node_type { NodeType::Map(m) => m, _ => panic!("expected Map variant"), }; - + assert_eq!(map.over, "{{subjects}}"); assert_eq!(map.as_name, "subject"); assert_eq!(map.branch, "research_subject"); @@ -1124,12 +1124,12 @@ branch: process collect_into: results "#; let node: Node = serde_yaml::from_str(yaml).unwrap(); - + let map = match node.node_type { NodeType::Map(m) => m, _ => panic!("expected Map variant"), }; - + assert_eq!(map.output_key, "output"); assert!(map.max_concurrency.is_none()); }