feat: added structured-output extraction for llm and agent nodes
This commit is contained in:
+92
-11
@@ -6,6 +6,7 @@
|
||||
//! `{{output}}` for the agent's stdout).
|
||||
|
||||
use super::state::StateManager;
|
||||
use super::structured;
|
||||
use super::types::AgentNode;
|
||||
use crate::config::RequestContext;
|
||||
use crate::function::supervisor::run_agent_for_graph;
|
||||
@@ -42,7 +43,7 @@ impl AgentNodeExecutor {
|
||||
|
||||
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS));
|
||||
|
||||
let output = timeout(
|
||||
let raw = timeout(
|
||||
timeout_dur,
|
||||
run_agent_for_graph(parent_ctx, &node.agent, &prompt),
|
||||
)
|
||||
@@ -56,9 +57,21 @@ impl AgentNodeExecutor {
|
||||
})?
|
||||
.with_context(|| format!("Agent '{}' failed", node.agent))?;
|
||||
|
||||
apply_state_updates(node, state_manager, &output);
|
||||
let output_value = match &node.output_schema {
|
||||
Some(schema) => structured::extract(&raw, schema, parent_ctx)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Agent '{}' output failed structured-output extraction",
|
||||
node.agent
|
||||
)
|
||||
})?,
|
||||
None => Value::String(raw.clone()),
|
||||
};
|
||||
|
||||
Ok(output)
|
||||
apply_state_updates(node, state_manager, &output_value);
|
||||
|
||||
Ok(raw)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,14 +94,26 @@ fn indent_prompt(prompt: &str, prefix_spaces: usize) -> Vec<String> {
|
||||
/// applies every key/template in `state_updates`. The temporary `output`
|
||||
/// state key is removed at the end so it doesn't leak into subsequent
|
||||
/// nodes' templates.
|
||||
fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &str) {
|
||||
///
|
||||
/// When `node.output_schema` is set AND the output is a JSON object, its
|
||||
/// top-level keys are also auto-merged into state permanently (before
|
||||
/// state_updates evaluation, so explicit state_updates can override).
|
||||
fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &Value) {
|
||||
if node.output_schema.is_some()
|
||||
&& let Some(obj) = output.as_object()
|
||||
{
|
||||
for (k, v) in obj {
|
||||
state_manager.state_mut().set(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let Some(updates) = &node.state_updates else {
|
||||
return;
|
||||
};
|
||||
let prev_output = state_manager.state().get(OUTPUT_KEY).cloned();
|
||||
state_manager
|
||||
.state_mut()
|
||||
.set(OUTPUT_KEY.into(), Value::String(output.to_string()));
|
||||
.set(OUTPUT_KEY.into(), output.clone());
|
||||
|
||||
for (key, template) in updates {
|
||||
let value = state_manager.interpolate_lenient(template);
|
||||
@@ -127,6 +152,7 @@ mod tests {
|
||||
agent: "test_agent".into(),
|
||||
prompt: prompt.into(),
|
||||
state_updates: updates,
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
}
|
||||
}
|
||||
@@ -139,7 +165,7 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates(&node, &mut state, "agent finished its work");
|
||||
apply_state_updates(&node, &mut state, &json!("agent finished its work"));
|
||||
assert_eq!(
|
||||
state.state().get("findings"),
|
||||
Some(&json!("agent finished its work"))
|
||||
@@ -154,7 +180,7 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[("topic", json!("auth"))]);
|
||||
apply_state_updates(&node, &mut state, "JWT vs sessions");
|
||||
apply_state_updates(&node, &mut state, &json!("JWT vs sessions"));
|
||||
assert_eq!(
|
||||
state.state().get("summary"),
|
||||
Some(&json!("auth: JWT vs sessions"))
|
||||
@@ -169,7 +195,7 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates(&node, &mut state, "anything");
|
||||
apply_state_updates(&node, &mut state, &json!("anything"));
|
||||
assert_eq!(state.state().get("output"), Some(&Value::Null));
|
||||
}
|
||||
|
||||
@@ -181,7 +207,7 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[("output", json!("preserved"))]);
|
||||
apply_state_updates(&node, &mut state, "new agent output");
|
||||
apply_state_updates(&node, &mut state, &json!("new agent output"));
|
||||
assert_eq!(
|
||||
state.state().get("greeting"),
|
||||
Some(&json!("new agent output"))
|
||||
@@ -193,7 +219,7 @@ mod tests {
|
||||
fn no_state_updates_is_a_noop() {
|
||||
let node = node_with("hi", None);
|
||||
let mut state = manager_with(&[("k", json!("v"))]);
|
||||
apply_state_updates(&node, &mut state, "ignored");
|
||||
apply_state_updates(&node, &mut state, &json!("ignored"));
|
||||
assert_eq!(state.state().get("k"), Some(&json!("v")));
|
||||
assert!(state.state().get("output").is_none());
|
||||
}
|
||||
@@ -206,7 +232,62 @@ mod tests {
|
||||
node_with("hi", Some(u))
|
||||
};
|
||||
let mut state = manager_with(&[]);
|
||||
apply_state_updates(&node, &mut state, "DATA");
|
||||
apply_state_updates(&node, &mut state, &json!("DATA"));
|
||||
assert_eq!(state.state().get("decorated"), Some(&json!("[] DATA")));
|
||||
}
|
||||
|
||||
fn node_with_schema(
|
||||
prompt: &str,
|
||||
updates: Option<HashMap<String, String>>,
|
||||
schema: Value,
|
||||
) -> AgentNode {
|
||||
let mut n = node_with(prompt, updates);
|
||||
n.output_schema = Some(schema);
|
||||
n
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_auto_merges_top_level_keys() {
|
||||
let node = node_with_schema("hi", None, json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X", "summary": "details"});
|
||||
apply_state_updates(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("goal"), Some(&json!("do X")));
|
||||
assert_eq!(state.state().get("summary"), Some(&json!("details")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_preserves_nested_value_types() {
|
||||
let node = node_with_schema("hi", None, json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({
|
||||
"tags": ["a", "b"],
|
||||
"config": { "key": "value" },
|
||||
"count": 42
|
||||
});
|
||||
apply_state_updates(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("tags"), Some(&json!(["a", "b"])));
|
||||
assert_eq!(state.state().get("config"), Some(&json!({"key": "value"})));
|
||||
assert_eq!(state.state().get("count"), Some(&json!(42)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn output_schema_explicit_state_updates_override_auto_merge() {
|
||||
let mut u = HashMap::new();
|
||||
u.insert("goal".into(), "renamed-{{output.goal}}".into());
|
||||
let node = node_with_schema("hi", Some(u), json!({"type": "object"}));
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X"});
|
||||
apply_state_updates(&node, &mut state, &output);
|
||||
assert_eq!(state.state().get("goal"), Some(&json!("renamed-do X")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_schema_does_not_auto_merge() {
|
||||
let node = node_with("hi", None);
|
||||
let mut state = manager_with(&[]);
|
||||
let output = json!({"goal": "do X"});
|
||||
apply_state_updates(&node, &mut state, &output);
|
||||
assert!(state.state().get("goal").is_none());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user