fix: bug in next_single method and improved outcome handling for LLM node execution

This commit is contained in:
2026-05-20 16:27:25 -06:00
parent a3bfa2fbe9
commit 76549a9911
5 changed files with 65 additions and 153 deletions
+8 -38
View File
@@ -1,5 +1,5 @@
use super::agent::AgentNodeExecutor;
use super::llm::LlmNodeExecutor;
use super::llm::{LlmExecutionOutcome, LlmNodeExecutor};
use super::logging::GraphLogger;
use super::map::MapNodeExecutor;
use super::progress::{BranchProgressHandle, BranchProgressTracker};
@@ -366,17 +366,16 @@ async fn step(
Ok(StepResult::Continue(vec![next]))
}
NodeType::Llm(llm_node) => {
let primary = first_next_target(node).map(str::to_string);
let llm_routing =
LlmNodeExecutor::execute(llm_node, primary.as_deref(), state, ctx).await?;
let targets = resolve_branching_next(node, &llm_routing);
let outcome = LlmNodeExecutor::execute(llm_node, state, ctx).await?;
let targets = match outcome {
LlmExecutionOutcome::Continue => static_next_targets(node, current, "llm")?,
LlmExecutionOutcome::FellBack(target) => vec![target],
};
Ok(StepResult::Continue(targets))
}
NodeType::Rag(rag_node) => {
let primary = first_next_target(node).map(str::to_string);
let rag_routing =
RagNodeExecutor::execute(rag_node, current, primary.as_deref(), state, ctx).await?;
let targets = resolve_branching_next(node, &rag_routing);
RagNodeExecutor::execute(rag_node, current, state, ctx).await?;
let targets = static_next_targets(node, current, "rag")?;
Ok(StepResult::Continue(targets))
}
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
@@ -406,35 +405,6 @@ fn first_next_target(node: &Node) -> Option<&str> {
.and_then(|t| t.as_slice().first().map(|s| s.as_str()))
}
// Resolves the actual frontier-advance targets after an LLM/RAG node ran.
//
// LLM/RAG executors return their chosen routing as a String — either the
// primary `next:` target (success path) or the node's `fallback:` (failure
// path with retry exhausted). We can't tell these apart from inside step()
// without an API refactor, so we compare strings: if the returned routing
// matches the first declared `next` target, treat as success and (for
// fan-out) use ALL declared targets; otherwise treat as fallback and use the
// returned target alone.
//
// Known limitation: if a fan-out node's `fallback:` is set to the same node
// id as its first `next:` target, a successful run is indistinguishable from
// a fallback run — both look like "returned the first target". The result is
// that the executor advances to all Many targets in the fallback case (which
// is the OPPOSITE of the user's likely intent). Workaround: choose a
// `fallback:` distinct from any `next:` target.
fn resolve_branching_next(node: &Node, returned_routing: &str) -> Vec<String> {
let Some(targets) = &node.next else {
return vec![returned_routing.to_string()];
};
let slice = targets.as_slice();
let first_matches = slice.first().is_some_and(|s| s == returned_routing);
if first_matches && slice.len() > 1 {
slice.to_vec()
} else {
vec![returned_routing.to_string()]
}
}
fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String {
apply_simple_state_updates(end_node.state_updates.as_ref(), state);
state.interpolate_lenient(&end_node.output)
+41 -36
View File
@@ -13,15 +13,26 @@ use tokio::time::timeout;
const OUTPUT_KEY: &str = "output";
/// What happened during an LLM node's execution, from the caller's routing
/// perspective. `Continue` means the caller should advance via the node's
/// declared `next:` targets (whether the LLM actually succeeded or failed
/// without a fallback — either way, the executor uses node.next). `FellBack`
/// means the LLM failed after retries and the node had a `fallback:` declared,
/// so routing should go to that fallback target only.
#[derive(Debug, PartialEq, Eq)]
pub(super) enum LlmExecutionOutcome {
Continue,
FellBack(String),
}
pub struct LlmNodeExecutor;
impl LlmNodeExecutor {
pub async fn execute(
pub(super) async fn execute(
node: &LlmNode,
node_next: Option<&str>,
state_manager: &mut StateManager,
parent_ctx: &mut RequestContext,
) -> Result<String> {
) -> Result<LlmExecutionOutcome> {
let result = run(node, state_manager, parent_ctx).await;
let (output, failed) = match result {
Ok(raw) => match &node.output_schema {
@@ -44,7 +55,15 @@ impl LlmNodeExecutor {
};
apply_state_updates_with_output(node, state_manager, &output);
next_for_llm_node(node_next, failed, node.fallback.as_deref())
Ok(outcome_from(failed, node.fallback.as_deref()))
}
}
fn outcome_from(failed: bool, fallback: Option<&str>) -> LlmExecutionOutcome {
if failed && let Some(fb) = fallback {
LlmExecutionOutcome::FellBack(fb.to_string())
} else {
LlmExecutionOutcome::Continue
}
}
@@ -298,19 +317,6 @@ fn is_transient(err: &Error) -> bool {
|| s.contains("produced no output")
}
fn next_for_llm_node(
node_next: Option<&str>,
failed: bool,
fallback: Option<&str>,
) -> Result<String> {
if failed && let Some(fb) = fallback {
return Ok(fb.to_string());
}
node_next
.map(String::from)
.ok_or_else(|| anyhow!("llm node has no `next` set; llm nodes need static routing"))
}
fn apply_state_updates_with_output(
node: &LlmNode,
state_manager: &mut StateManager,
@@ -457,30 +463,29 @@ mod tests {
}
#[test]
fn next_for_llm_node_success_routes_to_next() {
fn outcome_from_success_is_continue() {
assert_eq!(
next_for_llm_node(Some("nx"), false, Some("fb")).unwrap(),
"nx"
outcome_from(false, Some("fb")),
LlmExecutionOutcome::Continue
);
assert_eq!(outcome_from(false, None), LlmExecutionOutcome::Continue);
}
#[test]
fn outcome_from_failure_with_fallback_is_fell_back() {
assert_eq!(
outcome_from(true, Some("fb")),
LlmExecutionOutcome::FellBack("fb".to_string())
);
}
#[test]
fn next_for_llm_node_failure_with_fallback_routes_to_fallback() {
assert_eq!(
next_for_llm_node(Some("nx"), true, Some("fb")).unwrap(),
"fb"
);
}
#[test]
fn next_for_llm_node_failure_without_fallback_routes_to_next() {
assert_eq!(next_for_llm_node(Some("nx"), true, None).unwrap(), "nx");
}
#[test]
fn next_for_llm_node_errors_without_next_or_fallback() {
assert!(next_for_llm_node(None, false, None).is_err());
assert!(next_for_llm_node(None, true, None).is_err());
fn outcome_from_failure_without_fallback_is_continue() {
// Failed but no fallback: caller routes via node.next as if successful.
// The error has already been recorded to state via the OUTPUT_KEY by
// execute(); the caller's `static_next_targets` will error if node.next
// is also missing.
assert_eq!(outcome_from(true, None), LlmExecutionOutcome::Continue);
}
fn node_with_schema(updates: Option<HashMap<String, String>>, schema: Value) -> LlmNode {
+6 -23
View File
@@ -14,12 +14,6 @@ use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Semaphore;
// Map sub-branches are atomic — the branch node has no `next` (enforced by
// validator rule C.5). But LLM/RAG node executors require an `Option<&str>` for
// their routing argument and error if it's `None`. Passing this sentinel
// satisfies their contract; the map executor discards the returned routing.
const MAP_BRANCH_SENTINEL_NEXT: &str = "__map_branch_continuation__";
pub(super) struct MapNodeExecutor;
impl MapNodeExecutor {
@@ -99,26 +93,15 @@ impl MapNodeExecutor {
let mut ctx = sub_ctx;
let exec_result: Result<()> = match &branch_clone.node_type {
NodeType::Llm(n) => LlmNodeExecutor::execute(
n,
Some(MAP_BRANCH_SENTINEL_NEXT),
&mut state,
&mut ctx,
)
.await
.map(|_| ()),
NodeType::Llm(n) => LlmNodeExecutor::execute(n, &mut state, &mut ctx)
.await
.map(|_| ()),
NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx)
.await
.map(|_| ()),
NodeType::Rag(n) => RagNodeExecutor::execute(
n,
&sub_branch_id,
Some(MAP_BRANCH_SENTINEL_NEXT),
&mut state,
&mut ctx,
)
.await
.map(|_| ()),
NodeType::Rag(n) => {
RagNodeExecutor::execute(n, &sub_branch_id, &mut state, &mut ctx).await
}
NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()),
_ => Err(anyhow!(
"map branch '{}' has type that cannot run inside a map \
+3 -7
View File
@@ -14,13 +14,12 @@ const DEFAULT_RAG_TIMEOUT_SECS: u64 = 120;
pub struct RagNodeExecutor;
impl RagNodeExecutor {
pub async fn execute(
pub(super) async fn execute(
node: &RagNode,
node_id: &str,
node_next: Option<&str>,
state_manager: &mut StateManager,
ctx: &mut RequestContext,
) -> Result<String> {
) -> Result<()> {
let query_template = node.query.as_deref().unwrap_or(DEFAULT_QUERY);
let query = state_manager
.interpolate(query_template)
@@ -55,10 +54,7 @@ impl RagNodeExecutor {
let output = build_rag_output(context, &sources_str);
apply_state_updates(node, state_manager, &output);
node_next
.map(String::from)
.ok_or_else(|| anyhow!("rag node '{node_id}' has no `next` set"))
Ok(())
}
}
+7 -49
View File
@@ -1,4 +1,4 @@
use anyhow::{Result, bail};
use anyhow::Result;
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
@@ -129,11 +129,11 @@ pub struct Node {
}
impl Node {
/// Returns the single next target as a string slice, or `None` if no next is
/// declared or if a multi-target fan-out is declared. Use this for read-only
/// inspection (e.g. tests). For execution paths that require single-target
/// semantics, use `next_single()` — it errors explicitly when a fan-out is
/// declared so the caller can surface a clear failure instead of skipping it.
/// Returns the single next target as a string slice for tests and other
/// read-only inspection. Returns `None` when no `next:` is declared at all,
/// OR when a real multi-target fan-out is declared (since a fan-out has no
/// "single" target). Execution paths use `static_next_targets` in the graph
/// executor instead.
#[allow(dead_code)]
pub fn next_target(&self) -> Option<&str> {
match &self.next {
@@ -143,16 +143,6 @@ impl Node {
Some(NextTargets::Many(_)) => None,
}
}
/// Returns the single next target as a string slice, or an explicit error if
/// the node declares a multi-target fan-out (which is not yet supported
/// pre-Phase-D). Returns `Ok(None)` when no next is declared at all.
pub fn next_single(&self) -> Result<Option<&str>> {
match &self.next {
None => Ok(None),
Some(targets) => Ok(Some(targets.single()?.as_str())),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
@@ -172,24 +162,9 @@ impl NextTargets {
}
/// True if this declares more than one parallel target (i.e., a real fan-out).
#[allow(dead_code)]
pub fn is_fan_out(&self) -> bool {
matches!(self, NextTargets::Many(v) if v.len() > 1)
}
/// Returns the single target if exactly one is declared, else errors with a
/// clear "not yet supported" message. Used by the v1 executor until parallel
/// branch execution lands in Phase D.
pub fn single(&self) -> Result<&String> {
match self {
NextTargets::One(s) => Ok(s),
NextTargets::Many(v) if v.len() == 1 => Ok(&v[0]),
NextTargets::Many(_) => bail!(
"Parallel fan-out (`next: [a, b, ...]`) is declared, but parallel \
branch execution is not yet implemented in this build."
),
}
}
}
impl From<String> for NextTargets {
@@ -971,23 +946,7 @@ next: retrieve
}
#[test]
fn next_single_errors_on_real_fan_out_with_clear_message() {
let yaml = r#"
id: triage
type: llm
prompt: Classify
next: [a, b]
"#;
let node: Node = serde_yaml::from_str(yaml).unwrap();
let err = node.next_single().unwrap_err().to_string();
assert!(err.contains("Parallel fan-out"), "got: {err}");
assert!(err.contains("not yet implemented"), "got: {err}");
}
#[test]
fn next_single_accepts_many_containing_exactly_one_target() {
fn next_target_treats_many_of_one_as_single() {
let yaml = r#"
id: triage
type: llm
@@ -997,7 +956,6 @@ next: [retrieve]
let node: Node = serde_yaml::from_str(yaml).unwrap();
assert_eq!(node.next_single().unwrap(), Some("retrieve"));
assert_eq!(node.next_target(), Some("retrieve"));
}