diff --git a/src/graph/executor.rs b/src/graph/executor.rs index f9bf69c..56bc95c 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -228,15 +228,17 @@ impl GraphExecutor { let step_outcome = step_result.with_context(|| format!("at node '{node_id}'"))?; match step_outcome { - StepResult::Continue(target) => { - logger.routing(&node_id, &target); + StepResult::Continue(targets) => { + for target in &targets { + logger.routing(&node_id, target); + } let diff = branch_state.diff_against(snapshot.as_ref()); branch_writes.push(BranchWrites { node_id: node_id.clone(), invocation_index: 0, writes: diff, }); - next_frontier.insert(target); + next_frontier.extend(targets); } StepResult::End(output) => { end_results.push((node_id.clone(), branch_state, output)); @@ -309,7 +311,12 @@ impl StepContext<'_> { } enum StepResult { - Continue(String), + // The set of next-node ids the executor should add to the next super-step's + // frontier. A `Vec` of length 1 for sequential routing (default) and the + // full target list for fan-out (`next: [a, b, ...]`). Dynamic single-route + // decisions (script `_next`, approval routes, LLM/RAG fallback) always emit + // a single-element vec. + Continue(Vec), End(String), } @@ -323,13 +330,8 @@ async fn step( match &node.node_type { NodeType::Agent(agent_node) => { AgentNodeExecutor::execute(agent_node, state, ctx).await?; - let next = node - .next_single()? - .ok_or_else(|| { - anyhow!("agent node '{current}' has no `next` and is not an end node") - })? - .to_string(); - Ok(StepResult::Continue(next)) + let targets = static_next_targets(node, current, "agent")?; + Ok(StepResult::Continue(targets)) } NodeType::Script(script_node) => { let dynamic = match step_ctx.script_executor.execute(script_node, state).await { @@ -343,57 +345,96 @@ async fn step( fallback, e ); - return Ok(StepResult::Continue(fallback.clone())); + return Ok(StepResult::Continue(vec![fallback.clone()])); } return Err(e); } }; - let next = match dynamic { - Some(n) => n, - None => node - .next_single()? - .ok_or_else(|| { - anyhow!( - "script node '{current}' did not emit `_next` and has no static `next`" - ) - })? - .to_string(), + let targets = match dynamic { + Some(n) => vec![n], + None => static_next_targets(node, current, "script")?, }; - Ok(StepResult::Continue(next)) + Ok(StepResult::Continue(targets)) } NodeType::Approval(approval_node) => { let next = ApprovalNodeExecutor::execute(approval_node, state, ctx).await?; - Ok(StepResult::Continue(next)) + Ok(StepResult::Continue(vec![next])) } NodeType::Input(input_node) => { - let next_id = node.next_single()?; + let next_id = first_next_target(node); let next = InputNodeExecutor::execute(input_node, next_id, state, ctx).await?; - Ok(StepResult::Continue(next)) + Ok(StepResult::Continue(vec![next])) } NodeType::Llm(llm_node) => { - let next_id = node.next_single()?; - let next = LlmNodeExecutor::execute(llm_node, next_id, state, ctx).await?; - Ok(StepResult::Continue(next)) + let primary = first_next_target(node).map(str::to_string); + let llm_routing = + LlmNodeExecutor::execute(llm_node, primary.as_deref(), state, ctx).await?; + let targets = resolve_branching_next(node, &llm_routing); + Ok(StepResult::Continue(targets)) } NodeType::Rag(rag_node) => { - let next_id = node.next_single()?; - let next = RagNodeExecutor::execute(rag_node, current, next_id, state, ctx).await?; - Ok(StepResult::Continue(next)) + let primary = first_next_target(node).map(str::to_string); + let rag_routing = + RagNodeExecutor::execute(rag_node, current, primary.as_deref(), state, ctx).await?; + let targets = resolve_branching_next(node, &rag_routing); + Ok(StepResult::Continue(targets)) } NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))), NodeType::Map(map_node) => { - let next = node - .next_single()? - .ok_or_else(|| { - anyhow!("map node '{current}' has no `next` and is not an end node") - })? - .to_string(); + let targets = static_next_targets(node, current, "map")?; MapNodeExecutor::execute(map_node, state, ctx, step_ctx, current).await?; - Ok(StepResult::Continue(next)) + Ok(StepResult::Continue(targets)) } } } +// Returns all `next:` targets from the node (handles both `One` and `Many`), +// erroring if no `next` is set. +fn static_next_targets(node: &Node, current: &str, kind: &str) -> Result> { + node.next + .as_ref() + .map(|t| t.as_slice().to_vec()) + .ok_or_else(|| anyhow!("{kind} node '{current}' has no `next` and is not an end node")) +} + +// Returns the first declared `next:` target as a borrowed `&str`, or `None` if +// no `next` is set. Used by node executors that take `Option<&str>` for their +// primary routing argument (LLM, RAG, Input). +fn first_next_target(node: &Node) -> Option<&str> { + node.next + .as_ref() + .and_then(|t| t.as_slice().first().map(|s| s.as_str())) +} + +// Resolves the actual frontier-advance targets after an LLM/RAG node ran. +// +// LLM/RAG executors return their chosen routing as a String — either the +// primary `next:` target (success path) or the node's `fallback:` (failure +// path with retry exhausted). We can't tell these apart from inside step() +// without an API refactor, so we compare strings: if the returned routing +// matches the first declared `next` target, treat as success and (for +// fan-out) use ALL declared targets; otherwise treat as fallback and use the +// returned target alone. +// +// Known limitation: if a fan-out node's `fallback:` is set to the same node +// id as its first `next:` target, a successful run is indistinguishable from +// a fallback run — both look like "returned the first target". The result is +// that the executor advances to all Many targets in the fallback case (which +// is the OPPOSITE of the user's likely intent). Workaround: choose a +// `fallback:` distinct from any `next:` target. +fn resolve_branching_next(node: &Node, returned_routing: &str) -> Vec { + let Some(targets) = &node.next else { + return vec![returned_routing.to_string()]; + }; + let slice = targets.as_slice(); + let first_matches = slice.first().is_some_and(|s| s == returned_routing); + if first_matches && slice.len() > 1 { + slice.to_vec() + } else { + vec![returned_routing.to_string()] + } +} + fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String { apply_simple_state_updates(end_node.state_updates.as_ref(), state); state.interpolate_lenient(&end_node.output) @@ -490,3 +531,262 @@ mod tests { assert_eq!(state.state().get("k"), Some(&json!("new-old"))); } } + +#[cfg(test)] +mod integration_tests { + use super::*; + use crate::config::{AppState, WorkingMode}; + use crate::utils::{create_abort_signal, temp_file}; + use std::fs; + + fn cmd_available(name: &str) -> bool { + which::which(name).is_ok() + } + + struct TestWorkspace { + dir: PathBuf, + } + + impl TestWorkspace { + fn new() -> Self { + let dir = temp_file("-graph-integration-", ""); + fs::create_dir_all(&dir).unwrap(); + Self { dir } + } + + fn write_script(&self, name: &str, contents: &str) { + fs::write(self.dir.join(name), contents).unwrap(); + } + } + + impl Drop for TestWorkspace { + fn drop(&mut self) { + let _ = fs::remove_dir_all(&self.dir); + } + } + + fn make_ctx() -> RequestContext { + RequestContext::new(Arc::new(AppState::test_default()), WorkingMode::Cmd) + } + + #[tokio::test] + async fn static_fan_out_merges_branch_writes_via_append_reducer() { + if !cmd_available("bash") { + eprintln!("skipping: bash not available"); + return; + } + let ws = TestWorkspace::new(); + ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n"); + ws.write_script( + "worker_a.sh", + "#!/bin/bash\necho '{\"results\": \"alpha\"}'\n", + ); + ws.write_script( + "worker_b.sh", + "#!/bin/bash\necho '{\"results\": \"beta\"}'\n", + ); + + let yaml = r#" +name: static_fan_out_test +start: dispatcher +reducers: + results: append +nodes: + dispatcher: + type: script + script: dispatcher.sh + state_updates: {} + next: [worker_a, worker_b] + worker_a: + type: script + script: worker_a.sh + state_updates: {} + next: join + worker_b: + type: script + script: worker_b.sh + state_updates: {} + next: join + join: + type: end + output: "{{results}}" +"#; + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + let mut ctx = make_ctx(); + let abort = create_abort_signal(); + let result = GraphExecutor::new(graph, &ws.dir) + .execute(&mut ctx, abort) + .await + .unwrap_or_else(|e| panic!("executor failed: {e:#}")); + + let parsed: Value = serde_json::from_str(&result) + .unwrap_or_else(|_| panic!("expected JSON array, got: {result}")); + let arr = parsed.as_array().expect("results should be an array"); + assert_eq!(arr.len(), 2, "expected 2 elements, got: {result}"); + let strs: Vec<&str> = arr.iter().filter_map(|v| v.as_str()).collect(); + assert!(strs.contains(&"alpha"), "missing 'alpha' in {strs:?}"); + assert!(strs.contains(&"beta"), "missing 'beta' in {strs:?}"); + } + + #[tokio::test] + async fn map_over_list_collects_outputs_in_input_order() { + if !cmd_available("python3") { + eprintln!("skipping: python3 not available"); + return; + } + let ws = TestWorkspace::new(); + ws.write_script( + "doubler.py", + r#"#!/usr/bin/env python3 +import os, json +state = json.loads(os.environ.get("GRAPH_STATE", "{}")) +val = state["item"] +print(json.dumps({"output": val * 2})) +"#, + ); + + let yaml = r#" +name: map_input_order_test +start: fan_out +initial_state: + items: [1, 2, 3, 4, 5] +nodes: + fan_out: + type: map + over: "{{items}}" + as: item + branch: doubler + collect_into: doubled + next: done + doubler: + type: script + script: doubler.py + state_updates: {} + done: + type: end + output: "{{doubled}}" +"#; + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + let mut ctx = make_ctx(); + let abort = create_abort_signal(); + let result = GraphExecutor::new(graph, &ws.dir) + .execute(&mut ctx, abort) + .await + .unwrap_or_else(|e| panic!("executor failed: {e:#}")); + + let parsed: Value = serde_json::from_str(&result) + .unwrap_or_else(|_| panic!("expected JSON array, got: {result}")); + let arr = parsed.as_array().expect("doubled should be an array"); + let nums: Vec = arr + .iter() + .map(|v| v.as_i64().expect("each item should be int")) + .collect(); + + assert_eq!( + nums, + vec![2, 4, 6, 8, 10], + "map outputs should be in input order, not finish order" + ); + } + + #[tokio::test] + async fn parallel_branch_error_aborts_super_step() { + if !cmd_available("bash") { + eprintln!("skipping: bash not available"); + return; + } + let ws = TestWorkspace::new(); + ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n"); + ws.write_script( + "worker_ok.sh", + "#!/bin/bash\necho '{\"results\": \"ok\"}'\n", + ); + ws.write_script( + "worker_fail.sh", + "#!/bin/bash\necho 'simulated failure' >&2\nexit 1\n", + ); + + let yaml = r#" +name: branch_error_test +start: dispatcher +reducers: + results: append +nodes: + dispatcher: + type: script + script: dispatcher.sh + state_updates: {} + next: [worker_ok, worker_fail] + worker_ok: + type: script + script: worker_ok.sh + state_updates: {} + next: join + worker_fail: + type: script + script: worker_fail.sh + state_updates: {} + next: join + join: + type: end + output: "{{results}}" +"#; + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + let mut ctx = make_ctx(); + let abort = create_abort_signal(); + let result = GraphExecutor::new(graph, &ws.dir) + .execute(&mut ctx, abort) + .await; + + assert!(result.is_err(), "expected branch error to propagate"); + let err = format!("{:#}", result.unwrap_err()); + assert!( + err.contains("worker_fail"), + "error should mention failing node: {err}" + ); + } + + #[tokio::test] + async fn multi_end_in_super_step_is_rejected() { + if !cmd_available("bash") { + eprintln!("skipping: bash not available"); + return; + } + let ws = TestWorkspace::new(); + ws.write_script("dispatcher.sh", "#!/bin/bash\necho '{}'\n"); + + let yaml = r#" +name: multi_end_test +start: dispatcher +nodes: + dispatcher: + type: script + script: dispatcher.sh + state_updates: {} + next: [end_a, end_b] + end_a: + type: end + output: "from a" + end_b: + type: end + output: "from b" +"#; + let graph: Graph = serde_yaml::from_str(yaml).unwrap(); + let mut ctx = make_ctx(); + let abort = create_abort_signal(); + let result = GraphExecutor::new(graph, &ws.dir) + .execute(&mut ctx, abort) + .await; + + assert!(result.is_err(), "expected multi-End to be rejected"); + let err = format!("{:#}", result.unwrap_err()); + assert!( + err.contains("multiple End targets"), + "error should explain multi-End cause: {err}" + ); + assert!( + err.contains("end_a") && err.contains("end_b"), + "error should list both End nodes: {err}" + ); + } +}