test: implemented integration tests for the parallel frontier-based graph scheduling
This commit is contained in:
+339
-39
@@ -228,15 +228,17 @@ impl GraphExecutor {
|
||||
let step_outcome = step_result.with_context(|| format!("at node '{node_id}'"))?;
|
||||
|
||||
match step_outcome {
|
||||
StepResult::Continue(target) => {
|
||||
logger.routing(&node_id, &target);
|
||||
StepResult::Continue(targets) => {
|
||||
for target in &targets {
|
||||
logger.routing(&node_id, target);
|
||||
}
|
||||
let diff = branch_state.diff_against(snapshot.as_ref());
|
||||
branch_writes.push(BranchWrites {
|
||||
node_id: node_id.clone(),
|
||||
invocation_index: 0,
|
||||
writes: diff,
|
||||
});
|
||||
next_frontier.insert(target);
|
||||
next_frontier.extend(targets);
|
||||
}
|
||||
StepResult::End(output) => {
|
||||
end_results.push((node_id.clone(), branch_state, output));
|
||||
@@ -309,7 +311,12 @@ impl StepContext<'_> {
|
||||
}
|
||||
|
||||
enum StepResult {
|
||||
Continue(String),
|
||||
// The set of next-node ids the executor should add to the next super-step's
|
||||
// frontier. A `Vec` of length 1 for sequential routing (default) and the
|
||||
// full target list for fan-out (`next: [a, b, ...]`). Dynamic single-route
|
||||
// decisions (script `_next`, approval routes, LLM/RAG fallback) always emit
|
||||
// a single-element vec.
|
||||
Continue(Vec<String>),
|
||||
End(String),
|
||||
}
|
||||
|
||||
@@ -323,13 +330,8 @@ async fn step(
|
||||
match &node.node_type {
|
||||
NodeType::Agent(agent_node) => {
|
||||
AgentNodeExecutor::execute(agent_node, state, ctx).await?;
|
||||
let next = node
|
||||
.next_single()?
|
||||
.ok_or_else(|| {
|
||||
anyhow!("agent node '{current}' has no `next` and is not an end node")
|
||||
})?
|
||||
.to_string();
|
||||
Ok(StepResult::Continue(next))
|
||||
let targets = static_next_targets(node, current, "agent")?;
|
||||
Ok(StepResult::Continue(targets))
|
||||
}
|
||||
NodeType::Script(script_node) => {
|
||||
let dynamic = match step_ctx.script_executor.execute(script_node, state).await {
|
||||
@@ -343,57 +345,96 @@ async fn step(
|
||||
fallback,
|
||||
e
|
||||
);
|
||||
return Ok(StepResult::Continue(fallback.clone()));
|
||||
return Ok(StepResult::Continue(vec![fallback.clone()]));
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
let next = match dynamic {
|
||||
Some(n) => n,
|
||||
None => node
|
||||
.next_single()?
|
||||
.ok_or_else(|| {
|
||||
anyhow!(
|
||||
"script node '{current}' did not emit `_next` and has no static `next`"
|
||||
)
|
||||
})?
|
||||
.to_string(),
|
||||
let targets = match dynamic {
|
||||
Some(n) => vec![n],
|
||||
None => static_next_targets(node, current, "script")?,
|
||||
};
|
||||
Ok(StepResult::Continue(next))
|
||||
Ok(StepResult::Continue(targets))
|
||||
}
|
||||
NodeType::Approval(approval_node) => {
|
||||
let next = ApprovalNodeExecutor::execute(approval_node, state, ctx).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
Ok(StepResult::Continue(vec![next]))
|
||||
}
|
||||
NodeType::Input(input_node) => {
|
||||
let next_id = node.next_single()?;
|
||||
let next_id = first_next_target(node);
|
||||
let next = InputNodeExecutor::execute(input_node, next_id, state, ctx).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
Ok(StepResult::Continue(vec![next]))
|
||||
}
|
||||
NodeType::Llm(llm_node) => {
|
||||
let next_id = node.next_single()?;
|
||||
let next = LlmNodeExecutor::execute(llm_node, next_id, state, ctx).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
let primary = first_next_target(node).map(str::to_string);
|
||||
let llm_routing =
|
||||
LlmNodeExecutor::execute(llm_node, primary.as_deref(), state, ctx).await?;
|
||||
let targets = resolve_branching_next(node, &llm_routing);
|
||||
Ok(StepResult::Continue(targets))
|
||||
}
|
||||
NodeType::Rag(rag_node) => {
|
||||
let next_id = node.next_single()?;
|
||||
let next = RagNodeExecutor::execute(rag_node, current, next_id, state, ctx).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
let primary = first_next_target(node).map(str::to_string);
|
||||
let rag_routing =
|
||||
RagNodeExecutor::execute(rag_node, current, primary.as_deref(), state, ctx).await?;
|
||||
let targets = resolve_branching_next(node, &rag_routing);
|
||||
Ok(StepResult::Continue(targets))
|
||||
}
|
||||
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
||||
NodeType::Map(map_node) => {
|
||||
let next = node
|
||||
.next_single()?
|
||||
.ok_or_else(|| {
|
||||
anyhow!("map node '{current}' has no `next` and is not an end node")
|
||||
})?
|
||||
.to_string();
|
||||
let targets = static_next_targets(node, current, "map")?;
|
||||
MapNodeExecutor::execute(map_node, state, ctx, step_ctx, current).await?;
|
||||
Ok(StepResult::Continue(next))
|
||||
Ok(StepResult::Continue(targets))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns all `next:` targets from the node (handles both `One` and `Many`),
|
||||
// erroring if no `next` is set.
|
||||
fn static_next_targets(node: &Node, current: &str, kind: &str) -> Result<Vec<String>> {
|
||||
node.next
|
||||
.as_ref()
|
||||
.map(|t| t.as_slice().to_vec())
|
||||
.ok_or_else(|| anyhow!("{kind} node '{current}' has no `next` and is not an end node"))
|
||||
}
|
||||
|
||||
// Returns the first declared `next:` target as a borrowed `&str`, or `None` if
|
||||
// no `next` is set. Used by node executors that take `Option<&str>` for their
|
||||
// primary routing argument (LLM, RAG, Input).
|
||||
fn first_next_target(node: &Node) -> Option<&str> {
|
||||
node.next
|
||||
.as_ref()
|
||||
.and_then(|t| t.as_slice().first().map(|s| s.as_str()))
|
||||
}
|
||||
|
||||
// Resolves the actual frontier-advance targets after an LLM/RAG node ran.
|
||||
//
|
||||
// LLM/RAG executors return their chosen routing as a String — either the
|
||||
// primary `next:` target (success path) or the node's `fallback:` (failure
|
||||
// path with retry exhausted). We can't tell these apart from inside step()
|
||||
// without an API refactor, so we compare strings: if the returned routing
|
||||
// matches the first declared `next` target, treat as success and (for
|
||||
// fan-out) use ALL declared targets; otherwise treat as fallback and use the
|
||||
// returned target alone.
|
||||
//
|
||||
// Known limitation: if a fan-out node's `fallback:` is set to the same node
|
||||
// id as its first `next:` target, a successful run is indistinguishable from
|
||||
// a fallback run — both look like "returned the first target". The result is
|
||||
// that the executor advances to all Many targets in the fallback case (which
|
||||
// is the OPPOSITE of the user's likely intent). Workaround: choose a
|
||||
// `fallback:` distinct from any `next:` target.
|
||||
fn resolve_branching_next(node: &Node, returned_routing: &str) -> Vec<String> {
|
||||
let Some(targets) = &node.next else {
|
||||
return vec![returned_routing.to_string()];
|
||||
};
|
||||
let slice = targets.as_slice();
|
||||
let first_matches = slice.first().is_some_and(|s| s == returned_routing);
|
||||
if first_matches && slice.len() > 1 {
|
||||
slice.to_vec()
|
||||
} else {
|
||||
vec![returned_routing.to_string()]
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String {
|
||||
apply_simple_state_updates(end_node.state_updates.as_ref(), state);
|
||||
state.interpolate_lenient(&end_node.output)
|
||||
@@ -490,3 +531,262 @@ mod tests {
|
||||
assert_eq!(state.state().get("k"), Some(&json!("new-old")));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests {
|
||||
use super::*;
|
||||
use crate::config::{AppState, WorkingMode};
|
||||
use crate::utils::{create_abort_signal, temp_file};
|
||||
use std::fs;
|
||||
|
||||
fn cmd_available(name: &str) -> bool {
|
||||
which::which(name).is_ok()
|
||||
}
|
||||
|
||||
struct TestWorkspace {
|
||||
dir: PathBuf,
|
||||
}
|
||||
|
||||
impl TestWorkspace {
|
||||
fn new() -> Self {
|
||||
let dir = temp_file("-graph-integration-", "");
|
||||
fs::create_dir_all(&dir).unwrap();
|
||||
Self { dir }
|
||||
}
|
||||
|
||||
fn write_script(&self, name: &str, contents: &str) {
|
||||
fs::write(self.dir.join(name), contents).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for TestWorkspace {
|
||||
fn drop(&mut self) {
|
||||
let _ = fs::remove_dir_all(&self.dir);
|
||||
}
|
||||
}
|
||||
|
||||
fn make_ctx() -> RequestContext {
|
||||
RequestContext::new(Arc::new(AppState::test_default()), WorkingMode::Cmd)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn static_fan_out_merges_branch_writes_via_append_reducer() {
|
||||
if !cmd_available("bash") {
|
||||
eprintln!("skipping: bash not available");
|
||||
return;
|
||||
}
|
||||
let ws = TestWorkspace::new();
|
||||
ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n");
|
||||
ws.write_script(
|
||||
"worker_a.sh",
|
||||
"#!/bin/bash\necho '{\"results\": \"alpha\"}'\n",
|
||||
);
|
||||
ws.write_script(
|
||||
"worker_b.sh",
|
||||
"#!/bin/bash\necho '{\"results\": \"beta\"}'\n",
|
||||
);
|
||||
|
||||
let yaml = r#"
|
||||
name: static_fan_out_test
|
||||
start: dispatcher
|
||||
reducers:
|
||||
results: append
|
||||
nodes:
|
||||
dispatcher:
|
||||
type: script
|
||||
script: dispatcher.sh
|
||||
state_updates: {}
|
||||
next: [worker_a, worker_b]
|
||||
worker_a:
|
||||
type: script
|
||||
script: worker_a.sh
|
||||
state_updates: {}
|
||||
next: join
|
||||
worker_b:
|
||||
type: script
|
||||
script: worker_b.sh
|
||||
state_updates: {}
|
||||
next: join
|
||||
join:
|
||||
type: end
|
||||
output: "{{results}}"
|
||||
"#;
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
let mut ctx = make_ctx();
|
||||
let abort = create_abort_signal();
|
||||
let result = GraphExecutor::new(graph, &ws.dir)
|
||||
.execute(&mut ctx, abort)
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("executor failed: {e:#}"));
|
||||
|
||||
let parsed: Value = serde_json::from_str(&result)
|
||||
.unwrap_or_else(|_| panic!("expected JSON array, got: {result}"));
|
||||
let arr = parsed.as_array().expect("results should be an array");
|
||||
assert_eq!(arr.len(), 2, "expected 2 elements, got: {result}");
|
||||
let strs: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect();
|
||||
assert!(strs.contains(&"alpha"), "missing 'alpha' in {strs:?}");
|
||||
assert!(strs.contains(&"beta"), "missing 'beta' in {strs:?}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn map_over_list_collects_outputs_in_input_order() {
|
||||
if !cmd_available("python3") {
|
||||
eprintln!("skipping: python3 not available");
|
||||
return;
|
||||
}
|
||||
let ws = TestWorkspace::new();
|
||||
ws.write_script(
|
||||
"doubler.py",
|
||||
r#"#!/usr/bin/env python3
|
||||
import os, json
|
||||
state = json.loads(os.environ.get("GRAPH_STATE", "{}"))
|
||||
val = state["item"]
|
||||
print(json.dumps({"output": val * 2}))
|
||||
"#,
|
||||
);
|
||||
|
||||
let yaml = r#"
|
||||
name: map_input_order_test
|
||||
start: fan_out
|
||||
initial_state:
|
||||
items: [1, 2, 3, 4, 5]
|
||||
nodes:
|
||||
fan_out:
|
||||
type: map
|
||||
over: "{{items}}"
|
||||
as: item
|
||||
branch: doubler
|
||||
collect_into: doubled
|
||||
next: done
|
||||
doubler:
|
||||
type: script
|
||||
script: doubler.py
|
||||
state_updates: {}
|
||||
done:
|
||||
type: end
|
||||
output: "{{doubled}}"
|
||||
"#;
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
let mut ctx = make_ctx();
|
||||
let abort = create_abort_signal();
|
||||
let result = GraphExecutor::new(graph, &ws.dir)
|
||||
.execute(&mut ctx, abort)
|
||||
.await
|
||||
.unwrap_or_else(|e| panic!("executor failed: {e:#}"));
|
||||
|
||||
let parsed: Value = serde_json::from_str(&result)
|
||||
.unwrap_or_else(|_| panic!("expected JSON array, got: {result}"));
|
||||
let arr = parsed.as_array().expect("doubled should be an array");
|
||||
let nums: Vec<i64> = arr
|
||||
.iter()
|
||||
.map(|v| v.as_i64().expect("each item should be int"))
|
||||
.collect();
|
||||
|
||||
assert_eq!(
|
||||
nums,
|
||||
vec![2, 4, 6, 8, 10],
|
||||
"map outputs should be in input order, not finish order"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn parallel_branch_error_aborts_super_step() {
|
||||
if !cmd_available("bash") {
|
||||
eprintln!("skipping: bash not available");
|
||||
return;
|
||||
}
|
||||
let ws = TestWorkspace::new();
|
||||
ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n");
|
||||
ws.write_script(
|
||||
"worker_ok.sh",
|
||||
"#!/bin/bash\necho '{\"results\": \"ok\"}'\n",
|
||||
);
|
||||
ws.write_script(
|
||||
"worker_fail.sh",
|
||||
"#!/bin/bash\necho 'simulated failure' >&2\nexit 1\n",
|
||||
);
|
||||
|
||||
let yaml = r#"
|
||||
name: branch_error_test
|
||||
start: dispatcher
|
||||
reducers:
|
||||
results: append
|
||||
nodes:
|
||||
dispatcher:
|
||||
type: script
|
||||
script: dispatcher.sh
|
||||
state_updates: {}
|
||||
next: [worker_ok, worker_fail]
|
||||
worker_ok:
|
||||
type: script
|
||||
script: worker_ok.sh
|
||||
state_updates: {}
|
||||
next: join
|
||||
worker_fail:
|
||||
type: script
|
||||
script: worker_fail.sh
|
||||
state_updates: {}
|
||||
next: join
|
||||
join:
|
||||
type: end
|
||||
output: "{{results}}"
|
||||
"#;
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
let mut ctx = make_ctx();
|
||||
let abort = create_abort_signal();
|
||||
let result = GraphExecutor::new(graph, &ws.dir)
|
||||
.execute(&mut ctx, abort)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err(), "expected branch error to propagate");
|
||||
let err = format!("{:#}", result.unwrap_err());
|
||||
assert!(
|
||||
err.contains("worker_fail"),
|
||||
"error should mention failing node: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multi_end_in_super_step_is_rejected() {
|
||||
if !cmd_available("bash") {
|
||||
eprintln!("skipping: bash not available");
|
||||
return;
|
||||
}
|
||||
let ws = TestWorkspace::new();
|
||||
ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n");
|
||||
|
||||
let yaml = r#"
|
||||
name: multi_end_test
|
||||
start: dispatcher
|
||||
nodes:
|
||||
dispatcher:
|
||||
type: script
|
||||
script: dispatcher.sh
|
||||
state_updates: {}
|
||||
next: [end_a, end_b]
|
||||
end_a:
|
||||
type: end
|
||||
output: "from a"
|
||||
end_b:
|
||||
type: end
|
||||
output: "from b"
|
||||
"#;
|
||||
let graph: Graph = serde_yaml::from_str(yaml).unwrap();
|
||||
let mut ctx = make_ctx();
|
||||
let abort = create_abort_signal();
|
||||
let result = GraphExecutor::new(graph, &ws.dir)
|
||||
.execute(&mut ctx, abort)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err(), "expected multi-End to be rejected");
|
||||
let err = format!("{:#}", result.unwrap_err());
|
||||
assert!(
|
||||
err.contains("multiple End targets"),
|
||||
"error should explain multi-End cause: {err}"
|
||||
);
|
||||
assert!(
|
||||
err.contains("end_a") && err.contains("end_b"),
|
||||
"error should list both End nodes: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user