feat: added structured-output extraction for llm and agent nodes
This commit is contained in:
+92
-11
@@ -6,6 +6,7 @@
|
||||
//! `{{output}}` for the agent's stdout).
|
||||
|
||||
use super::state::StateManager;
|
||||
use super::structured;
|
||||
use super::types::AgentNode;
|
||||
use crate::config::RequestContext;
|
||||
use crate::function::supervisor::run_agent_for_graph;
|
||||
@@ -42,7 +43,7 @@ impl AgentNodeExecutor {
|
||||
|
||||
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS));
|
||||
|
||||
let output = timeout(
|
||||
let raw = timeout(
|
||||
timeout_dur,
|
||||
run_agent_for_graph(parent_ctx, &node.agent, &prompt),
|
||||
)
|
||||
@@ -56,9 +57,21 @@ impl AgentNodeExecutor {
|
||||
})?
|
||||
.with_context(|| format!("Agent '{}' failed", node.agent))?;
|
||||
|
||||
apply_state_updates(node, state_manager, &output);
|
||||
let output_value = match &node.output_schema {
|
||||
Some(schema) => structured::extract(&raw, schema, parent_ctx)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Agent '{}' output failed structured-output extraction",
|
||||
node.agent
|
||||
)
|
||||
})?,
|
||||
None => Value::String(raw.clone()),
|
||||
};
|
||||
|
||||
Ok(output)
|
||||
apply_state_updates(node, state_manager, &output_value);
|
||||
|
||||
Ok(raw)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,14 +94,26 @@ fn indent_prompt(prompt: &str, prefix_spaces: usize) -> Vec<String> {
|
||||
/// applies every key/template in `state_updates`. The temporary `output`
|
||||
/// state key is removed at the end so it doesn't leak into subsequent
|
||||
/// nodes' templates.
|
||||
fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &str) {
|
||||
///
|
||||
/// When `node.output_schema` is set AND the output is a JSON object, its
|
||||
/// top-level keys are also auto-merged into state permanently (before
|
||||
/// state_updates evaluation, so explicit state_updates can override).
|
||||
fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &Value) {
|
||||
if node.output_schema.is_some()
|
||||
&& let Some(obj) = output.as_object()
|
||||
{
|
||||
for (k, v) in obj {
|
||||
state_manager.state_mut().set(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let Some(updates) = &node.state_updates else {
|
||||
return;
|
||||
};
|
||||
let prev_output = state_manager.state().get(OUTPUT_KEY).cloned();
|
||||
state_manager
|
||||
.state_mut()
|
||||
.set(OUTPUT_KEY.into(), Value::String(output.to_string()));
|
||||
.set(OUTPUT_KEY.into(), output.clone());
|
||||
|
||||
for (key, template) in updates {
|
||||
let value = state_manager.interpolate_lenient(template);
|
||||
@@ -127,6 +152,7 @@ mod tests {
|
||||
agent: "test_agent".into(),
|
||||
prompt: prompt.into(),
|
||||
state_updates: updates,
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
@@ -139,7 +165,7 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates(&node, &mut state, "agent finished its work");
|
||||
apply_state_updates(&node, &mut state, &json!("agent finished its work"));
|
||||
assert_eq!(
|
||||
state.state().get("findings"),
|
||||
Some(&json!("agent finished its work"))
|
||||
@@ -154,7 +180,7 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[("topic", json!("auth"))]);
|
||||
apply_state_updates(&node, &mut state, "JWT vs sessions");
|
||||
apply_state_updates(&node, &mut state, &json!("JWT vs sessions"));
|
||||
assert_eq!(
|
||||
state.state().get("summary"),
|
||||
Some(&json!("auth: JWT vs sessions"))
|
||||
@@ -169,7 +195,7 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates(&node, &mut state, "anything");
|
||||
apply_state_updates(&node, &mut state, &json!("anything"));
|
||||
assert_eq!(state.state().get("output"), Some(&Value::Null));
|
||||
}
|
||||
|
||||
@@ -181,7 +207,7 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[("output", json!("preserved"))]);
|
||||
apply_state_updates(&node, &mut state, "new agent output");
|
||||
apply_state_updates(&node, &mut state, &json!("new agent output"));
|
||||
assert_eq!(
|
||||
state.state().get("greeting"),
|
||||
Some(&json!("new agent output"))
|
||||
@@ -193,7 +219,7 @@ mod tests {
|
||||
fn no_state_updates_is_a_noop() {
|
||||
let node = node_with("hi", None);
|
||||
let mut state = manager_with(&[("k", json!("v"))]);
|
||||
apply_state_updates(&node, &mut state, "ignored");
|
||||
apply_state_updates(&node, &mut state, &json!("ignored"));
|
||||
assert_eq!(state.state().get("k"), Some(&json!("v")));
|
||||
assert!(state.state().get("output").is_none());
|
||||
}
|
||||
@@ -206,7 +232,62 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates(&node, &mut state, "DATA");
|
||||
apply_state_updates(&node, &mut state, &json!("DATA"));
|
||||
assert_eq!(state.state().get("decorated"), Some(&json!("[] DATA")));
|
||||
}
|
||||
|
||||
fn node_with_schema(
|
||||
prompt: &str,
|
||||
updates: Option<HashMap<String, String>>,
|
||||
schema: Value,
|
||||
) -> AgentNode {
|
||||
let mut n = node_with(prompt, updates);
|
||||
n.output_schema = Some(schema);
|
||||
n
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_auto_merges_top_level_keys() {
|
||||
let node = node_with_schema("hi", None, json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X", "summary": "details"});
|
||||
apply_state_updates(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("goal"), Some(&json!("do X")));
|
||||
assert_eq!(state.state().get("summary"), Some(&json!("details")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_preserves_nested_value_types() {
|
||||
let node = node_with_schema("hi", None, json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({
|
||||
"tags": ["a", "b"],
|
||||
"config": { "key": "value" },
|
||||
"count": 42
|
||||
});
|
||||
apply_state_updates(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("tags"), Some(&json!(["a", "b"])));
|
||||
assert_eq!(state.state().get("config"), Some(&json!({"key": "value"})));
|
||||
assert_eq!(state.state().get("count"), Some(&json!(42)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_explicit_state_updates_override_auto_merge() {
|
||||
let mut u = HashMap::new();
|
||||
u.insert("goal".into(), "renamed-{{output.goal}}".into());
|
||||
let node = node_with_schema("hi", Some(u), json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X"});
|
||||
apply_state_updates(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("goal"), Some(&json!("renamed-do X")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_schema_does_not_auto_merge() {
|
||||
let node = node_with("hi", None);
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X"});
|
||||
apply_state_updates(&node, &mut state, &output);
|
||||
assert!(state.state().get("goal").is_none());
|
||||
}
|
||||
}
|
||||
|
||||
+135
-12
@@ -5,11 +5,12 @@
|
||||
//! `docs/implementation/graph-agents/10.5-llm-nodes.md` for the design.
|
||||
|
||||
use super::state::StateManager;
|
||||
use super::structured;
|
||||
use super::types::LlmNode;
|
||||
use crate::client::{Model, ModelType, call_chat_completions};
|
||||
use crate::config::{RequestContext, Role, RoleLike};
|
||||
use crate::utils::{create_abort_signal, dimmed_text};
|
||||
use anyhow::{Context, Result, anyhow, bail, Error};
|
||||
use anyhow::{Context, Error, Result, anyhow, bail};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
@@ -33,10 +34,22 @@ impl LlmNodeExecutor {
|
||||
) -> Result<String> {
|
||||
let result = run(node, state_manager, parent_ctx).await;
|
||||
let (output, failed) = match result {
|
||||
Ok(out) => (out, false),
|
||||
Ok(raw) => match &node.output_schema {
|
||||
Some(schema) => match structured::extract(&raw, schema, parent_ctx).await {
|
||||
Ok(value) => (value, false),
|
||||
Err(e) => {
|
||||
warn!("llm node structured extraction failed: {e}");
|
||||
(
|
||||
Value::String(format!("LLM node structured-extraction failed: {e}")),
|
||||
true,
|
||||
)
|
||||
}
|
||||
},
|
||||
None => (Value::String(raw), false),
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("llm node failed: {e}");
|
||||
(format!("LLM node failed: {e}"), true)
|
||||
(Value::String(format!("LLM node failed: {e}")), true)
|
||||
}
|
||||
};
|
||||
apply_state_updates_with_output(node, state_manager, &output);
|
||||
@@ -49,7 +62,7 @@ async fn run(
|
||||
state_manager: &mut StateManager,
|
||||
parent_ctx: &mut RequestContext,
|
||||
) -> Result<String> {
|
||||
let instructions: Option<String> = match &node.instructions {
|
||||
let mut instructions: Option<String> = match &node.instructions {
|
||||
Some(s) => Some(
|
||||
state_manager
|
||||
.interpolate(s)
|
||||
@@ -57,10 +70,24 @@ async fn run(
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
let prompt = state_manager
|
||||
let mut prompt = state_manager
|
||||
.interpolate(&node.prompt)
|
||||
.context("Failed to interpolate llm node prompt")?;
|
||||
|
||||
if let Some(schema) = &node.output_schema {
|
||||
let hint = format_schema_hint(schema);
|
||||
match instructions.as_mut() {
|
||||
Some(s) => {
|
||||
s.push_str("\n\n");
|
||||
s.push_str(&hint);
|
||||
}
|
||||
None => {
|
||||
prompt.push_str("\n\n");
|
||||
prompt.push_str(&hint);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref());
|
||||
validate_tools_subset(®ular_tools, &mcp_servers, parent_ctx)?;
|
||||
|
||||
@@ -291,14 +318,30 @@ fn next_for_llm_node(
|
||||
/// of `state_updates` evaluation, then restore the prior value (or set
|
||||
/// it to `Null` if there wasn't one). Same pattern as
|
||||
/// `AgentNodeExecutor`'s `{{output}}` scoping.
|
||||
fn apply_state_updates_with_output(node: &LlmNode, state_manager: &mut StateManager, output: &str) {
|
||||
///
|
||||
/// When `node.output_schema` is set AND the output is a JSON object, its
|
||||
/// top-level keys are also auto-merged into state permanently (before
|
||||
/// state_updates evaluation, so explicit state_updates can override).
|
||||
fn apply_state_updates_with_output(
|
||||
node: &LlmNode,
|
||||
state_manager: &mut StateManager,
|
||||
output: &Value,
|
||||
) {
|
||||
if node.output_schema.is_some()
|
||||
&& let Some(obj) = output.as_object()
|
||||
{
|
||||
for (k, v) in obj {
|
||||
state_manager.state_mut().set(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let Some(updates) = &node.state_updates else {
|
||||
return;
|
||||
};
|
||||
let prev_output = state_manager.state().get(OUTPUT_KEY).cloned();
|
||||
state_manager
|
||||
.state_mut()
|
||||
.set(OUTPUT_KEY.into(), Value::String(output.to_string()));
|
||||
.set(OUTPUT_KEY.into(), output.clone());
|
||||
|
||||
for (key, template) in updates {
|
||||
let value = state_manager.interpolate_lenient(template);
|
||||
@@ -317,6 +360,14 @@ fn apply_state_updates_with_output(node: &LlmNode, state_manager: &mut StateMana
|
||||
}
|
||||
}
|
||||
|
||||
fn format_schema_hint(schema: &Value) -> String {
|
||||
let schema_json = serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
|
||||
format!(
|
||||
"Respond with a JSON object that matches this schema. Output ONLY the JSON \
|
||||
object with no surrounding prose or markdown fences.\n\nSchema:\n{schema_json}"
|
||||
)
|
||||
}
|
||||
|
||||
fn describe_tools_filter(tools: Option<&[String]>) -> String {
|
||||
match tools {
|
||||
None => "<inherit>".into(),
|
||||
@@ -352,6 +403,7 @@ mod tests {
|
||||
max_attempts: 1,
|
||||
max_iterations: 10,
|
||||
state_updates: updates,
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
@@ -362,7 +414,7 @@ mod tests {
|
||||
u.insert("response".into(), "{{output}}".into());
|
||||
let node = node_with(Some(u));
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates_with_output(&node, &mut state, "the answer");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("the answer"));
|
||||
assert_eq!(state.state().get("response"), Some(&json!("the answer")));
|
||||
}
|
||||
|
||||
@@ -372,7 +424,7 @@ mod tests {
|
||||
u.insert("summary".into(), "{{topic}}: {{output}}".into());
|
||||
let node = node_with(Some(u));
|
||||
let mut state = manager_with(&[("topic", json!("LOINC"))]);
|
||||
apply_state_updates_with_output(&node, &mut state, "abc");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("abc"));
|
||||
assert_eq!(state.state().get("summary"), Some(&json!("LOINC: abc")));
|
||||
}
|
||||
|
||||
@@ -382,7 +434,7 @@ mod tests {
|
||||
u.insert("k".into(), "{{output}}".into());
|
||||
let node = node_with(Some(u));
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates_with_output(&node, &mut state, "anything");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("anything"));
|
||||
assert_eq!(state.state().get(OUTPUT_KEY), Some(&json!(null)));
|
||||
}
|
||||
|
||||
@@ -392,7 +444,7 @@ mod tests {
|
||||
u.insert("greeting".into(), "{{output}}".into());
|
||||
let node = node_with(Some(u));
|
||||
let mut state = manager_with(&[("output", json!("preserved"))]);
|
||||
apply_state_updates_with_output(&node, &mut state, "new");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("new"));
|
||||
assert_eq!(state.state().get("greeting"), Some(&json!("new")));
|
||||
assert_eq!(state.state().get(OUTPUT_KEY), Some(&json!("preserved")));
|
||||
}
|
||||
@@ -401,11 +453,82 @@ mod tests {
|
||||
fn no_state_updates_is_a_noop() {
|
||||
let node = node_with(None);
|
||||
let mut state = manager_with(&[("k", json!("v"))]);
|
||||
apply_state_updates_with_output(&node, &mut state, "x");
|
||||
apply_state_updates_with_output(&node, &mut state, &json!("x"));
|
||||
assert_eq!(state.state().get("k"), Some(&json!("v")));
|
||||
assert!(state.state().get(OUTPUT_KEY).is_none());
|
||||
}
|
||||
|
||||
fn node_with_schema(updates: Option<HashMap<String, String>>, schema: Value) -> LlmNode {
|
||||
let mut n = node_with(updates);
|
||||
n.output_schema = Some(schema);
|
||||
n
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_auto_merges_top_level_keys() {
|
||||
let node = node_with_schema(None, json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X", "summary": "details"});
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("goal"), Some(&json!("do X")));
|
||||
assert_eq!(state.state().get("summary"), Some(&json!("details")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_preserves_nested_value_types() {
|
||||
let node = node_with_schema(None, json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({
|
||||
"tags": ["a", "b"],
|
||||
"config": { "key": "value" },
|
||||
"count": 42
|
||||
});
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("tags"), Some(&json!(["a", "b"])));
|
||||
assert_eq!(state.state().get("config"), Some(&json!({"key": "value"})));
|
||||
assert_eq!(state.state().get("count"), Some(&json!(42)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_explicit_state_updates_override_auto_merge() {
|
||||
let mut u = HashMap::new();
|
||||
u.insert("goal".into(), "renamed-{{output.goal}}".into());
|
||||
let node = node_with_schema(Some(u), json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X"});
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("goal"), Some(&json!("renamed-do X")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_skips_auto_merge_for_non_object() {
|
||||
let node = node_with_schema(None, json!({"type": "array"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!([1, 2, 3]);
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert!(state.state().get("0").is_none());
|
||||
assert!(state.state().get(OUTPUT_KEY).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_schema_does_not_auto_merge() {
|
||||
let node = node_with(None);
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X"});
|
||||
apply_state_updates_with_output(&node, &mut state, &output);
|
||||
assert!(state.state().get("goal").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_schema_hint_includes_schema_and_instruction() {
|
||||
let schema = json!({"type": "object", "properties": {"goal": {"type": "string"}}});
|
||||
let hint = format_schema_hint(&schema);
|
||||
assert!(hint.contains("Schema:"));
|
||||
assert!(hint.contains("\"goal\""));
|
||||
assert!(hint.contains("JSON"));
|
||||
assert!(hint.contains("ONLY"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn describe_tools_filter_renders_each_case() {
|
||||
assert_eq!(describe_tools_filter(None), "<inherit>");
|
||||
|
||||
@@ -8,6 +8,7 @@ pub mod llm;
|
||||
pub mod parser;
|
||||
pub mod script;
|
||||
pub mod state;
|
||||
pub mod structured;
|
||||
pub mod types;
|
||||
pub mod user_interaction;
|
||||
pub mod validator;
|
||||
|
||||
+75
-5
@@ -13,13 +13,14 @@ use std::path::PathBuf;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
static TEMPLATE_VAR_RE: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r"\{\{([a-zA-Z0-9_\.]+)\}\}").expect("invalid template regex"));
|
||||
LazyLock::new(|| Regex::new(r"\{\{([a-zA-Z0-9_\.\[\]]+)\}\}").expect("invalid template regex"));
|
||||
|
||||
/// Wraps [`GraphState`] with template interpolation, script-output merging,
|
||||
/// and a large-state temp-file fallback for use with scripts.
|
||||
///
|
||||
/// Template syntax: `{{key}}` for top-level keys, `{{a.b.c}}` for nested
|
||||
/// JSON paths. Use [`StateManager::interpolate`] for strict interpolation
|
||||
/// JSON paths, and `{{arr[0]}}` / `{{a.b[2].c}}` / `{{matrix[0][1]}}` for
|
||||
/// array indices. Use [`StateManager::interpolate`] for strict interpolation
|
||||
/// (errors on missing keys) or [`StateManager::interpolate_lenient`] for
|
||||
/// best-effort (missing keys become empty strings).
|
||||
pub struct StateManager {
|
||||
@@ -89,10 +90,20 @@ impl StateManager {
|
||||
|
||||
fn get_nested_value(&self, key: &str) -> Option<&Value> {
|
||||
let mut parts = key.split('.');
|
||||
let root = parts.next()?;
|
||||
let mut current = self.state.get(root)?;
|
||||
let first = parts.next()?;
|
||||
let (root_key, root_indices) = split_indices(first)?;
|
||||
let mut current = self.state.get(root_key)?;
|
||||
for idx in root_indices {
|
||||
current = current.get(idx)?;
|
||||
}
|
||||
for part in parts {
|
||||
current = current.get(part)?;
|
||||
let (segment_key, indices) = split_indices(part)?;
|
||||
if !segment_key.is_empty() {
|
||||
current = current.get(segment_key)?;
|
||||
}
|
||||
for idx in indices {
|
||||
current = current.get(idx)?;
|
||||
}
|
||||
}
|
||||
Some(current)
|
||||
}
|
||||
@@ -203,6 +214,26 @@ impl StateRepresentation {
|
||||
}
|
||||
}
|
||||
|
||||
fn split_indices(segment: &str) -> Option<(&str, Vec<usize>)> {
|
||||
let bracket_start = segment.find('[');
|
||||
let key = match bracket_start {
|
||||
Some(i) => &segment[..i],
|
||||
None => return Some((segment, Vec::new())),
|
||||
};
|
||||
let mut indices = Vec::new();
|
||||
let mut rest = &segment[bracket_start.unwrap()..];
|
||||
while !rest.is_empty() {
|
||||
if !rest.starts_with('[') {
|
||||
return None;
|
||||
}
|
||||
let close = rest.find(']')?;
|
||||
let idx: usize = rest[1..close].parse().ok()?;
|
||||
indices.push(idx);
|
||||
rest = &rest[close + 1..];
|
||||
}
|
||||
Some((key, indices))
|
||||
}
|
||||
|
||||
fn value_to_string(value: &Value) -> String {
|
||||
match value {
|
||||
Value::String(s) => s.clone(),
|
||||
@@ -321,6 +352,45 @@ mod tests {
|
||||
assert_eq!(result, r#"{"key":"value"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolates_array_indices() {
|
||||
let manager = manager_with(&[("items", json!(["a", "b", "c"]))]);
|
||||
assert_eq!(manager.interpolate("{{items[0]}}").unwrap(), "a");
|
||||
assert_eq!(manager.interpolate("{{items[2]}}").unwrap(), "c");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolates_array_indices_inside_nested_paths() {
|
||||
let manager = manager_with(&[("outer", json!({ "inner": { "arr": ["x", "y", "z"] } }))]);
|
||||
let result = manager
|
||||
.interpolate("first={{outer.inner.arr[0]}} last={{outer.inner.arr[2]}}")
|
||||
.unwrap();
|
||||
assert_eq!(result, "first=x last=z");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolates_object_fields_after_array_index() {
|
||||
let manager = manager_with(&[("users", json!([{ "name": "Alice" }, { "name": "Bob" }]))]);
|
||||
let result = manager
|
||||
.interpolate("{{users[0].name}} and {{users[1].name}}")
|
||||
.unwrap();
|
||||
assert_eq!(result, "Alice and Bob");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn interpolates_nested_array_indices() {
|
||||
let manager = manager_with(&[("matrix", json!([[1, 2], [3, 4]]))]);
|
||||
assert_eq!(manager.interpolate("{{matrix[0][1]}}").unwrap(), "2");
|
||||
assert_eq!(manager.interpolate("{{matrix[1][0]}}").unwrap(), "3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn out_of_bounds_array_index_is_missing() {
|
||||
let manager = manager_with(&[("items", json!(["a", "b"]))]);
|
||||
let err = manager.interpolate("{{items[5]}}").unwrap_err().to_string();
|
||||
assert!(err.contains("not found"), "got: {err}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replaces_all_occurrences_of_same_key() {
|
||||
let manager = manager_with(&[("n", json!("Alice"))]);
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
//! Structured-output extraction for `llm` and `agent` nodes. Takes the
|
||||
//! raw final text of a node and a user-supplied JSON Schema, and returns
|
||||
//! a parsed [`serde_json::Value`] conforming to that schema (best-effort).
|
||||
//!
|
||||
//! Strategy: try to parse `raw` directly first (with light cleanup of
|
||||
//! markdown fences), and only invoke a follow-up LLM call against the
|
||||
//! built-in `structured-output` role if direct parsing fails. On
|
||||
//! extractor-output parse failure, perform one repair retry.
|
||||
|
||||
use crate::client::call_chat_completions;
|
||||
use crate::config::{Input, RequestContext, Role, RoleLike};
|
||||
use crate::utils::{create_abort_signal, dimmed_text};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
|
||||
const EXTRACTOR_ROLE_NAME: &str = "__structured_output__";
|
||||
|
||||
const EXTRACTOR_ROLE_PROMPT: &str = "\
|
||||
Extract a JSON object from the user's input that strictly conforms to the provided JSON Schema.
|
||||
|
||||
Rules:
|
||||
- Output ONLY the JSON object. No prose, no explanation, no markdown fences, no <think> tokens.
|
||||
- The first character of your response must be `{` and the last must be `}`.
|
||||
- Every key marked `required` in the schema MUST appear in the output.
|
||||
- All values MUST match the types specified in the schema.
|
||||
- If the input is already a valid JSON object matching the schema, return it unchanged.
|
||||
- If a field cannot be determined from the input, use `null` (when allowed) or your best inferred value.
|
||||
- Do NOT invent fields not present in the schema.";
|
||||
|
||||
pub async fn extract(raw: &str, schema: &Value, parent_ctx: &mut RequestContext) -> Result<Value> {
|
||||
if let Some(parsed) = try_parse_json(raw) {
|
||||
return Ok(parsed);
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text("▸ structured-output: parsing raw output failed, invoking extractor")
|
||||
);
|
||||
extract_via_extractor(raw, schema, parent_ctx, false).await
|
||||
}
|
||||
|
||||
async fn extract_via_extractor(
|
||||
raw: &str,
|
||||
schema: &Value,
|
||||
parent_ctx: &mut RequestContext,
|
||||
is_repair: bool,
|
||||
) -> Result<Value> {
|
||||
let role = build_extractor_role()?;
|
||||
let prompt = build_extractor_prompt(raw, schema, is_repair);
|
||||
|
||||
let saved_role = parent_ctx.role.clone();
|
||||
parent_ctx.role = Some(role);
|
||||
let result = run_one_shot(&prompt, parent_ctx).await;
|
||||
parent_ctx.role = saved_role;
|
||||
|
||||
let output = result.context("Structured-output extractor LLM call failed")?;
|
||||
|
||||
match try_parse_json(&output) {
|
||||
Some(value) => Ok(value),
|
||||
None if is_repair => bail!(
|
||||
"Structured-output extractor failed to produce valid JSON after repair retry. \
|
||||
Last response:\n{output}"
|
||||
),
|
||||
None => {
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text("▸ structured-output: extractor returned invalid JSON, retrying")
|
||||
);
|
||||
Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_extractor_role() -> Result<Role> {
|
||||
let mut role = Role::new(EXTRACTOR_ROLE_NAME, EXTRACTOR_ROLE_PROMPT);
|
||||
role.set_enabled_tools(Some(String::new()));
|
||||
role.set_enabled_mcp_servers(Some(String::new()));
|
||||
Ok(role)
|
||||
}
|
||||
|
||||
fn build_extractor_prompt(raw: &str, schema: &Value, is_repair: bool) -> String {
|
||||
let schema_json = serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string());
|
||||
if is_repair {
|
||||
format!(
|
||||
"Your previous response was not valid JSON. Output ONLY a JSON object \
|
||||
matching this schema. No prose, no fences.\n\nSchema:\n{schema_json}\n\nInput:\n{raw}"
|
||||
)
|
||||
} else {
|
||||
format!("Schema:\n{schema_json}\n\nInput:\n{raw}")
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_one_shot(prompt: &str, ctx: &mut RequestContext) -> Result<String> {
|
||||
let abort = create_abort_signal();
|
||||
let app_cfg = Arc::clone(&ctx.app.config);
|
||||
let role_for_input = ctx.role.clone();
|
||||
let input = Input::from_str(ctx, prompt, role_for_input);
|
||||
let client = input.create_client()?;
|
||||
ctx.before_chat_completion(&input)?;
|
||||
let (output, tool_results) =
|
||||
call_chat_completions(&input, false, false, client.as_ref(), ctx, abort).await?;
|
||||
ctx.after_chat_completion(app_cfg.as_ref(), &input, &output, &tool_results)?;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn try_parse_json(raw: &str) -> Option<Value> {
|
||||
let cleaned = strip_code_fences(raw.trim());
|
||||
serde_json::from_str(cleaned).ok()
|
||||
}
|
||||
|
||||
fn strip_code_fences(s: &str) -> &str {
|
||||
let after_open = s
|
||||
.strip_prefix("```json")
|
||||
.or_else(|| s.strip_prefix("```"))
|
||||
.map(str::trim_start)
|
||||
.unwrap_or(s);
|
||||
after_open
|
||||
.strip_suffix("```")
|
||||
.map(str::trim_end)
|
||||
.unwrap_or(after_open)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn try_parse_json_accepts_plain_object() {
|
||||
let v = try_parse_json(r#"{"a": 1}"#).unwrap();
|
||||
assert_eq!(v, json!({"a": 1}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_parse_json_strips_json_fences() {
|
||||
let raw = "```json\n{\"a\": 1}\n```";
|
||||
let v = try_parse_json(raw).unwrap();
|
||||
assert_eq!(v, json!({"a": 1}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_parse_json_strips_bare_fences() {
|
||||
let raw = "```\n{\"a\": 1}\n```";
|
||||
let v = try_parse_json(raw).unwrap();
|
||||
assert_eq!(v, json!({"a": 1}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_parse_json_tolerates_whitespace() {
|
||||
let v = try_parse_json(" \n {\"x\": true}\n\n").unwrap();
|
||||
assert_eq!(v, json!({"x": true}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_parse_json_returns_none_on_prose() {
|
||||
assert!(try_parse_json("Here is the result: it's good").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_parse_json_returns_none_on_partial_json() {
|
||||
assert!(try_parse_json("{\"a\": ").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_parse_json_accepts_arrays() {
|
||||
let v = try_parse_json("[1, 2, 3]").unwrap();
|
||||
assert_eq!(v, json!([1, 2, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_extractor_prompt_includes_schema_and_input() {
|
||||
let schema = json!({"type": "object"});
|
||||
let prompt = build_extractor_prompt("hello", &schema, false);
|
||||
assert!(prompt.contains("Schema:"));
|
||||
assert!(prompt.contains("Input:"));
|
||||
assert!(prompt.contains("hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_extractor_prompt_repair_includes_repair_instruction() {
|
||||
let schema = json!({"type": "object"});
|
||||
let prompt = build_extractor_prompt("oops", &schema, true);
|
||||
assert!(prompt.contains("previous response"));
|
||||
assert!(prompt.contains("oops"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_extractor_role_disables_tools_and_mcp() {
|
||||
let role = build_extractor_role().expect("builtin role must exist");
|
||||
assert_eq!(role.enabled_tools().as_deref(), Some(""));
|
||||
assert_eq!(role.enabled_mcp_servers().as_deref(), Some(""));
|
||||
}
|
||||
}
|
||||
@@ -131,6 +131,13 @@ pub struct AgentNode {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub state_updates: Option<HashMap<String, String>>,
|
||||
|
||||
/// JSON Schema describing the expected shape of the agent's final
|
||||
/// output. When set, the agent's raw text is post-processed through
|
||||
/// a built-in structured-output extractor and parsed as JSON. Top-
|
||||
/// level keys of the parsed object are auto-merged into state.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub output_schema: Option<Value>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub timeout: Option<u64>,
|
||||
}
|
||||
@@ -256,6 +263,13 @@ pub struct LlmNode {
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub state_updates: Option<HashMap<String, String>>,
|
||||
|
||||
/// JSON Schema (as parsed JSON) describing the expected shape of the
|
||||
/// node's output. When set, the raw LLM response is post-processed
|
||||
/// through a built-in structured-output extractor and parsed as JSON.
|
||||
/// Top-level keys of the parsed object are auto-merged into state.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub output_schema: Option<Value>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub timeout: Option<u64>,
|
||||
}
|
||||
|
||||
@@ -411,6 +411,7 @@ mod tests {
|
||||
agent: agent.into(),
|
||||
prompt: "hi".into(),
|
||||
state_updates: None,
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
}),
|
||||
next: next.map(String::from),
|
||||
|
||||
Reference in New Issue
Block a user