From 98d16d9a56a4c3c3955c0eed0b2c6badb713c29c Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Wed, 20 May 2026 16:27:25 -0600 Subject: [PATCH] fix: bug in next_single method and improved outcome handling for LLM node execution --- src/graph/executor.rs | 46 +++++--------------------- src/graph/llm.rs | 77 +++++++++++++++++++++++-------------------- src/graph/map.rs | 29 ++++------------ src/graph/rag.rs | 10 ++---- src/graph/types.rs | 56 ++++--------------------------- 5 files changed, 65 insertions(+), 153 deletions(-) diff --git a/src/graph/executor.rs b/src/graph/executor.rs index 56bc95c..552c270 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -1,5 +1,5 @@ use super::agent::AgentNodeExecutor; -use super::llm::LlmNodeExecutor; +use super::llm::{LlmExecutionOutcome, LlmNodeExecutor}; use super::logging::GraphLogger; use super::map::MapNodeExecutor; use super::progress::{BranchProgressHandle, BranchProgressTracker}; @@ -366,17 +366,16 @@ async fn step( Ok(StepResult::Continue(vec![next])) } NodeType::Llm(llm_node) => { - let primary = first_next_target(node).map(str::to_string); - let llm_routing = - LlmNodeExecutor::execute(llm_node, primary.as_deref(), state, ctx).await?; - let targets = resolve_branching_next(node, &llm_routing); + let outcome = LlmNodeExecutor::execute(llm_node, state, ctx).await?; + let targets = match outcome { + LlmExecutionOutcome::Continue => static_next_targets(node, current, "llm")?, + LlmExecutionOutcome::FellBack(target) => vec![target], + }; Ok(StepResult::Continue(targets)) } NodeType::Rag(rag_node) => { - let primary = first_next_target(node).map(str::to_string); - let rag_routing = - RagNodeExecutor::execute(rag_node, current, primary.as_deref(), state, ctx).await?; - let targets = resolve_branching_next(node, &rag_routing); + RagNodeExecutor::execute(rag_node, current, state, ctx).await?; + let targets = static_next_targets(node, current, "rag")?; Ok(StepResult::Continue(targets)) } NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))), @@ -406,35 +405,6 @@ fn first_next_target(node: &Node) -> Option<&str> { .and_then(|t| t.as_slice().first().map(|s| s.as_str())) } -// Resolves the actual frontier-advance targets after an LLM/RAG node ran. -// -// LLM/RAG executors return their chosen routing as a String — either the -// primary `next:` target (success path) or the node's `fallback:` (failure -// path with retry exhausted). We can't tell these apart from inside step() -// without an API refactor, so we compare strings: if the returned routing -// matches the first declared `next` target, treat as success and (for -// fan-out) use ALL declared targets; otherwise treat as fallback and use the -// returned target alone. -// -// Known limitation: if a fan-out node's `fallback:` is set to the same node -// id as its first `next:` target, a successful run is indistinguishable from -// a fallback run — both look like "returned the first target". The result is -// that the executor advances to all Many targets in the fallback case (which -// is the OPPOSITE of the user's likely intent). Workaround: choose a -// `fallback:` distinct from any `next:` target. -fn resolve_branching_next(node: &Node, returned_routing: &str) -> Vec { - let Some(targets) = &node.next else { - return vec![returned_routing.to_string()]; - }; - let slice = targets.as_slice(); - let first_matches = slice.first().is_some_and(|s| s == returned_routing); - if first_matches && slice.len() > 1 { - slice.to_vec() - } else { - vec![returned_routing.to_string()] - } -} - fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String { apply_simple_state_updates(end_node.state_updates.as_ref(), state); state.interpolate_lenient(&end_node.output) diff --git a/src/graph/llm.rs b/src/graph/llm.rs index 1a02498..98b47f1 100644 --- a/src/graph/llm.rs +++ b/src/graph/llm.rs @@ -13,15 +13,26 @@ use tokio::time::timeout; const OUTPUT_KEY: &str = "output"; +/// What happened during an LLM node's execution, from the caller's routing +/// perspective. `Continue` means the caller should advance via the node's +/// declared `next:` targets (whether the LLM actually succeeded or failed +/// without a fallback — either way, the executor uses node.next). `FellBack` +/// means the LLM failed after retries and the node had a `fallback:` declared, +/// so routing should go to that fallback target only. +#[derive(Debug, PartialEq, Eq)] +pub(super) enum LlmExecutionOutcome { + Continue, + FellBack(String), +} + pub struct LlmNodeExecutor; impl LlmNodeExecutor { - pub async fn execute( + pub(super) async fn execute( node: &LlmNode, - node_next: Option<&str>, state_manager: &mut StateManager, parent_ctx: &mut RequestContext, - ) -> Result { + ) -> Result { let result = run(node, state_manager, parent_ctx).await; let (output, failed) = match result { Ok(raw) => match &node.output_schema { @@ -44,7 +55,15 @@ impl LlmNodeExecutor { }; apply_state_updates_with_output(node, state_manager, &output); - next_for_llm_node(node_next, failed, node.fallback.as_deref()) + Ok(outcome_from(failed, node.fallback.as_deref())) + } +} + +fn outcome_from(failed: bool, fallback: Option<&str>) -> LlmExecutionOutcome { + if failed && let Some(fb) = fallback { + LlmExecutionOutcome::FellBack(fb.to_string()) + } else { + LlmExecutionOutcome::Continue } } @@ -298,19 +317,6 @@ fn is_transient(err: &Error) -> bool { || s.contains("produced no output") } -fn next_for_llm_node( - node_next: Option<&str>, - failed: bool, - fallback: Option<&str>, -) -> Result { - if failed && let Some(fb) = fallback { - return Ok(fb.to_string()); - } - node_next - .map(String::from) - .ok_or_else(|| anyhow!("llm node has no `next` set; llm nodes need static routing")) -} - fn apply_state_updates_with_output( node: &LlmNode, state_manager: &mut StateManager, @@ -457,30 +463,29 @@ mod tests { } #[test] - fn next_for_llm_node_success_routes_to_next() { + fn outcome_from_success_is_continue() { assert_eq!( - next_for_llm_node(Some("nx"), false, Some("fb")).unwrap(), - "nx" + outcome_from(false, Some("fb")), + LlmExecutionOutcome::Continue + ); + assert_eq!(outcome_from(false, None), LlmExecutionOutcome::Continue); + } + + #[test] + fn outcome_from_failure_with_fallback_is_fell_back() { + assert_eq!( + outcome_from(true, Some("fb")), + LlmExecutionOutcome::FellBack("fb".to_string()) ); } #[test] - fn next_for_llm_node_failure_with_fallback_routes_to_fallback() { - assert_eq!( - next_for_llm_node(Some("nx"), true, Some("fb")).unwrap(), - "fb" - ); - } - - #[test] - fn next_for_llm_node_failure_without_fallback_routes_to_next() { - assert_eq!(next_for_llm_node(Some("nx"), true, None).unwrap(), "nx"); - } - - #[test] - fn next_for_llm_node_errors_without_next_or_fallback() { - assert!(next_for_llm_node(None, false, None).is_err()); - assert!(next_for_llm_node(None, true, None).is_err()); + fn outcome_from_failure_without_fallback_is_continue() { + // Failed but no fallback: caller routes via node.next as if successful. + // The error has already been recorded to state via the OUTPUT_KEY by + // execute(); the caller's `static_next_targets` will error if node.next + // is also missing. + assert_eq!(outcome_from(true, None), LlmExecutionOutcome::Continue); } fn node_with_schema(updates: Option>, schema: Value) -> LlmNode { diff --git a/src/graph/map.rs b/src/graph/map.rs index b798610..a81156f 100644 --- a/src/graph/map.rs +++ b/src/graph/map.rs @@ -14,12 +14,6 @@ use std::collections::HashMap; use std::sync::Arc; use tokio::sync::Semaphore; -// Map sub-branches are atomic — the branch node has no `next` (enforced by -// validator rule C.5). But LLM/RAG node executors require an `Option<&str>` for -// their routing argument and error if it's `None`. Passing this sentinel -// satisfies their contract; the map executor discards the returned routing. -const MAP_BRANCH_SENTINEL_NEXT: &str = "__map_branch_continuation__"; - pub(super) struct MapNodeExecutor; impl MapNodeExecutor { @@ -99,26 +93,15 @@ impl MapNodeExecutor { let mut ctx = sub_ctx; let exec_result: Result<()> = match &branch_clone.node_type { - NodeType::Llm(n) => LlmNodeExecutor::execute( - n, - Some(MAP_BRANCH_SENTINEL_NEXT), - &mut state, - &mut ctx, - ) - .await - .map(|_| ()), + NodeType::Llm(n) => LlmNodeExecutor::execute(n, &mut state, &mut ctx) + .await + .map(|_| ()), NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx) .await .map(|_| ()), - NodeType::Rag(n) => RagNodeExecutor::execute( - n, - &sub_branch_id, - Some(MAP_BRANCH_SENTINEL_NEXT), - &mut state, - &mut ctx, - ) - .await - .map(|_| ()), + NodeType::Rag(n) => { + RagNodeExecutor::execute(n, &sub_branch_id, &mut state, &mut ctx).await + } NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()), _ => Err(anyhow!( "map branch '{}' has type that cannot run inside a map \ diff --git a/src/graph/rag.rs b/src/graph/rag.rs index 2017e50..3439952 100644 --- a/src/graph/rag.rs +++ b/src/graph/rag.rs @@ -14,13 +14,12 @@ const DEFAULT_RAG_TIMEOUT_SECS: u64 = 120; pub struct RagNodeExecutor; impl RagNodeExecutor { - pub async fn execute( + pub(super) async fn execute( node: &RagNode, node_id: &str, - node_next: Option<&str>, state_manager: &mut StateManager, ctx: &mut RequestContext, - ) -> Result { + ) -> Result<()> { let query_template = node.query.as_deref().unwrap_or(DEFAULT_QUERY); let query = state_manager .interpolate(query_template) @@ -55,10 +54,7 @@ impl RagNodeExecutor { let output = build_rag_output(context, &sources_str); apply_state_updates(node, state_manager, &output); - - node_next - .map(String::from) - .ok_or_else(|| anyhow!("rag node '{node_id}' has no `next` set")) + Ok(()) } } diff --git a/src/graph/types.rs b/src/graph/types.rs index 70e8ccb..6a668e1 100644 --- a/src/graph/types.rs +++ b/src/graph/types.rs @@ -1,4 +1,4 @@ -use anyhow::{Result, bail}; +use anyhow::Result; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -129,11 +129,11 @@ pub struct Node { } impl Node { - /// Returns the single next target as a string slice, or `None` if no next is - /// declared or if a multi-target fan-out is declared. Use this for read-only - /// inspection (e.g. tests). For execution paths that require single-target - /// semantics, use `next_single()` — it errors explicitly when a fan-out is - /// declared so the caller can surface a clear failure instead of skipping it. + /// Returns the single next target as a string slice for tests and other + /// read-only inspection. Returns `None` when no `next:` is declared at all, + /// OR when a real multi-target fan-out is declared (since a fan-out has no + /// "single" target). Execution paths use `static_next_targets` in the graph + /// executor instead. #[allow(dead_code)] pub fn next_target(&self) -> Option<&str> { match &self.next { @@ -143,16 +143,6 @@ impl Node { Some(NextTargets::Many(_)) => None, } } - - /// Returns the single next target as a string slice, or an explicit error if - /// the node declares a multi-target fan-out (which is not yet supported - /// pre-Phase-D). Returns `Ok(None)` when no next is declared at all. - pub fn next_single(&self) -> Result> { - match &self.next { - None => Ok(None), - Some(targets) => Ok(Some(targets.single()?.as_str())), - } - } } #[derive(Debug, Clone, Deserialize, Serialize)] @@ -172,24 +162,9 @@ impl NextTargets { } /// True if this declares more than one parallel target (i.e., a real fan-out). - #[allow(dead_code)] pub fn is_fan_out(&self) -> bool { matches!(self, NextTargets::Many(v) if v.len() > 1) } - - /// Returns the single target if exactly one is declared, else errors with a - /// clear "not yet supported" message. Used by the v1 executor until parallel - /// branch execution lands in Phase D. - pub fn single(&self) -> Result<&String> { - match self { - NextTargets::One(s) => Ok(s), - NextTargets::Many(v) if v.len() == 1 => Ok(&v[0]), - NextTargets::Many(_) => bail!( - "Parallel fan-out (`next: [a, b, ...]`) is declared, but parallel \ - branch execution is not yet implemented in this build." - ), - } - } } impl From for NextTargets { @@ -971,23 +946,7 @@ next: retrieve } #[test] - fn next_single_errors_on_real_fan_out_with_clear_message() { - let yaml = r#" -id: triage -type: llm -prompt: Classify -next: [a, b] -"#; - let node: Node = serde_yaml::from_str(yaml).unwrap(); - - let err = node.next_single().unwrap_err().to_string(); - - assert!(err.contains("Parallel fan-out"), "got: {err}"); - assert!(err.contains("not yet implemented"), "got: {err}"); - } - - #[test] - fn next_single_accepts_many_containing_exactly_one_target() { + fn next_target_treats_many_of_one_as_single() { let yaml = r#" id: triage type: llm @@ -997,7 +956,6 @@ next: [retrieve] let node: Node = serde_yaml::from_str(yaml).unwrap(); - assert_eq!(node.next_single().unwrap(), Some("retrieve")); assert_eq!(node.next_target(), Some("retrieve")); }