feat: implemented the frontier-based scheduling for the graph executor with simplified state management (gotta love .clone)
This commit is contained in:
@@ -149,6 +149,56 @@ impl RequestContext {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Forks the context for one parallel branch of a graph super-step.
|
||||||
|
///
|
||||||
|
/// Each branch gets a fresh, owned clone — mutations (role swap,
|
||||||
|
/// `before/after_chat_completion`, tool tracker, last_message, etc.) are
|
||||||
|
/// scoped to the branch and discarded when the branch finishes. The
|
||||||
|
/// user-visible state communication happens through the graph's
|
||||||
|
/// `StateManager` (via `fork_for_branch_state` + `diff_against` +
|
||||||
|
/// `apply_branch_writes` reducers), NOT through `RequestContext`.
|
||||||
|
///
|
||||||
|
/// Distinction from `new_for_child`: `new_for_child` builds a fresh context
|
||||||
|
/// for a SPAWNED SUB-AGENT (different agent identity, different supervisor
|
||||||
|
/// hierarchy, depth+1, fresh tool tracker). `fork_for_branch` keeps the
|
||||||
|
/// caller's identity and supervisor hierarchy — it's a sibling clone of the
|
||||||
|
/// SAME logical agent, running one of N parallel work items.
|
||||||
|
///
|
||||||
|
/// Behavior of per-field cloning:
|
||||||
|
/// - `Arc`-wrapped fields (`app`, `rag`, `supervisor`, `parent_supervisor`,
|
||||||
|
/// `inbox`, `escalation_queue`) — shared via Arc::clone
|
||||||
|
/// - Owned heap fields (`model`, `role`, `session`, `agent`, `tool_scope`,
|
||||||
|
/// `todo_list`, etc.) — deep `.clone()` so the branch can mutate freely
|
||||||
|
/// - `auto_continue_count` reset to 0 (each branch starts a fresh
|
||||||
|
/// continuation budget)
|
||||||
|
/// - `last_continuation_response` reset to None
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn fork_for_branch(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
app: Arc::clone(&self.app),
|
||||||
|
macro_flag: self.macro_flag,
|
||||||
|
info_flag: self.info_flag,
|
||||||
|
working_mode: self.working_mode,
|
||||||
|
model: self.model.clone(),
|
||||||
|
agent_variables: self.agent_variables.clone(),
|
||||||
|
role: self.role.clone(),
|
||||||
|
session: self.session.clone(),
|
||||||
|
rag: self.rag.clone(),
|
||||||
|
agent: self.agent.clone(),
|
||||||
|
last_message: self.last_message.clone(),
|
||||||
|
tool_scope: self.tool_scope.clone(),
|
||||||
|
supervisor: self.supervisor.clone(),
|
||||||
|
parent_supervisor: self.parent_supervisor.clone(),
|
||||||
|
self_agent_id: self.self_agent_id.clone(),
|
||||||
|
inbox: self.inbox.clone(),
|
||||||
|
escalation_queue: self.escalation_queue.clone(),
|
||||||
|
current_depth: self.current_depth,
|
||||||
|
auto_continue_count: 0,
|
||||||
|
todo_list: self.todo_list.clone(),
|
||||||
|
last_continuation_response: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new_for_child(
|
pub fn new_for_child(
|
||||||
app: Arc<AppState>,
|
app: Arc<AppState>,
|
||||||
parent: &Self,
|
parent: &Self,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ use serde_json::{Value, json};
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ToolScope {
|
pub struct ToolScope {
|
||||||
pub functions: Functions,
|
pub functions: Functions,
|
||||||
pub mcp_runtime: McpRuntime,
|
pub mcp_runtime: McpRuntime,
|
||||||
@@ -24,7 +25,7 @@ impl Default for ToolScope {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default, Clone)]
|
||||||
pub struct McpRuntime {
|
pub struct McpRuntime {
|
||||||
pub servers: HashMap<String, Arc<ConnectedServer>>,
|
pub servers: HashMap<String, Arc<ConnectedServer>>,
|
||||||
}
|
}
|
||||||
|
|||||||
+170
-45
@@ -3,6 +3,7 @@ use super::llm::LlmNodeExecutor;
|
|||||||
use super::logging::GraphLogger;
|
use super::logging::GraphLogger;
|
||||||
use super::rag::RagNodeExecutor;
|
use super::rag::RagNodeExecutor;
|
||||||
use super::script::ScriptExecutor;
|
use super::script::ScriptExecutor;
|
||||||
|
use super::staging::BranchWrites;
|
||||||
use super::state::StateManager;
|
use super::state::StateManager;
|
||||||
use super::types::{EndNode, Graph, Node, NodeType};
|
use super::types::{EndNode, Graph, Node, NodeType};
|
||||||
use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
|
use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
|
||||||
@@ -10,11 +11,13 @@ use super::validator::{AgentValidationContext, GraphValidator};
|
|||||||
use crate::config::RequestContext;
|
use crate::config::RequestContext;
|
||||||
use crate::utils::AbortSignal;
|
use crate::utils::AbortSignal;
|
||||||
use anyhow::{Context, Result, anyhow, bail};
|
use anyhow::{Context, Result, anyhow, bail};
|
||||||
|
use futures_util::future::join_all;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
pub struct GraphExecutor {
|
pub struct GraphExecutor {
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
@@ -70,74 +73,196 @@ impl GraphExecutor {
|
|||||||
let script_executor = ScriptExecutor::new(&base_dir);
|
let script_executor = ScriptExecutor::new(&base_dir);
|
||||||
let max_iterations = graph.settings.max_loop_iterations;
|
let max_iterations = graph.settings.max_loop_iterations;
|
||||||
let graph_timeout = graph.settings.timeout.map(Duration::from_secs);
|
let graph_timeout = graph.settings.timeout.map(Duration::from_secs);
|
||||||
|
let max_concurrency = graph.settings.max_concurrency;
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let mut current = graph.start.clone();
|
let mut frontier: HashSet<String> = HashSet::from([graph.start.clone()]);
|
||||||
logger.graph_start(¤t, graph.nodes.len());
|
logger.graph_start(&graph.start, graph.nodes.len());
|
||||||
|
|
||||||
|
loop {
|
||||||
|
if frontier.is_empty() {
|
||||||
|
bail!(
|
||||||
|
"Graph '{}' frontier emptied without reaching an End node",
|
||||||
|
graph.name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let output = loop {
|
|
||||||
if abort_signal.aborted() {
|
if abort_signal.aborted() {
|
||||||
bail!("Graph '{}' aborted at '{}'", graph.name, current);
|
bail!(
|
||||||
|
"Graph '{}' aborted before super-step with frontier {:?}",
|
||||||
|
graph.name,
|
||||||
|
sorted_frontier(&frontier)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if let Some(t) = graph_timeout
|
if let Some(t) = graph_timeout
|
||||||
&& start.elapsed() > t
|
&& start.elapsed() > t
|
||||||
{
|
{
|
||||||
bail!(
|
bail!(
|
||||||
"Graph '{}' timed out after {}s at '{}'",
|
"Graph '{}' timed out after {}s before super-step with frontier {:?}",
|
||||||
graph.name,
|
graph.name,
|
||||||
t.as_secs(),
|
t.as_secs(),
|
||||||
current
|
sorted_frontier(&frontier)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
state.state_mut().visit_node(¤t);
|
// Loop-count and visit tracking on live state, BEFORE forking.
|
||||||
let visits = state.state().loop_count(¤t);
|
// This counts every entry to a node toward max_loop_iterations
|
||||||
if visits > max_iterations {
|
// regardless of how many parallel branches converged on it.
|
||||||
|
for node_id in &frontier {
|
||||||
|
state.state_mut().visit_node(node_id);
|
||||||
|
let visits = state.state().loop_count(node_id);
|
||||||
|
if visits > max_iterations {
|
||||||
|
bail!(
|
||||||
|
"Node '{}' visited {} times (max_loop_iterations={}). \
|
||||||
|
Possible infinite loop.",
|
||||||
|
node_id,
|
||||||
|
visits,
|
||||||
|
max_iterations
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for node_id in &frontier {
|
||||||
|
let node = graph.get_node(node_id).ok_or_else(|| {
|
||||||
|
anyhow!("Node '{}' not found in graph '{}'", node_id, graph.name)
|
||||||
|
})?;
|
||||||
|
let visits = state.state().loop_count(node_id);
|
||||||
|
logger.node_entry(node, visits);
|
||||||
|
}
|
||||||
|
let snapshot_label = if frontier.len() == 1 {
|
||||||
|
frontier.iter().next().cloned().unwrap_or_default()
|
||||||
|
} else {
|
||||||
|
format!("super-step {{{}}}", sorted_frontier(&frontier).join(","))
|
||||||
|
};
|
||||||
|
logger.state_snapshot(&snapshot_label, &state);
|
||||||
|
|
||||||
|
let snapshot = state.read_snapshot();
|
||||||
|
let semaphore = Arc::new(Semaphore::new(max_concurrency));
|
||||||
|
|
||||||
|
let mut branch_tasks = Vec::with_capacity(frontier.len());
|
||||||
|
for node_id in &frontier {
|
||||||
|
let node = graph
|
||||||
|
.get_node(node_id)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
anyhow!("Node '{}' not found in graph '{}'", node_id, graph.name)
|
||||||
|
})?
|
||||||
|
.clone();
|
||||||
|
let branch_state = state.fork_for_branch_state();
|
||||||
|
let branch_ctx = ctx.fork_for_branch();
|
||||||
|
let script_exec_clone = script_executor.clone();
|
||||||
|
let graph_name = graph.name.clone();
|
||||||
|
let current = node_id.clone();
|
||||||
|
let sem_clone = semaphore.clone();
|
||||||
|
let abort_clone = abort_signal.clone();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
let _permit = sem_clone
|
||||||
|
.acquire()
|
||||||
|
.await
|
||||||
|
.expect("semaphore should not be closed");
|
||||||
|
if abort_clone.aborted() {
|
||||||
|
return (
|
||||||
|
current.clone(),
|
||||||
|
branch_state,
|
||||||
|
Err(anyhow!("branch aborted")),
|
||||||
|
Duration::default(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let node_start = Instant::now();
|
||||||
|
let mut state = branch_state;
|
||||||
|
let mut ctx = branch_ctx;
|
||||||
|
let result = step(
|
||||||
|
&node,
|
||||||
|
&mut state,
|
||||||
|
&mut ctx,
|
||||||
|
&script_exec_clone,
|
||||||
|
&graph_name,
|
||||||
|
¤t,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
let elapsed = node_start.elapsed();
|
||||||
|
(current, state, result, elapsed)
|
||||||
|
});
|
||||||
|
branch_tasks.push(task);
|
||||||
|
}
|
||||||
|
|
||||||
|
let joined = join_all(branch_tasks).await;
|
||||||
|
|
||||||
|
let mut branch_writes: Vec<BranchWrites> = Vec::new();
|
||||||
|
let mut next_frontier: HashSet<String> = HashSet::new();
|
||||||
|
let mut end_results: Vec<(String, StateManager, String)> = Vec::new();
|
||||||
|
|
||||||
|
for join_result in joined {
|
||||||
|
let (node_id, branch_state, step_result, elapsed) =
|
||||||
|
join_result.map_err(|e| anyhow!("Branch task panicked: {e}"))?;
|
||||||
|
logger.record_timing(&node_id, elapsed);
|
||||||
|
|
||||||
|
let step_outcome = step_result.with_context(|| format!("at node '{node_id}'"))?;
|
||||||
|
|
||||||
|
match step_outcome {
|
||||||
|
StepResult::Continue(target) => {
|
||||||
|
logger.routing(&node_id, &target);
|
||||||
|
let diff = branch_state.diff_against(snapshot.as_ref());
|
||||||
|
branch_writes.push(BranchWrites {
|
||||||
|
node_id: node_id.clone(),
|
||||||
|
invocation_index: 0,
|
||||||
|
writes: diff,
|
||||||
|
});
|
||||||
|
next_frontier.insert(target);
|
||||||
|
}
|
||||||
|
StepResult::End(output) => {
|
||||||
|
end_results.push((node_id.clone(), branch_state, output));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if end_results.len() > 1 {
|
||||||
|
let mut ids: Vec<String> =
|
||||||
|
end_results.iter().map(|(id, _, _)| id.clone()).collect();
|
||||||
|
ids.sort();
|
||||||
bail!(
|
bail!(
|
||||||
"Node '{}' visited {} times (max_loop_iterations={}). \
|
"super-step ended with multiple End targets ({}). \
|
||||||
Possible infinite loop.",
|
Fan-out branches must converge at a join node before \
|
||||||
current,
|
terminating. To fix: route all parallel branches to a \
|
||||||
visits,
|
single shared next-node, then terminate from there.",
|
||||||
max_iterations
|
ids.join(", ")
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let node = graph
|
// Sort by (node_id, invocation_index) so non-commutative reducers
|
||||||
.get_node(¤t)
|
// like Concat/Merge produce deterministic output across runs.
|
||||||
.ok_or_else(|| anyhow!("Node '{}' not found in graph '{}'", current, graph.name))?;
|
branch_writes.sort_by(|a, b| {
|
||||||
|
a.node_id
|
||||||
|
.cmp(&b.node_id)
|
||||||
|
.then(a.invocation_index.cmp(&b.invocation_index))
|
||||||
|
});
|
||||||
|
state.apply_branch_writes(branch_writes, &graph.reducers)?;
|
||||||
|
|
||||||
logger.node_entry(node, visits);
|
if let Some((node_id, end_state, output)) = end_results.into_iter().next() {
|
||||||
logger.state_snapshot(¤t, &state);
|
let diff = end_state.diff_against(snapshot.as_ref());
|
||||||
|
state.apply_branch_writes(
|
||||||
let node_start = Instant::now();
|
vec![BranchWrites {
|
||||||
let step_result = step(
|
node_id: node_id.clone(),
|
||||||
node,
|
invocation_index: 0,
|
||||||
&mut state,
|
writes: diff,
|
||||||
ctx,
|
}],
|
||||||
&script_executor,
|
&graph.reducers,
|
||||||
&graph.name,
|
)?;
|
||||||
¤t,
|
logger.graph_complete(&node_id, start.elapsed());
|
||||||
)
|
return Ok(output);
|
||||||
.await;
|
|
||||||
logger.record_timing(¤t, node_start.elapsed());
|
|
||||||
let next = step_result.with_context(|| format!("at node '{current}'"))?;
|
|
||||||
|
|
||||||
match next {
|
|
||||||
StepResult::Continue(next_id) => {
|
|
||||||
logger.routing(¤t, &next_id);
|
|
||||||
current = next_id;
|
|
||||||
}
|
|
||||||
StepResult::End(out) => {
|
|
||||||
logger.graph_complete(¤t, start.elapsed());
|
|
||||||
break out;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
Ok(output)
|
frontier = next_frontier;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sorted_frontier(frontier: &HashSet<String>) -> Vec<String> {
|
||||||
|
let mut v: Vec<String> = frontier.iter().cloned().collect();
|
||||||
|
v.sort();
|
||||||
|
v
|
||||||
|
}
|
||||||
|
|
||||||
enum StepResult {
|
enum StepResult {
|
||||||
Continue(String),
|
Continue(String),
|
||||||
End(String),
|
End(String),
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use std::time::Duration;
|
|||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
pub struct ScriptExecutor {
|
pub struct ScriptExecutor {
|
||||||
base_dir: PathBuf,
|
base_dir: PathBuf,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,38 +1,6 @@
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
#[derive(Debug, Default, Clone)]
|
|
||||||
pub struct StagingArea {
|
|
||||||
writes: HashMap<String, Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
impl StagingArea {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self::default()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn write(&mut self, key: impl Into<String>, value: Value) {
|
|
||||||
self.writes.insert(key.into(), value);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get(&self, key: &str) -> Option<&Value> {
|
|
||||||
self.writes.get(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.writes.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
self.writes.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_writes(self) -> HashMap<String, Value> {
|
|
||||||
self.writes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Published form of one branch's writes for the super-step merge phase.
|
/// Published form of one branch's writes for the super-step merge phase.
|
||||||
/// Callers assemble these into a deterministically-ordered `Vec` keyed by
|
/// Callers assemble these into a deterministically-ordered `Vec` keyed by
|
||||||
/// `(node_id, invocation_index)` before passing to
|
/// `(node_id, invocation_index)` before passing to
|
||||||
@@ -40,58 +8,8 @@ impl StagingArea {
|
|||||||
/// branches and the input-list position for map sub-branches — so multiple
|
/// branches and the input-list position for map sub-branches — so multiple
|
||||||
/// invocations of the same `branch:` node by a `map` are still totally ordered.
|
/// invocations of the same `branch:` node by a `map` are still totally ordered.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(dead_code)]
|
|
||||||
pub struct BranchWrites {
|
pub struct BranchWrites {
|
||||||
pub node_id: String,
|
pub node_id: String,
|
||||||
pub invocation_index: usize,
|
pub invocation_index: usize,
|
||||||
pub writes: HashMap<String, Value>,
|
pub writes: HashMap<String, Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn new_staging_area_is_empty() {
|
|
||||||
let s = StagingArea::new();
|
|
||||||
|
|
||||||
assert!(s.is_empty());
|
|
||||||
assert_eq!(s.len(), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn write_stores_value_under_key() {
|
|
||||||
let mut s = StagingArea::new();
|
|
||||||
|
|
||||||
s.write("key", json!("value"));
|
|
||||||
|
|
||||||
assert_eq!(s.get("key"), Some(&json!("value")));
|
|
||||||
assert_eq!(s.len(), 1);
|
|
||||||
assert!(!s.is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn write_overwrites_existing_key() {
|
|
||||||
let mut s = StagingArea::new();
|
|
||||||
|
|
||||||
s.write("k", json!(1));
|
|
||||||
s.write("k", json!(2));
|
|
||||||
|
|
||||||
assert_eq!(s.get("k"), Some(&json!(2)));
|
|
||||||
assert_eq!(s.len(), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn into_writes_consumes_and_yields_map() {
|
|
||||||
let mut s = StagingArea::new();
|
|
||||||
s.write("a", json!(1));
|
|
||||||
s.write("b", json!(2));
|
|
||||||
|
|
||||||
let writes = s.into_writes();
|
|
||||||
|
|
||||||
assert_eq!(writes.len(), 2);
|
|
||||||
assert_eq!(writes.get("a"), Some(&json!(1)));
|
|
||||||
assert_eq!(writes.get("b"), Some(&json!(2)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
+118
-8
@@ -159,13 +159,44 @@ impl StateManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an `Arc`-wrapped snapshot of the current graph state. Each branch
|
/// Forks state for a parallel branch: returns a fully-owned `StateManager`
|
||||||
/// in a parallel super-step shares this snapshot for reads; their writes
|
/// seeded from the current state's data. The branch mutates its fork
|
||||||
/// accumulate into per-branch `StagingArea` instances, which are merged via
|
/// freely; callers extract its writes via `diff_against` after the branch
|
||||||
/// `apply_branch_writes` at the end of the super-step.
|
/// completes, then merge them via `apply_branch_writes`.
|
||||||
///
|
///
|
||||||
/// Distinct from the older `snapshot()` method (returns a `HashMap` clone of
|
/// Distinct from `read_snapshot` (returns a shared `Arc<GraphState>` for
|
||||||
/// the data only — used by `script_executor` to ship state to child processes).
|
/// reads) — `fork_for_branch_state` returns a writable owned clone.
|
||||||
|
pub fn fork_for_branch_state(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
state: self.state.clone(),
|
||||||
|
temp_file: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the keys whose values differ from `snapshot`. Use this after a
|
||||||
|
/// branch finishes to extract its writes (input to `apply_branch_writes`).
|
||||||
|
/// Keys present in `self` but absent from `snapshot`, or with different
|
||||||
|
/// values, count as writes. Deletions are not represented (no current node
|
||||||
|
/// executor deletes state).
|
||||||
|
pub fn diff_against(&self, snapshot: &GraphState) -> HashMap<String, Value> {
|
||||||
|
let mut diff = HashMap::new();
|
||||||
|
for (k, v) in self.state.data() {
|
||||||
|
if snapshot.get(k) != Some(v) {
|
||||||
|
diff.insert(k.clone(), v.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
diff
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns an `Arc`-wrapped snapshot of the current graph state. Each
|
||||||
|
/// branch in a parallel super-step uses this snapshot as the baseline for
|
||||||
|
/// its `diff_against` call at branch end. The executor extracts each
|
||||||
|
/// branch's writes (the diff) and merges them via `apply_branch_writes` at
|
||||||
|
/// the super-step boundary.
|
||||||
|
///
|
||||||
|
/// Distinct from the older `snapshot()` method (returns a `HashMap` clone
|
||||||
|
/// of the data only — used by `script_executor` to ship state to child
|
||||||
|
/// processes).
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn read_snapshot(&self) -> Arc<GraphState> {
|
pub fn read_snapshot(&self) -> Arc<GraphState> {
|
||||||
Arc::new(self.state.clone())
|
Arc::new(self.state.clone())
|
||||||
@@ -936,12 +967,91 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn interpolate_raw_inner_spaces_treated_as_mixed() {
|
fn interpolate_raw_inner_spaces_treated_as_mixed() {
|
||||||
let manager = manager_with(&[("k", json!("v"))]);
|
let manager = manager_with(&[("k", json!("v"))]);
|
||||||
|
|
||||||
// `{{ k }}` is not a valid pure reference (spaces inside braces are
|
// `{{ k }}` is not a valid pure reference (spaces inside braces are
|
||||||
// outside the allowed character set). Fall back to string interpolation
|
// outside the allowed character set). Fall back to string interpolation
|
||||||
// -- which doesn't match the regex either, so the literal passes through.
|
// -- which doesn't match the regex either, so the literal passes through.
|
||||||
let result = manager.interpolate_raw("{{ k }}").unwrap();
|
let result = manager.interpolate_raw("{{ k }}").unwrap();
|
||||||
|
|
||||||
assert_eq!(result, json!("{{ k }}"));
|
assert_eq!(result, json!("{{ k }}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fork_for_branch_state_copies_data() {
|
||||||
|
let parent = manager_with(&[("a", json!(1)), ("b", json!("x"))]);
|
||||||
|
|
||||||
|
let fork = parent.fork_for_branch_state();
|
||||||
|
|
||||||
|
assert_eq!(fork.state().get("a"), Some(&json!(1)));
|
||||||
|
assert_eq!(fork.state().get("b"), Some(&json!("x")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fork_for_branch_state_isolates_writes_from_parent() {
|
||||||
|
let parent = manager_with(&[("count", json!(10))]);
|
||||||
|
let mut fork = parent.fork_for_branch_state();
|
||||||
|
|
||||||
|
fork.state_mut().set("count".into(), json!(999));
|
||||||
|
|
||||||
|
assert_eq!(fork.state().get("count"), Some(&json!(999)));
|
||||||
|
assert_eq!(parent.state().get("count"), Some(&json!(10)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fork_for_branch_state_does_not_share_temp_file_lifecycle() {
|
||||||
|
let parent = manager_with(&[("k", json!("v"))]);
|
||||||
|
let fork = parent.fork_for_branch_state();
|
||||||
|
|
||||||
|
assert!(fork.temp_file.is_none());
|
||||||
|
// Dropping the fork must not affect the parent's data
|
||||||
|
drop(fork);
|
||||||
|
assert_eq!(parent.state().get("k"), Some(&json!("v")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn diff_against_returns_empty_when_unchanged() {
|
||||||
|
let original = manager_with(&[("a", json!(1)), ("b", json!(2))]);
|
||||||
|
let fork = original.fork_for_branch_state();
|
||||||
|
|
||||||
|
let diff = fork.diff_against(original.state());
|
||||||
|
|
||||||
|
assert!(diff.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn diff_against_reports_newly_written_keys() {
|
||||||
|
let original = manager_with(&[]);
|
||||||
|
let mut fork = original.fork_for_branch_state();
|
||||||
|
fork.state_mut().set("new".into(), json!(42));
|
||||||
|
|
||||||
|
let diff = fork.diff_against(original.state());
|
||||||
|
|
||||||
|
assert_eq!(diff.len(), 1);
|
||||||
|
assert_eq!(diff.get("new"), Some(&json!(42)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn diff_against_reports_changed_values_only() {
|
||||||
|
let original = manager_with(&[("a", json!(1)), ("b", json!(2)), ("c", json!(3))]);
|
||||||
|
let mut fork = original.fork_for_branch_state();
|
||||||
|
fork.state_mut().set("b".into(), json!(99));
|
||||||
|
|
||||||
|
let diff = fork.diff_against(original.state());
|
||||||
|
|
||||||
|
assert_eq!(diff.len(), 1);
|
||||||
|
assert_eq!(diff.get("b"), Some(&json!(99)));
|
||||||
|
assert!(!diff.contains_key("a"));
|
||||||
|
assert!(!diff.contains_key("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn diff_against_does_not_report_reverted_writes() {
|
||||||
|
// Branch writes then writes back to the original value; net change = 0.
|
||||||
|
let original = manager_with(&[("x", json!("initial"))]);
|
||||||
|
let mut fork = original.fork_for_branch_state();
|
||||||
|
fork.state_mut().set("x".into(), json!("modified"));
|
||||||
|
fork.state_mut().set("x".into(), json!("initial"));
|
||||||
|
|
||||||
|
let diff = fork.diff_against(original.state());
|
||||||
|
|
||||||
|
assert!(diff.is_empty(), "reverted write should not appear in diff");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user