test: implemented integration tests for the parallel frontier-based graph scheduling

This commit is contained in:
2026-05-20 16:09:07 -06:00
parent 20c28b55d5
commit 26de81e84e
+339 -39
View File
@@ -228,15 +228,17 @@ impl GraphExecutor {
let step_outcome = step_result.with_context(|| format!("at node '{node_id}'"))?; let step_outcome = step_result.with_context(|| format!("at node '{node_id}'"))?;
match step_outcome { match step_outcome {
StepResult::Continue(target) => { StepResult::Continue(targets) => {
logger.routing(&node_id, &target); for target in &targets {
logger.routing(&node_id, target);
}
let diff = branch_state.diff_against(snapshot.as_ref()); let diff = branch_state.diff_against(snapshot.as_ref());
branch_writes.push(BranchWrites { branch_writes.push(BranchWrites {
node_id: node_id.clone(), node_id: node_id.clone(),
invocation_index: 0, invocation_index: 0,
writes: diff, writes: diff,
}); });
next_frontier.insert(target); next_frontier.extend(targets);
} }
StepResult::End(output) => { StepResult::End(output) => {
end_results.push((node_id.clone(), branch_state, output)); end_results.push((node_id.clone(), branch_state, output));
@@ -309,7 +311,12 @@ impl StepContext<'_> {
} }
enum StepResult { enum StepResult {
Continue(String), // The set of next-node ids the executor should add to the next super-step's
// frontier. A `Vec` of length 1 for sequential routing (default) and the
// full target list for fan-out (`next: [a, b, ...]`). Dynamic single-route
// decisions (script `_next`, approval routes, LLM/RAG fallback) always emit
// a single-element vec.
Continue(Vec<String>),
End(String), End(String),
} }
@@ -323,13 +330,8 @@ async fn step(
match &node.node_type { match &node.node_type {
NodeType::Agent(agent_node) => { NodeType::Agent(agent_node) => {
AgentNodeExecutor::execute(agent_node, state, ctx).await?; AgentNodeExecutor::execute(agent_node, state, ctx).await?;
let next = node let targets = static_next_targets(node, current, "agent")?;
.next_single()? Ok(StepResult::Continue(targets))
.ok_or_else(|| {
anyhow!("agent node '{current}' has no `next` and is not an end node")
})?
.to_string();
Ok(StepResult::Continue(next))
} }
NodeType::Script(script_node) => { NodeType::Script(script_node) => {
let dynamic = match step_ctx.script_executor.execute(script_node, state).await { let dynamic = match step_ctx.script_executor.execute(script_node, state).await {
@@ -343,57 +345,96 @@ async fn step(
fallback, fallback,
e e
); );
return Ok(StepResult::Continue(fallback.clone())); return Ok(StepResult::Continue(vec![fallback.clone()]));
} }
return Err(e); return Err(e);
} }
}; };
let next = match dynamic { let targets = match dynamic {
Some(n) => n, Some(n) => vec![n],
None => node None => static_next_targets(node, current, "script")?,
.next_single()?
.ok_or_else(|| {
anyhow!(
"script node '{current}' did not emit `_next` and has no static `next`"
)
})?
.to_string(),
}; };
Ok(StepResult::Continue(next)) Ok(StepResult::Continue(targets))
} }
NodeType::Approval(approval_node) => { NodeType::Approval(approval_node) => {
let next = ApprovalNodeExecutor::execute(approval_node, state, ctx).await?; let next = ApprovalNodeExecutor::execute(approval_node, state, ctx).await?;
Ok(StepResult::Continue(next)) Ok(StepResult::Continue(vec![next]))
} }
NodeType::Input(input_node) => { NodeType::Input(input_node) => {
let next_id = node.next_single()?; let next_id = first_next_target(node);
let next = InputNodeExecutor::execute(input_node, next_id, state, ctx).await?; let next = InputNodeExecutor::execute(input_node, next_id, state, ctx).await?;
Ok(StepResult::Continue(next)) Ok(StepResult::Continue(vec![next]))
} }
NodeType::Llm(llm_node) => { NodeType::Llm(llm_node) => {
let next_id = node.next_single()?; let primary = first_next_target(node).map(str::to_string);
let next = LlmNodeExecutor::execute(llm_node, next_id, state, ctx).await?; let llm_routing =
Ok(StepResult::Continue(next)) LlmNodeExecutor::execute(llm_node, primary.as_deref(), state, ctx).await?;
let targets = resolve_branching_next(node, &llm_routing);
Ok(StepResult::Continue(targets))
} }
NodeType::Rag(rag_node) => { NodeType::Rag(rag_node) => {
let next_id = node.next_single()?; let primary = first_next_target(node).map(str::to_string);
let next = RagNodeExecutor::execute(rag_node, current, next_id, state, ctx).await?; let rag_routing =
Ok(StepResult::Continue(next)) RagNodeExecutor::execute(rag_node, current, primary.as_deref(), state, ctx).await?;
let targets = resolve_branching_next(node, &rag_routing);
Ok(StepResult::Continue(targets))
} }
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))), NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
NodeType::Map(map_node) => { NodeType::Map(map_node) => {
let next = node let targets = static_next_targets(node, current, "map")?;
.next_single()?
.ok_or_else(|| {
anyhow!("map node '{current}' has no `next` and is not an end node")
})?
.to_string();
MapNodeExecutor::execute(map_node, state, ctx, step_ctx, current).await?; MapNodeExecutor::execute(map_node, state, ctx, step_ctx, current).await?;
Ok(StepResult::Continue(next)) Ok(StepResult::Continue(targets))
} }
} }
} }
// Returns all `next:` targets from the node (handles both `One` and `Many`),
// erroring if no `next` is set.
fn static_next_targets(node: &Node, current: &str, kind: &str) -> Result<Vec<String>> {
node.next
.as_ref()
.map(|t| t.as_slice().to_vec())
.ok_or_else(|| anyhow!("{kind} node '{current}' has no `next` and is not an end node"))
}
// Returns the first declared `next:` target as a borrowed `&str`, or `None` if
// no `next` is set. Used by node executors that take `Option<&str>` for their
// primary routing argument (LLM, RAG, Input).
fn first_next_target(node: &Node) -> Option<&str> {
node.next
.as_ref()
.and_then(|t| t.as_slice().first().map(|s| s.as_str()))
}
// Resolves the actual frontier-advance targets after an LLM/RAG node ran.
//
// LLM/RAG executors return their chosen routing as a String — either the
// primary `next:` target (success path) or the node's `fallback:` (failure
// path with retry exhausted). We can't tell these apart from inside step()
// without an API refactor, so we compare strings: if the returned routing
// matches the first declared `next` target, treat as success and (for
// fan-out) use ALL declared targets; otherwise treat as fallback and use the
// returned target alone.
//
// Known limitation: if a fan-out node's `fallback:` is set to the same node
// id as its first `next:` target, a successful run is indistinguishable from
// a fallback run — both look like "returned the first target". The result is
// that the executor advances to all Many targets in the fallback case (which
// is the OPPOSITE of the user's likely intent). Workaround: choose a
// `fallback:` distinct from any `next:` target.
fn resolve_branching_next(node: &Node, returned_routing: &str) -> Vec<String> {
let Some(targets) = &node.next else {
return vec![returned_routing.to_string()];
};
let slice = targets.as_slice();
let first_matches = slice.first().is_some_and(|s| s == returned_routing);
if first_matches && slice.len() > 1 {
slice.to_vec()
} else {
vec![returned_routing.to_string()]
}
}
fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String { fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String {
apply_simple_state_updates(end_node.state_updates.as_ref(), state); apply_simple_state_updates(end_node.state_updates.as_ref(), state);
state.interpolate_lenient(&end_node.output) state.interpolate_lenient(&end_node.output)
@@ -490,3 +531,262 @@ mod tests {
assert_eq!(state.state().get("k"), Some(&json!("new-old"))); assert_eq!(state.state().get("k"), Some(&json!("new-old")));
} }
} }
#[cfg(test)]
mod integration_tests {
use super::*;
use crate::config::{AppState, WorkingMode};
use crate::utils::{create_abort_signal, temp_file};
use std::fs;
fn cmd_available(name: &str) -> bool {
which::which(name).is_ok()
}
struct TestWorkspace {
dir: PathBuf,
}
impl TestWorkspace {
fn new() -> Self {
let dir = temp_file("-graph-integration-", "");
fs::create_dir_all(&dir).unwrap();
Self { dir }
}
fn write_script(&self, name: &str, contents: &str) {
fs::write(self.dir.join(name), contents).unwrap();
}
}
impl Drop for TestWorkspace {
fn drop(&mut self) {
let _ = fs::remove_dir_all(&self.dir);
}
}
fn make_ctx() -> RequestContext {
RequestContext::new(Arc::new(AppState::test_default()), WorkingMode::Cmd)
}
#[tokio::test]
async fn static_fan_out_merges_branch_writes_via_append_reducer() {
if !cmd_available("bash") {
eprintln!("skipping: bash not available");
return;
}
let ws = TestWorkspace::new();
ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n");
ws.write_script(
"worker_a.sh",
"#!/bin/bash\necho '{\"results\": \"alpha\"}'\n",
);
ws.write_script(
"worker_b.sh",
"#!/bin/bash\necho '{\"results\": \"beta\"}'\n",
);
let yaml = r#"
name: static_fan_out_test
start: dispatcher
reducers:
results: append
nodes:
dispatcher:
type: script
script: dispatcher.sh
state_updates: {}
next: [worker_a, worker_b]
worker_a:
type: script
script: worker_a.sh
state_updates: {}
next: join
worker_b:
type: script
script: worker_b.sh
state_updates: {}
next: join
join:
type: end
output: "{{results}}"
"#;
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
let mut ctx = make_ctx();
let abort = create_abort_signal();
let result = GraphExecutor::new(graph, &ws.dir)
.execute(&mut ctx, abort)
.await
.unwrap_or_else(|e| panic!("executor failed: {e:#}"));
let parsed: Value = serde_json::from_str(&result)
.unwrap_or_else(|_| panic!("expected JSON array, got: {result}"));
let arr = parsed.as_array().expect("results should be an array");
assert_eq!(arr.len(), 2, "expected 2 elements, got: {result}");
let strs: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
assert!(strs.contains(&"alpha"), "missing 'alpha' in {strs:?}");
assert!(strs.contains(&"beta"), "missing 'beta' in {strs:?}");
}
#[tokio::test]
async fn map_over_list_collects_outputs_in_input_order() {
if !cmd_available("python3") {
eprintln!("skipping: python3 not available");
return;
}
let ws = TestWorkspace::new();
ws.write_script(
"doubler.py",
r#"#!/usr/bin/env python3
import os, json
state = json.loads(os.environ.get("GRAPH_STATE", "{}"))
val = state["item"]
print(json.dumps({"output": val * 2}))
"#,
);
let yaml = r#"
name: map_input_order_test
start: fan_out
initial_state:
items: [1, 2, 3, 4, 5]
nodes:
fan_out:
type: map
over: "{{items}}"
as: item
branch: doubler
collect_into: doubled
next: done
doubler:
type: script
script: doubler.py
state_updates: {}
done:
type: end
output: "{{doubled}}"
"#;
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
let mut ctx = make_ctx();
let abort = create_abort_signal();
let result = GraphExecutor::new(graph, &ws.dir)
.execute(&mut ctx, abort)
.await
.unwrap_or_else(|e| panic!("executor failed: {e:#}"));
let parsed: Value = serde_json::from_str(&result)
.unwrap_or_else(|_| panic!("expected JSON array, got: {result}"));
let arr = parsed.as_array().expect("doubled should be an array");
let nums: Vec<i64> = arr
.iter()
.map(|v| v.as_i64().expect("each item should be int"))
.collect();
assert_eq!(
nums,
vec![2, 4, 6, 8, 10],
"map outputs should be in input order, not finish order"
);
}
#[tokio::test]
async fn parallel_branch_error_aborts_super_step() {
if !cmd_available("bash") {
eprintln!("skipping: bash not available");
return;
}
let ws = TestWorkspace::new();
ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n");
ws.write_script(
"worker_ok.sh",
"#!/bin/bash\necho '{\"results\": \"ok\"}'\n",
);
ws.write_script(
"worker_fail.sh",
"#!/bin/bash\necho 'simulated failure' >&2\nexit 1\n",
);
let yaml = r#"
name: branch_error_test
start: dispatcher
reducers:
results: append
nodes:
dispatcher:
type: script
script: dispatcher.sh
state_updates: {}
next: [worker_ok, worker_fail]
worker_ok:
type: script
script: worker_ok.sh
state_updates: {}
next: join
worker_fail:
type: script
script: worker_fail.sh
state_updates: {}
next: join
join:
type: end
output: "{{results}}"
"#;
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
let mut ctx = make_ctx();
let abort = create_abort_signal();
let result = GraphExecutor::new(graph, &ws.dir)
.execute(&mut ctx, abort)
.await;
assert!(result.is_err(), "expected branch error to propagate");
let err = format!("{:#}", result.unwrap_err());
assert!(
err.contains("worker_fail"),
"error should mention failing node: {err}"
);
}
#[tokio::test]
async fn multi_end_in_super_step_is_rejected() {
if !cmd_available("bash") {
eprintln!("skipping: bash not available");
return;
}
let ws = TestWorkspace::new();
ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n");
let yaml = r#"
name: multi_end_test
start: dispatcher
nodes:
dispatcher:
type: script
script: dispatcher.sh
state_updates: {}
next: [end_a, end_b]
end_a:
type: end
output: "from a"
end_b:
type: end
output: "from b"
"#;
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
let mut ctx = make_ctx();
let abort = create_abort_signal();
let result = GraphExecutor::new(graph, &ws.dir)
.execute(&mut ctx, abort)
.await;
assert!(result.is_err(), "expected multi-End to be rejected");
let err = format!("{:#}", result.unwrap_err());
assert!(
err.contains("multiple End targets"),
"error should explain multi-End cause: {err}"
);
assert!(
err.contains("end_a") && err.contains("end_b"),
"error should list both End nodes: {err}"
);
}
}