feat: added structured-output extraction for llm and agent nodes
This commit is contained in:
+135
-12
@@ -5,11 +5,12 @@
|
||||
//! `docs/implementation/graph-agents/10.5-llm-nodes.md` for the design.
|
||||
|
||||
use super::state::StateManager;
|
||||
use super::structured;
|
||||
use super::types::LlmNode;
|
||||
use crate::client::{Model, ModelType, call_chat_completions};
|
||||
use crate::config::{RequestContext, Role, RoleLike};
|
||||
use crate::utils::{create_abort_signal, dimmed_text};
|
||||
use anyhow::{Context, Result, anyhow, bail, Error};
|
||||
use anyhow::{Context, Error, Result, anyhow, bail};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
@@ -33,10 +34,22 @@ impl LlmNodeExecutor {
|
||||
) -> Result<String> {
|
||||
let result = run(node, state_manager, parent_ctx).await;
|
||||
let (output, failed) = match result {
|
||||
Ok(out) => (out, false),
|
||||
Ok(raw) => match &node.output_schema {
|
||||
Some(schema) => match structured::extract(&raw, schema, parent_ctx).await {
|
||||
Ok(value) => (value, false),
|
||||
Err(e) => {
|
||||
warn!("llm node structured extraction failed: {e}");
|
||||
(
|
||||
Value::String(format!("LLM node structured-extraction failed: {e}")),
|
||||
true,
|
||||
)
|
||||
}
|
||||
},
|
||||
None => (Value::String(raw), false),
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("llm node failed: {e}");
|
||||
(format!("LLM node failed: {e}"), true)
|
||||
(Value::String(format!("LLM node failed: {e}")), true)
|
||||
}
|
||||
};
|
||||
apply_state_updates_with_output(node, state_manager, &output);
|
||||
@@ -49,7 +62,7 @@ async fn run(
|
||||
state_manager: &mut StateManager,
|
||||
parent_ctx: &mut RequestContext,
|
||||
) -> Result<String> {
|
||||
let instructions: Option<String> = match &node.instructions {
|
||||
let mut instructions: Option<String> = match &node.instructions {
|
||||
Some(s) => Some(
|
||||
state_manager
|
||||
.interpolate(s)
|
||||
@@ -57,10 +70,24 @@ async fn run(
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
let prompt = state_manager
|
||||
let mut prompt = state_manager
|
||||
.interpolate(&node.prompt)
|
||||
.context("Failed to interpolate llm node prompt")?;
|
||||
|
||||
if let Some(schema) = &node.output_schema {
|
||||
let hint = format_schema_hint(schema);
|
||||
match instructions.as_mut() {
|
||||
Some(s) => {
|
||||
s.push_str("\n\n");
|
||||
s.push_str(&hint);
|
||||
}
|
||||
None => {
|
||||
prompt.push_str("\n\n");
|
||||
prompt.push_str(&hint);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref());
|
||||
validate_tools_subset(®ular_tools, &mcp_servers, parent_ctx)?;
|
||||
|
||||
@@ -291,14 +318,30 @@ fn next_for_llm_node(
|
||||
/// of `state_updates` evaluation, then restore the prior value (or set
|
||||
/// it to `Null` if there wasn't one). Same pattern as
|
||||
/// `AgentNodeExecutor`'s `{{output}}` scoping.
|
||||
fn apply_state_updates_with_output(node: &LlmNode, state_manager: &mut StateManager, output: &str) {
|
||||
///
|
||||
/// When `node.output_schema` is set AND the output is a JSON object, its
|
||||
/// top-level keys are also auto-merged into state permanently (before
|
||||
/// state_updates evaluation, so explicit state_updates can override).
|
||||
fn apply_state_updates_with_output(
|
||||
node: &LlmNode,
|
||||
state_manager: &mut StateManager,
|
||||
output: &Value,
|
||||
) {
|
||||
if node.output_schema.is_some()
|
||||
&& let Some(obj) = output.as_object()
|
||||
{
|
||||
for (k, v) in obj {
|
||||
state_manager.state_mut().set(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let Some(updates) = &node.state_updates else {
|
||||
return;
|
||||
};
|
||||
let prev_output = state_manager.state().get(OUTPUT_KEY).cloned();
|
||||
state_manager
|
||||
.state_mut()
|
||||
.set(OUTPUT_KEY.into(), Value::String(output.to_string()));
|
||||
.set(OUTPUT_KEY.into(), output.clone());
|
||||
|
||||
for (key, template) in updates {
|
||||
let value = state_manager.interpolate_lenient(template);
|
||||
@@ -317,6 +360,14 @@ fn apply_state_updates_with_output(node: &LlmNode, state_manager: &mut StateMana
|
||||
}
|
||||
}
|
||||
|
||||
fn format_schema_hint(schema: &Value) -> String {
|
||||
let schema_json = serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
|
||||
format!(
|
||||
"Respond with a JSON object that matches this schema. Output ONLY the JSON \
|
||||
object with no surrounding prose or markdown fences.\n\nSchema:\n{schema_json}"
|
||||
)
|
||||
}
|
||||
|
||||
fn describe_tools_filter(tools: Option<&[String]>) -> String {
|
||||
match tools {
|
||||
None => "<inherit>".into(),
|
||||
@@ -352,6 +403,7 @@ mod tests {
|
||||
max_attempts: 1,
|
||||
max_iterations: 10,
|
||||
state_updates: updates,
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
@@ -362,7 +414,7 @@ mod tests {
|
||||
u.insert("response".into(), "{{output}}".into());
|
||||
let node = node_with(Some(u));
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates_with_output(&node, &mut state, "the answer");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("the answer"));
|
||||
assert_eq!(state.state().get("response"), Some(&json!("the answer")));
|
||||
}
|
||||
|
||||
@@ -372,7 +424,7 @@ mod tests {
|
||||
u.insert("summary".into(), "{{topic}}: {{output}}".into());
|
||||
let node = node_with(Some(u));
|
||||
let mut state = manager_with(&[("topic", json!("LOINC"))]);
|
||||
apply_state_updates_with_output(&node, &mut state, "abc");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("abc"));
|
||||
assert_eq!(state.state().get("summary"), Some(&json!("LOINC: abc")));
|
||||
}
|
||||
|
||||
@@ -382,7 +434,7 @@ mod tests {
|
||||
u.insert("k".into(), "{{output}}".into());
|
||||
let node = node_with(Some(u));
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates_with_output(&node, &mut state, "anything");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("anything"));
|
||||
assert_eq!(state.state().get(OUTPUT_KEY), Some(&json!(null)));
|
||||
}
|
||||
|
||||
@@ -392,7 +444,7 @@ mod tests {
|
||||
u.insert("greeting".into(), "{{output}}".into());
|
||||
let node = node_with(Some(u));
|
||||
let mut state = manager_with(&[("output", json!("preserved"))]);
|
||||
apply_state_updates_with_output(&node, &mut state, "new");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("new"));
|
||||
assert_eq!(state.state().get("greeting"), Some(&json!("new")));
|
||||
assert_eq!(state.state().get(OUTPUT_KEY), Some(&json!("preserved")));
|
||||
}
|
||||
@@ -401,11 +453,82 @@ mod tests {
|
||||
fn no_state_updates_is_a_noop() {
|
||||
let node = node_with(None);
|
||||
let mut state = manager_with(&[("k", json!("v"))]);
|
||||
apply_state_updates_with_output(&node, &mut state, "x");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("x"));
|
||||
assert_eq!(state.state().get("k"), Some(&json!("v")));
|
||||
assert!(state.state().get(OUTPUT_KEY).is_none());
|
||||
}
|
||||
|
||||
fn node_with_schema(updates: Option<HashMap<String, String>>, schema: Value) -> LlmNode {
|
||||
let mut n = node_with(updates);
|
||||
n.output_schema = Some(schema);
|
||||
n
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_auto_merges_top_level_keys() {
|
||||
let node = node_with_schema(None, json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X", "summary": "details"});
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("goal"), Some(&json!("do X")));
|
||||
assert_eq!(state.state().get("summary"), Some(&json!("details")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_preserves_nested_value_types() {
|
||||
let node = node_with_schema(None, json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({
|
||||
"tags": ["a", "b"],
|
||||
"config": { "key": "value" },
|
||||
"count": 42
|
||||
});
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("tags"), Some(&json!(["a", "b"])));
|
||||
assert_eq!(state.state().get("config"), Some(&json!({"key": "value"})));
|
||||
assert_eq!(state.state().get("count"), Some(&json!(42)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_explicit_state_updates_override_auto_merge() {
|
||||
let mut u = HashMap::new();
|
||||
u.insert("goal".into(), "renamed-{{output.goal}}".into());
|
||||
let node = node_with_schema(Some(u), json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X"});
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("goal"), Some(&json!("renamed-do X")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_skips_auto_merge_for_non_object() {
|
||||
let node = node_with_schema(None, json!({"type": "array"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!([1, 2, 3]);
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert!(state.state().get("0").is_none());
|
||||
assert!(state.state().get(OUTPUT_KEY).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_schema_does_not_auto_merge() {
|
||||
let node = node_with(None);
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X"});
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert!(state.state().get("goal").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_schema_hint_includes_schema_and_instruction() {
|
||||
let schema = json!({"type": "object", "properties": {"goal": {"type": "string"}}});
|
||||
let hint = format_schema_hint(&schema);
|
||||
assert!(hint.contains("Schema:"));
|
||||
assert!(hint.contains("\"goal\""));
|
||||
assert!(hint.contains("JSON"));
|
||||
assert!(hint.contains("ONLY"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn describe_tools_filter_renders_each_case() {
|
||||
assert_eq!(describe_tools_filter(None), "<inherit>");
|
||||
|
||||
Reference in New Issue
Block a user