fix: bug in next_single method and improved outcome handling for LLM node execution
This commit is contained in:
+8
-38
@@ -1,5 +1,5 @@
|
||||
use super::agent::AgentNodeExecutor;
|
||||
use super::llm::LlmNodeExecutor;
|
||||
use super::llm::{LlmExecutionOutcome, LlmNodeExecutor};
|
||||
use super::logging::GraphLogger;
|
||||
use super::map::MapNodeExecutor;
|
||||
use super::progress::{BranchProgressHandle, BranchProgressTracker};
|
||||
@@ -366,17 +366,16 @@ async fn step(
|
||||
Ok(StepResult::Continue(vec![next]))
|
||||
}
|
||||
NodeType::Llm(llm_node) => {
|
||||
let primary = first_next_target(node).map(str::to_string);
|
||||
let llm_routing =
|
||||
LlmNodeExecutor::execute(llm_node, primary.as_deref(), state, ctx).await?;
|
||||
let targets = resolve_branching_next(node, &llm_routing);
|
||||
let outcome = LlmNodeExecutor::execute(llm_node, state, ctx).await?;
|
||||
let targets = match outcome {
|
||||
LlmExecutionOutcome::Continue => static_next_targets(node, current, "llm")?,
|
||||
LlmExecutionOutcome::FellBack(target) => vec![target],
|
||||
};
|
||||
Ok(StepResult::Continue(targets))
|
||||
}
|
||||
NodeType::Rag(rag_node) => {
|
||||
let primary = first_next_target(node).map(str::to_string);
|
||||
let rag_routing =
|
||||
RagNodeExecutor::execute(rag_node, current, primary.as_deref(), state, ctx).await?;
|
||||
let targets = resolve_branching_next(node, &rag_routing);
|
||||
RagNodeExecutor::execute(rag_node, current, state, ctx).await?;
|
||||
let targets = static_next_targets(node, current, "rag")?;
|
||||
Ok(StepResult::Continue(targets))
|
||||
}
|
||||
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
||||
@@ -406,35 +405,6 @@ fn first_next_target(node: &Node) -> Option<&str> {
|
||||
.and_then(|t| t.as_slice().first().map(|s| s.as_str()))
|
||||
}
|
||||
|
||||
// Resolves the actual frontier-advance targets after an LLM/RAG node ran.
|
||||
//
|
||||
// LLM/RAG executors return their chosen routing as a String — either the
|
||||
// primary `next:` target (success path) or the node's `fallback:` (failure
|
||||
// path with retry exhausted). We can't tell these apart from inside step()
|
||||
// without an API refactor, so we compare strings: if the returned routing
|
||||
// matches the first declared `next` target, treat as success and (for
|
||||
// fan-out) use ALL declared targets; otherwise treat as fallback and use the
|
||||
// returned target alone.
|
||||
//
|
||||
// Known limitation: if a fan-out node's `fallback:` is set to the same node
|
||||
// id as its first `next:` target, a successful run is indistinguishable from
|
||||
// a fallback run — both look like "returned the first target". The result is
|
||||
// that the executor advances to all Many targets in the fallback case (which
|
||||
// is the OPPOSITE of the user's likely intent). Workaround: choose a
|
||||
// `fallback:` distinct from any `next:` target.
|
||||
fn resolve_branching_next(node: &Node, returned_routing: &str) -> Vec<String> {
|
||||
let Some(targets) = &node.next else {
|
||||
return vec![returned_routing.to_string()];
|
||||
};
|
||||
let slice = targets.as_slice();
|
||||
let first_matches = slice.first().is_some_and(|s| s == returned_routing);
|
||||
if first_matches && slice.len() > 1 {
|
||||
slice.to_vec()
|
||||
} else {
|
||||
vec![returned_routing.to_string()]
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String {
|
||||
apply_simple_state_updates(end_node.state_updates.as_ref(), state);
|
||||
state.interpolate_lenient(&end_node.output)
|
||||
|
||||
+41
-36
@@ -13,15 +13,26 @@ use tokio::time::timeout;
|
||||
|
||||
const OUTPUT_KEY: &str = "output";
|
||||
|
||||
/// What happened during an LLM node's execution, from the caller's routing
|
||||
/// perspective. `Continue` means the caller should advance via the node's
|
||||
/// declared `next:` targets (whether the LLM actually succeeded or failed
|
||||
/// without a fallback — either way, the executor uses node.next). `FellBack`
|
||||
/// means the LLM failed after retries and the node had a `fallback:` declared,
|
||||
/// so routing should go to that fallback target only.
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub(super) enum LlmExecutionOutcome {
|
||||
Continue,
|
||||
FellBack(String),
|
||||
}
|
||||
|
||||
pub struct LlmNodeExecutor;
|
||||
|
||||
impl LlmNodeExecutor {
|
||||
pub async fn execute(
|
||||
pub(super) async fn execute(
|
||||
node: &LlmNode,
|
||||
node_next: Option<&str>,
|
||||
state_manager: &mut StateManager,
|
||||
parent_ctx: &mut RequestContext,
|
||||
) -> Result<String> {
|
||||
) -> Result<LlmExecutionOutcome> {
|
||||
let result = run(node, state_manager, parent_ctx).await;
|
||||
let (output, failed) = match result {
|
||||
Ok(raw) => match &node.output_schema {
|
||||
@@ -44,7 +55,15 @@ impl LlmNodeExecutor {
|
||||
};
|
||||
|
||||
apply_state_updates_with_output(node, state_manager, &output);
|
||||
next_for_llm_node(node_next, failed, node.fallback.as_deref())
|
||||
Ok(outcome_from(failed, node.fallback.as_deref()))
|
||||
}
|
||||
}
|
||||
|
||||
fn outcome_from(failed: bool, fallback: Option<&str>) -> LlmExecutionOutcome {
|
||||
if failed && let Some(fb) = fallback {
|
||||
LlmExecutionOutcome::FellBack(fb.to_string())
|
||||
} else {
|
||||
LlmExecutionOutcome::Continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,19 +317,6 @@ fn is_transient(err: &Error) -> bool {
|
||||
|| s.contains("produced no output")
|
||||
}
|
||||
|
||||
fn next_for_llm_node(
|
||||
node_next: Option<&str>,
|
||||
failed: bool,
|
||||
fallback: Option<&str>,
|
||||
) -> Result<String> {
|
||||
if failed && let Some(fb) = fallback {
|
||||
return Ok(fb.to_string());
|
||||
}
|
||||
node_next
|
||||
.map(String::from)
|
||||
.ok_or_else(|| anyhow!("llm node has no `next` set; llm nodes need static routing"))
|
||||
}
|
||||
|
||||
fn apply_state_updates_with_output(
|
||||
node: &LlmNode,
|
||||
state_manager: &mut StateManager,
|
||||
@@ -457,30 +463,29 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_for_llm_node_success_routes_to_next() {
|
||||
fn outcome_from_success_is_continue() {
|
||||
assert_eq!(
|
||||
next_for_llm_node(Some("nx"), false, Some("fb")).unwrap(),
|
||||
"nx"
|
||||
outcome_from(false, Some("fb")),
|
||||
LlmExecutionOutcome::Continue
|
||||
);
|
||||
assert_eq!(outcome_from(false, None), LlmExecutionOutcome::Continue);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn outcome_from_failure_with_fallback_is_fell_back() {
|
||||
assert_eq!(
|
||||
outcome_from(true, Some("fb")),
|
||||
LlmExecutionOutcome::FellBack("fb".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_for_llm_node_failure_with_fallback_routes_to_fallback() {
|
||||
assert_eq!(
|
||||
next_for_llm_node(Some("nx"), true, Some("fb")).unwrap(),
|
||||
"fb"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_for_llm_node_failure_without_fallback_routes_to_next() {
|
||||
assert_eq!(next_for_llm_node(Some("nx"), true, None).unwrap(), "nx");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_for_llm_node_errors_without_next_or_fallback() {
|
||||
assert!(next_for_llm_node(None, false, None).is_err());
|
||||
assert!(next_for_llm_node(None, true, None).is_err());
|
||||
fn outcome_from_failure_without_fallback_is_continue() {
|
||||
// Failed but no fallback: caller routes via node.next as if successful.
|
||||
// The error has already been recorded to state via the OUTPUT_KEY by
|
||||
// execute(); the caller's `static_next_targets` will error if node.next
|
||||
// is also missing.
|
||||
assert_eq!(outcome_from(true, None), LlmExecutionOutcome::Continue);
|
||||
}
|
||||
|
||||
fn node_with_schema(updates: Option<HashMap<String, String>>, schema: Value) -> LlmNode {
|
||||
|
||||
+6
-23
@@ -14,12 +14,6 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
// Map sub-branches are atomic — the branch node has no `next` (enforced by
|
||||
// validator rule C.5). But LLM/RAG node executors require an `Option<&str>` for
|
||||
// their routing argument and error if it's `None`. Passing this sentinel
|
||||
// satisfies their contract; the map executor discards the returned routing.
|
||||
const MAP_BRANCH_SENTINEL_NEXT: &str = "__map_branch_continuation__";
|
||||
|
||||
pub(super) struct MapNodeExecutor;
|
||||
|
||||
impl MapNodeExecutor {
|
||||
@@ -99,26 +93,15 @@ impl MapNodeExecutor {
|
||||
let mut ctx = sub_ctx;
|
||||
|
||||
let exec_result: Result<()> = match &branch_clone.node_type {
|
||||
NodeType::Llm(n) => LlmNodeExecutor::execute(
|
||||
n,
|
||||
Some(MAP_BRANCH_SENTINEL_NEXT),
|
||||
&mut state,
|
||||
&mut ctx,
|
||||
)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
NodeType::Llm(n) => LlmNodeExecutor::execute(n, &mut state, &mut ctx)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
NodeType::Rag(n) => RagNodeExecutor::execute(
|
||||
n,
|
||||
&sub_branch_id,
|
||||
Some(MAP_BRANCH_SENTINEL_NEXT),
|
||||
&mut state,
|
||||
&mut ctx,
|
||||
)
|
||||
.await
|
||||
.map(|_| ()),
|
||||
NodeType::Rag(n) => {
|
||||
RagNodeExecutor::execute(n, &sub_branch_id, &mut state, &mut ctx).await
|
||||
}
|
||||
NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()),
|
||||
_ => Err(anyhow!(
|
||||
"map branch '{}' has type that cannot run inside a map \
|
||||
|
||||
+3
-7
@@ -14,13 +14,12 @@ const DEFAULT_RAG_TIMEOUT_SECS: u64 = 120;
|
||||
pub struct RagNodeExecutor;
|
||||
|
||||
impl RagNodeExecutor {
|
||||
pub async fn execute(
|
||||
pub(super) async fn execute(
|
||||
node: &RagNode,
|
||||
node_id: &str,
|
||||
node_next: Option<&str>,
|
||||
state_manager: &mut StateManager,
|
||||
ctx: &mut RequestContext,
|
||||
) -> Result<String> {
|
||||
) -> Result<()> {
|
||||
let query_template = node.query.as_deref().unwrap_or(DEFAULT_QUERY);
|
||||
let query = state_manager
|
||||
.interpolate(query_template)
|
||||
@@ -55,10 +54,7 @@ impl RagNodeExecutor {
|
||||
|
||||
let output = build_rag_output(context, &sources_str);
|
||||
apply_state_updates(node, state_manager, &output);
|
||||
|
||||
node_next
|
||||
.map(String::from)
|
||||
.ok_or_else(|| anyhow!("rag node '{node_id}' has no `next` set"))
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+7
-49
@@ -1,4 +1,4 @@
|
||||
use anyhow::{Result, bail};
|
||||
use anyhow::Result;
|
||||
use indexmap::IndexMap;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
@@ -129,11 +129,11 @@ pub struct Node {
|
||||
}
|
||||
|
||||
impl Node {
|
||||
/// Returns the single next target as a string slice, or `None` if no next is
|
||||
/// declared or if a multi-target fan-out is declared. Use this for read-only
|
||||
/// inspection (e.g. tests). For execution paths that require single-target
|
||||
/// semantics, use `next_single()` — it errors explicitly when a fan-out is
|
||||
/// declared so the caller can surface a clear failure instead of skipping it.
|
||||
/// Returns the single next target as a string slice for tests and other
|
||||
/// read-only inspection. Returns `None` when no `next:` is declared at all,
|
||||
/// OR when a real multi-target fan-out is declared (since a fan-out has no
|
||||
/// "single" target). Execution paths use `static_next_targets` in the graph
|
||||
/// executor instead.
|
||||
#[allow(dead_code)]
|
||||
pub fn next_target(&self) -> Option<&str> {
|
||||
match &self.next {
|
||||
@@ -143,16 +143,6 @@ impl Node {
|
||||
Some(NextTargets::Many(_)) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the single next target as a string slice, or an explicit error if
|
||||
/// the node declares a multi-target fan-out (which is not yet supported
|
||||
/// pre-Phase-D). Returns `Ok(None)` when no next is declared at all.
|
||||
pub fn next_single(&self) -> Result<Option<&str>> {
|
||||
match &self.next {
|
||||
None => Ok(None),
|
||||
Some(targets) => Ok(Some(targets.single()?.as_str())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
@@ -172,24 +162,9 @@ impl NextTargets {
|
||||
}
|
||||
|
||||
/// True if this declares more than one parallel target (i.e., a real fan-out).
|
||||
#[allow(dead_code)]
|
||||
pub fn is_fan_out(&self) -> bool {
|
||||
matches!(self, NextTargets::Many(v) if v.len() > 1)
|
||||
}
|
||||
|
||||
/// Returns the single target if exactly one is declared, else errors with a
|
||||
/// clear "not yet supported" message. Used by the v1 executor until parallel
|
||||
/// branch execution lands in Phase D.
|
||||
pub fn single(&self) -> Result<&String> {
|
||||
match self {
|
||||
NextTargets::One(s) => Ok(s),
|
||||
NextTargets::Many(v) if v.len() == 1 => Ok(&v[0]),
|
||||
NextTargets::Many(_) => bail!(
|
||||
"Parallel fan-out (`next: [a, b, ...]`) is declared, but parallel \
|
||||
branch execution is not yet implemented in this build."
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for NextTargets {
|
||||
@@ -971,23 +946,7 @@ next: retrieve
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_single_errors_on_real_fan_out_with_clear_message() {
|
||||
let yaml = r#"
|
||||
id: triage
|
||||
type: llm
|
||||
prompt: Classify
|
||||
next: [a, b]
|
||||
"#;
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
let err = node.next_single().unwrap_err().to_string();
|
||||
|
||||
assert!(err.contains("Parallel fan-out"), "got: {err}");
|
||||
assert!(err.contains("not yet implemented"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn next_single_accepts_many_containing_exactly_one_target() {
|
||||
fn next_target_treats_many_of_one_as_single() {
|
||||
let yaml = r#"
|
||||
id: triage
|
||||
type: llm
|
||||
@@ -997,7 +956,6 @@ next: [retrieve]
|
||||
|
||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||
|
||||
assert_eq!(node.next_single().unwrap(), Some("retrieve"));
|
||||
assert_eq!(node.next_target(), Some("retrieve"));
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user