fix: bug in next_single method and improved outcome handling for LLM node execution
This commit is contained in:
+8
-38
@@ -1,5 +1,5 @@
|
|||||||
use super::agent::AgentNodeExecutor;
|
use super::agent::AgentNodeExecutor;
|
||||||
use super::llm::LlmNodeExecutor;
|
use super::llm::{LlmExecutionOutcome, LlmNodeExecutor};
|
||||||
use super::logging::GraphLogger;
|
use super::logging::GraphLogger;
|
||||||
use super::map::MapNodeExecutor;
|
use super::map::MapNodeExecutor;
|
||||||
use super::progress::{BranchProgressHandle, BranchProgressTracker};
|
use super::progress::{BranchProgressHandle, BranchProgressTracker};
|
||||||
@@ -366,17 +366,16 @@ async fn step(
|
|||||||
Ok(StepResult::Continue(vec![next]))
|
Ok(StepResult::Continue(vec![next]))
|
||||||
}
|
}
|
||||||
NodeType::Llm(llm_node) => {
|
NodeType::Llm(llm_node) => {
|
||||||
let primary = first_next_target(node).map(str::to_string);
|
let outcome = LlmNodeExecutor::execute(llm_node, state, ctx).await?;
|
||||||
let llm_routing =
|
let targets = match outcome {
|
||||||
LlmNodeExecutor::execute(llm_node, primary.as_deref(), state, ctx).await?;
|
LlmExecutionOutcome::Continue => static_next_targets(node, current, "llm")?,
|
||||||
let targets = resolve_branching_next(node, &llm_routing);
|
LlmExecutionOutcome::FellBack(target) => vec![target],
|
||||||
|
};
|
||||||
Ok(StepResult::Continue(targets))
|
Ok(StepResult::Continue(targets))
|
||||||
}
|
}
|
||||||
NodeType::Rag(rag_node) => {
|
NodeType::Rag(rag_node) => {
|
||||||
let primary = first_next_target(node).map(str::to_string);
|
RagNodeExecutor::execute(rag_node, current, state, ctx).await?;
|
||||||
let rag_routing =
|
let targets = static_next_targets(node, current, "rag")?;
|
||||||
RagNodeExecutor::execute(rag_node, current, primary.as_deref(), state, ctx).await?;
|
|
||||||
let targets = resolve_branching_next(node, &rag_routing);
|
|
||||||
Ok(StepResult::Continue(targets))
|
Ok(StepResult::Continue(targets))
|
||||||
}
|
}
|
||||||
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
||||||
@@ -406,35 +405,6 @@ fn first_next_target(node: &Node) -> Option<&str> {
|
|||||||
.and_then(|t| t.as_slice().first().map(|s| s.as_str()))
|
.and_then(|t| t.as_slice().first().map(|s| s.as_str()))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolves the actual frontier-advance targets after an LLM/RAG node ran.
|
|
||||||
//
|
|
||||||
// LLM/RAG executors return their chosen routing as a String — either the
|
|
||||||
// primary `next:` target (success path) or the node's `fallback:` (failure
|
|
||||||
// path with retry exhausted). We can't tell these apart from inside step()
|
|
||||||
// without an API refactor, so we compare strings: if the returned routing
|
|
||||||
// matches the first declared `next` target, treat as success and (for
|
|
||||||
// fan-out) use ALL declared targets; otherwise treat as fallback and use the
|
|
||||||
// returned target alone.
|
|
||||||
//
|
|
||||||
// Known limitation: if a fan-out node's `fallback:` is set to the same node
|
|
||||||
// id as its first `next:` target, a successful run is indistinguishable from
|
|
||||||
// a fallback run — both look like "returned the first target". The result is
|
|
||||||
// that the executor advances to all Many targets in the fallback case (which
|
|
||||||
// is the OPPOSITE of the user's likely intent). Workaround: choose a
|
|
||||||
// `fallback:` distinct from any `next:` target.
|
|
||||||
fn resolve_branching_next(node: &Node, returned_routing: &str) -> Vec<String> {
|
|
||||||
let Some(targets) = &node.next else {
|
|
||||||
return vec![returned_routing.to_string()];
|
|
||||||
};
|
|
||||||
let slice = targets.as_slice();
|
|
||||||
let first_matches = slice.first().is_some_and(|s| s == returned_routing);
|
|
||||||
if first_matches && slice.len() > 1 {
|
|
||||||
slice.to_vec()
|
|
||||||
} else {
|
|
||||||
vec![returned_routing.to_string()]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String {
|
fn resolve_end_output(end_node: &EndNode, state: &mut StateManager) -> String {
|
||||||
apply_simple_state_updates(end_node.state_updates.as_ref(), state);
|
apply_simple_state_updates(end_node.state_updates.as_ref(), state);
|
||||||
state.interpolate_lenient(&end_node.output)
|
state.interpolate_lenient(&end_node.output)
|
||||||
|
|||||||
+41
-36
@@ -13,15 +13,26 @@ use tokio::time::timeout;
|
|||||||
|
|
||||||
const OUTPUT_KEY: &str = "output";
|
const OUTPUT_KEY: &str = "output";
|
||||||
|
|
||||||
|
/// What happened during an LLM node's execution, from the caller's routing
|
||||||
|
/// perspective. `Continue` means the caller should advance via the node's
|
||||||
|
/// declared `next:` targets (whether the LLM actually succeeded or failed
|
||||||
|
/// without a fallback — either way, the executor uses node.next). `FellBack`
|
||||||
|
/// means the LLM failed after retries and the node had a `fallback:` declared,
|
||||||
|
/// so routing should go to that fallback target only.
|
||||||
|
#[derive(Debug, PartialEq, Eq)]
|
||||||
|
pub(super) enum LlmExecutionOutcome {
|
||||||
|
Continue,
|
||||||
|
FellBack(String),
|
||||||
|
}
|
||||||
|
|
||||||
pub struct LlmNodeExecutor;
|
pub struct LlmNodeExecutor;
|
||||||
|
|
||||||
impl LlmNodeExecutor {
|
impl LlmNodeExecutor {
|
||||||
pub async fn execute(
|
pub(super) async fn execute(
|
||||||
node: &LlmNode,
|
node: &LlmNode,
|
||||||
node_next: Option<&str>,
|
|
||||||
state_manager: &mut StateManager,
|
state_manager: &mut StateManager,
|
||||||
parent_ctx: &mut RequestContext,
|
parent_ctx: &mut RequestContext,
|
||||||
) -> Result<String> {
|
) -> Result<LlmExecutionOutcome> {
|
||||||
let result = run(node, state_manager, parent_ctx).await;
|
let result = run(node, state_manager, parent_ctx).await;
|
||||||
let (output, failed) = match result {
|
let (output, failed) = match result {
|
||||||
Ok(raw) => match &node.output_schema {
|
Ok(raw) => match &node.output_schema {
|
||||||
@@ -44,7 +55,15 @@ impl LlmNodeExecutor {
|
|||||||
};
|
};
|
||||||
|
|
||||||
apply_state_updates_with_output(node, state_manager, &output);
|
apply_state_updates_with_output(node, state_manager, &output);
|
||||||
next_for_llm_node(node_next, failed, node.fallback.as_deref())
|
Ok(outcome_from(failed, node.fallback.as_deref()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn outcome_from(failed: bool, fallback: Option<&str>) -> LlmExecutionOutcome {
|
||||||
|
if failed && let Some(fb) = fallback {
|
||||||
|
LlmExecutionOutcome::FellBack(fb.to_string())
|
||||||
|
} else {
|
||||||
|
LlmExecutionOutcome::Continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,19 +317,6 @@ fn is_transient(err: &Error) -> bool {
|
|||||||
|| s.contains("produced no output")
|
|| s.contains("produced no output")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn next_for_llm_node(
|
|
||||||
node_next: Option<&str>,
|
|
||||||
failed: bool,
|
|
||||||
fallback: Option<&str>,
|
|
||||||
) -> Result<String> {
|
|
||||||
if failed && let Some(fb) = fallback {
|
|
||||||
return Ok(fb.to_string());
|
|
||||||
}
|
|
||||||
node_next
|
|
||||||
.map(String::from)
|
|
||||||
.ok_or_else(|| anyhow!("llm node has no `next` set; llm nodes need static routing"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_state_updates_with_output(
|
fn apply_state_updates_with_output(
|
||||||
node: &LlmNode,
|
node: &LlmNode,
|
||||||
state_manager: &mut StateManager,
|
state_manager: &mut StateManager,
|
||||||
@@ -457,30 +463,29 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn next_for_llm_node_success_routes_to_next() {
|
fn outcome_from_success_is_continue() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
next_for_llm_node(Some("nx"), false, Some("fb")).unwrap(),
|
outcome_from(false, Some("fb")),
|
||||||
"nx"
|
LlmExecutionOutcome::Continue
|
||||||
|
);
|
||||||
|
assert_eq!(outcome_from(false, None), LlmExecutionOutcome::Continue);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn outcome_from_failure_with_fallback_is_fell_back() {
|
||||||
|
assert_eq!(
|
||||||
|
outcome_from(true, Some("fb")),
|
||||||
|
LlmExecutionOutcome::FellBack("fb".to_string())
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn next_for_llm_node_failure_with_fallback_routes_to_fallback() {
|
fn outcome_from_failure_without_fallback_is_continue() {
|
||||||
assert_eq!(
|
// Failed but no fallback: caller routes via node.next as if successful.
|
||||||
next_for_llm_node(Some("nx"), true, Some("fb")).unwrap(),
|
// The error has already been recorded to state via the OUTPUT_KEY by
|
||||||
"fb"
|
// execute(); the caller's `static_next_targets` will error if node.next
|
||||||
);
|
// is also missing.
|
||||||
}
|
assert_eq!(outcome_from(true, None), LlmExecutionOutcome::Continue);
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn next_for_llm_node_failure_without_fallback_routes_to_next() {
|
|
||||||
assert_eq!(next_for_llm_node(Some("nx"), true, None).unwrap(), "nx");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn next_for_llm_node_errors_without_next_or_fallback() {
|
|
||||||
assert!(next_for_llm_node(None, false, None).is_err());
|
|
||||||
assert!(next_for_llm_node(None, true, None).is_err());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn node_with_schema(updates: Option<HashMap<String, String>>, schema: Value) -> LlmNode {
|
fn node_with_schema(updates: Option<HashMap<String, String>>, schema: Value) -> LlmNode {
|
||||||
|
|||||||
+6
-23
@@ -14,12 +14,6 @@ use std::collections::HashMap;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::Semaphore;
|
use tokio::sync::Semaphore;
|
||||||
|
|
||||||
// Map sub-branches are atomic — the branch node has no `next` (enforced by
|
|
||||||
// validator rule C.5). But LLM/RAG node executors require an `Option<&str>` for
|
|
||||||
// their routing argument and error if it's `None`. Passing this sentinel
|
|
||||||
// satisfies their contract; the map executor discards the returned routing.
|
|
||||||
const MAP_BRANCH_SENTINEL_NEXT: &str = "__map_branch_continuation__";
|
|
||||||
|
|
||||||
pub(super) struct MapNodeExecutor;
|
pub(super) struct MapNodeExecutor;
|
||||||
|
|
||||||
impl MapNodeExecutor {
|
impl MapNodeExecutor {
|
||||||
@@ -99,26 +93,15 @@ impl MapNodeExecutor {
|
|||||||
let mut ctx = sub_ctx;
|
let mut ctx = sub_ctx;
|
||||||
|
|
||||||
let exec_result: Result<()> = match &branch_clone.node_type {
|
let exec_result: Result<()> = match &branch_clone.node_type {
|
||||||
NodeType::Llm(n) => LlmNodeExecutor::execute(
|
NodeType::Llm(n) => LlmNodeExecutor::execute(n, &mut state, &mut ctx)
|
||||||
n,
|
.await
|
||||||
Some(MAP_BRANCH_SENTINEL_NEXT),
|
.map(|_| ()),
|
||||||
&mut state,
|
|
||||||
&mut ctx,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map(|_| ()),
|
|
||||||
NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx)
|
NodeType::Agent(n) => AgentNodeExecutor::execute(n, &mut state, &mut ctx)
|
||||||
.await
|
.await
|
||||||
.map(|_| ()),
|
.map(|_| ()),
|
||||||
NodeType::Rag(n) => RagNodeExecutor::execute(
|
NodeType::Rag(n) => {
|
||||||
n,
|
RagNodeExecutor::execute(n, &sub_branch_id, &mut state, &mut ctx).await
|
||||||
&sub_branch_id,
|
}
|
||||||
Some(MAP_BRANCH_SENTINEL_NEXT),
|
|
||||||
&mut state,
|
|
||||||
&mut ctx,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.map(|_| ()),
|
|
||||||
NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()),
|
NodeType::Script(n) => script_clone.execute(n, &mut state).await.map(|_| ()),
|
||||||
_ => Err(anyhow!(
|
_ => Err(anyhow!(
|
||||||
"map branch '{}' has type that cannot run inside a map \
|
"map branch '{}' has type that cannot run inside a map \
|
||||||
|
|||||||
+3
-7
@@ -14,13 +14,12 @@ const DEFAULT_RAG_TIMEOUT_SECS: u64 = 120;
|
|||||||
pub struct RagNodeExecutor;
|
pub struct RagNodeExecutor;
|
||||||
|
|
||||||
impl RagNodeExecutor {
|
impl RagNodeExecutor {
|
||||||
pub async fn execute(
|
pub(super) async fn execute(
|
||||||
node: &RagNode,
|
node: &RagNode,
|
||||||
node_id: &str,
|
node_id: &str,
|
||||||
node_next: Option<&str>,
|
|
||||||
state_manager: &mut StateManager,
|
state_manager: &mut StateManager,
|
||||||
ctx: &mut RequestContext,
|
ctx: &mut RequestContext,
|
||||||
) -> Result<String> {
|
) -> Result<()> {
|
||||||
let query_template = node.query.as_deref().unwrap_or(DEFAULT_QUERY);
|
let query_template = node.query.as_deref().unwrap_or(DEFAULT_QUERY);
|
||||||
let query = state_manager
|
let query = state_manager
|
||||||
.interpolate(query_template)
|
.interpolate(query_template)
|
||||||
@@ -55,10 +54,7 @@ impl RagNodeExecutor {
|
|||||||
|
|
||||||
let output = build_rag_output(context, &sources_str);
|
let output = build_rag_output(context, &sources_str);
|
||||||
apply_state_updates(node, state_manager, &output);
|
apply_state_updates(node, state_manager, &output);
|
||||||
|
Ok(())
|
||||||
node_next
|
|
||||||
.map(String::from)
|
|
||||||
.ok_or_else(|| anyhow!("rag node '{node_id}' has no `next` set"))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+7
-49
@@ -1,4 +1,4 @@
|
|||||||
use anyhow::{Result, bail};
|
use anyhow::Result;
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
@@ -129,11 +129,11 @@ pub struct Node {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Node {
|
impl Node {
|
||||||
/// Returns the single next target as a string slice, or `None` if no next is
|
/// Returns the single next target as a string slice for tests and other
|
||||||
/// declared or if a multi-target fan-out is declared. Use this for read-only
|
/// read-only inspection. Returns `None` when no `next:` is declared at all,
|
||||||
/// inspection (e.g. tests). For execution paths that require single-target
|
/// OR when a real multi-target fan-out is declared (since a fan-out has no
|
||||||
/// semantics, use `next_single()` — it errors explicitly when a fan-out is
|
/// "single" target). Execution paths use `static_next_targets` in the graph
|
||||||
/// declared so the caller can surface a clear failure instead of skipping it.
|
/// executor instead.
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pub fn next_target(&self) -> Option<&str> {
|
pub fn next_target(&self) -> Option<&str> {
|
||||||
match &self.next {
|
match &self.next {
|
||||||
@@ -143,16 +143,6 @@ impl Node {
|
|||||||
Some(NextTargets::Many(_)) => None,
|
Some(NextTargets::Many(_)) => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the single next target as a string slice, or an explicit error if
|
|
||||||
/// the node declares a multi-target fan-out (which is not yet supported
|
|
||||||
/// pre-Phase-D). Returns `Ok(None)` when no next is declared at all.
|
|
||||||
pub fn next_single(&self) -> Result<Option<&str>> {
|
|
||||||
match &self.next {
|
|
||||||
None => Ok(None),
|
|
||||||
Some(targets) => Ok(Some(targets.single()?.as_str())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -172,24 +162,9 @@ impl NextTargets {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// True if this declares more than one parallel target (i.e., a real fan-out).
|
/// True if this declares more than one parallel target (i.e., a real fan-out).
|
||||||
#[allow(dead_code)]
|
|
||||||
pub fn is_fan_out(&self) -> bool {
|
pub fn is_fan_out(&self) -> bool {
|
||||||
matches!(self, NextTargets::Many(v) if v.len() > 1)
|
matches!(self, NextTargets::Many(v) if v.len() > 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the single target if exactly one is declared, else errors with a
|
|
||||||
/// clear "not yet supported" message. Used by the v1 executor until parallel
|
|
||||||
/// branch execution lands in Phase D.
|
|
||||||
pub fn single(&self) -> Result<&String> {
|
|
||||||
match self {
|
|
||||||
NextTargets::One(s) => Ok(s),
|
|
||||||
NextTargets::Many(v) if v.len() == 1 => Ok(&v[0]),
|
|
||||||
NextTargets::Many(_) => bail!(
|
|
||||||
"Parallel fan-out (`next: [a, b, ...]`) is declared, but parallel \
|
|
||||||
branch execution is not yet implemented in this build."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<String> for NextTargets {
|
impl From<String> for NextTargets {
|
||||||
@@ -971,23 +946,7 @@ next: retrieve
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn next_single_errors_on_real_fan_out_with_clear_message() {
|
fn next_target_treats_many_of_one_as_single() {
|
||||||
let yaml = r#"
|
|
||||||
id: triage
|
|
||||||
type: llm
|
|
||||||
prompt: Classify
|
|
||||||
next: [a, b]
|
|
||||||
"#;
|
|
||||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
|
||||||
|
|
||||||
let err = node.next_single().unwrap_err().to_string();
|
|
||||||
|
|
||||||
assert!(err.contains("Parallel fan-out"), "got: {err}");
|
|
||||||
assert!(err.contains("not yet implemented"), "got: {err}");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn next_single_accepts_many_containing_exactly_one_target() {
|
|
||||||
let yaml = r#"
|
let yaml = r#"
|
||||||
id: triage
|
id: triage
|
||||||
type: llm
|
type: llm
|
||||||
@@ -997,7 +956,6 @@ next: [retrieve]
|
|||||||
|
|
||||||
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
let node: Node = serde_yaml::from_str(yaml).unwrap();
|
||||||
|
|
||||||
assert_eq!(node.next_single().unwrap(), Some("retrieve"));
|
|
||||||
assert_eq!(node.next_target(), Some("retrieve"));
|
assert_eq!(node.next_target(), Some("retrieve"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user