diff --git a/src/graph/agent.rs b/src/graph/agent.rs index 997bb5f..8ec37ff 100644 --- a/src/graph/agent.rs +++ b/src/graph/agent.rs @@ -6,6 +6,7 @@ //! `{{output}}` for the agent's stdout). use super::state::StateManager; +use super::structured; use super::types::AgentNode; use crate::config::RequestContext; use crate::function::supervisor::run_agent_for_graph; @@ -42,7 +43,7 @@ impl AgentNodeExecutor { let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS)); - let output = timeout( + let raw = timeout( timeout_dur, run_agent_for_graph(parent_ctx, &node.agent, &prompt), ) @@ -56,9 +57,21 @@ impl AgentNodeExecutor { })? .with_context(|| format!("Agent '{}' failed", node.agent))?; - apply_state_updates(node, state_manager, &output); + let output_value = match &node.output_schema { + Some(schema) => structured::extract(&raw, schema, parent_ctx) + .await + .with_context(|| { + format!( + "Agent '{}' output failed structured-output extraction", + node.agent + ) + })?, + None => Value::String(raw.clone()), + }; - Ok(output) + apply_state_updates(node, state_manager, &output_value); + + Ok(raw) } } @@ -81,14 +94,26 @@ fn indent_prompt(prompt: &str, prefix_spaces: usize) -> Vec { /// applies every key/template in `state_updates`. The temporary `output` /// state key is removed at the end so it doesn't leak into subsequent /// nodes' templates. -fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &str) { +/// +/// When `node.output_schema` is set AND the output is a JSON object, its +/// top-level keys are also auto-merged into state permanently (before +/// state_updates evaluation, so explicit state_updates can override). +fn apply_state_updates(node: &AgentNode, state_manager: &mut StateManager, output: &Value) { + if node.output_schema.is_some() + && let Some(obj) = output.as_object() + { + for (k, v) in obj { + state_manager.state_mut().set(k.clone(), v.clone()); + } + } + let Some(updates) = &node.state_updates else { return; }; let prev_output = state_manager.state().get(OUTPUT_KEY).cloned(); state_manager .state_mut() - .set(OUTPUT_KEY.into(), Value::String(output.to_string())); + .set(OUTPUT_KEY.into(), output.clone()); for (key, template) in updates { let value = state_manager.interpolate_lenient(template); @@ -127,6 +152,7 @@ mod tests { agent: "test_agent".into(), prompt: prompt.into(), state_updates: updates, + output_schema: None, timeout: None, } } @@ -139,7 +165,7 @@ mod tests { node_with("hi", Some(u)) }; let mut state = manager_with(&[]); - apply_state_updates(&node, &mut state, "agent finished its work"); + apply_state_updates(&node, &mut state, &json!("agent finished its work")); assert_eq!( state.state().get("findings"), Some(&json!("agent finished its work")) @@ -154,7 +180,7 @@ mod tests { node_with("hi", Some(u)) }; let mut state = manager_with(&[("topic", json!("auth"))]); - apply_state_updates(&node, &mut state, "JWT vs sessions"); + apply_state_updates(&node, &mut state, &json!("JWT vs sessions")); assert_eq!( state.state().get("summary"), Some(&json!("auth: JWT vs sessions")) @@ -169,7 +195,7 @@ mod tests { node_with("hi", Some(u)) }; let mut state = manager_with(&[]); - apply_state_updates(&node, &mut state, "anything"); + apply_state_updates(&node, &mut state, &json!("anything")); assert_eq!(state.state().get("output"), Some(&Value::Null)); } @@ -181,7 +207,7 @@ mod tests { node_with("hi", Some(u)) }; let mut state = manager_with(&[("output", json!("preserved"))]); - apply_state_updates(&node, &mut state, "new agent output"); + apply_state_updates(&node, &mut state, &json!("new agent output")); assert_eq!( state.state().get("greeting"), Some(&json!("new agent output")) @@ -193,7 +219,7 @@ mod tests { fn no_state_updates_is_a_noop() { let node = node_with("hi", None); let mut state = manager_with(&[("k", json!("v"))]); - apply_state_updates(&node, &mut state, "ignored"); + apply_state_updates(&node, &mut state, &json!("ignored")); assert_eq!(state.state().get("k"), Some(&json!("v"))); assert!(state.state().get("output").is_none()); } @@ -206,7 +232,62 @@ mod tests { node_with("hi", Some(u)) }; let mut state = manager_with(&[]); - apply_state_updates(&node, &mut state, "DATA"); + apply_state_updates(&node, &mut state, &json!("DATA")); assert_eq!(state.state().get("decorated"), Some(&json!("[] DATA"))); } + + fn node_with_schema( + prompt: &str, + updates: Option>, + schema: Value, + ) -> AgentNode { + let mut n = node_with(prompt, updates); + n.output_schema = Some(schema); + n + } + + #[test] + fn output_schema_auto_merges_top_level_keys() { + let node = node_with_schema("hi", None, json!({"type": "object"})); + let mut state = manager_with(&[]); + let output = json!({"goal": "do X", "summary": "details"}); + apply_state_updates(&node, &mut state, &output); + assert_eq!(state.state().get("goal"), Some(&json!("do X"))); + assert_eq!(state.state().get("summary"), Some(&json!("details"))); + } + + #[test] + fn output_schema_preserves_nested_value_types() { + let node = node_with_schema("hi", None, json!({"type": "object"})); + let mut state = manager_with(&[]); + let output = json!({ + "tags": ["a", "b"], + "config": { "key": "value" }, + "count": 42 + }); + apply_state_updates(&node, &mut state, &output); + assert_eq!(state.state().get("tags"), Some(&json!(["a", "b"]))); + assert_eq!(state.state().get("config"), Some(&json!({"key": "value"}))); + assert_eq!(state.state().get("count"), Some(&json!(42))); + } + + #[test] + fn output_schema_explicit_state_updates_override_auto_merge() { + let mut u = HashMap::new(); + u.insert("goal".into(), "renamed-{{output.goal}}".into()); + let node = node_with_schema("hi", Some(u), json!({"type": "object"})); + let mut state = manager_with(&[]); + let output = json!({"goal": "do X"}); + apply_state_updates(&node, &mut state, &output); + assert_eq!(state.state().get("goal"), Some(&json!("renamed-do X"))); + } + + #[test] + fn no_schema_does_not_auto_merge() { + let node = node_with("hi", None); + let mut state = manager_with(&[]); + let output = json!({"goal": "do X"}); + apply_state_updates(&node, &mut state, &output); + assert!(state.state().get("goal").is_none()); + } } diff --git a/src/graph/llm.rs b/src/graph/llm.rs index d2828cb..b8376ba 100644 --- a/src/graph/llm.rs +++ b/src/graph/llm.rs @@ -5,11 +5,12 @@ //! `docs/implementation/graph-agents/10.5-llm-nodes.md` for the design. use super::state::StateManager; +use super::structured; use super::types::LlmNode; use crate::client::{Model, ModelType, call_chat_completions}; use crate::config::{RequestContext, Role, RoleLike}; use crate::utils::{create_abort_signal, dimmed_text}; -use anyhow::{Context, Result, anyhow, bail, Error}; +use anyhow::{Context, Error, Result, anyhow, bail}; use serde_json::Value; use std::collections::HashSet; use std::sync::Arc; @@ -33,10 +34,22 @@ impl LlmNodeExecutor { ) -> Result { let result = run(node, state_manager, parent_ctx).await; let (output, failed) = match result { - Ok(out) => (out, false), + Ok(raw) => match &node.output_schema { + Some(schema) => match structured::extract(&raw, schema, parent_ctx).await { + Ok(value) => (value, false), + Err(e) => { + warn!("llm node structured extraction failed: {e}"); + ( + Value::String(format!("LLM node structured-extraction failed: {e}")), + true, + ) + } + }, + None => (Value::String(raw), false), + }, Err(e) => { warn!("llm node failed: {e}"); - (format!("LLM node failed: {e}"), true) + (Value::String(format!("LLM node failed: {e}")), true) } }; apply_state_updates_with_output(node, state_manager, &output); @@ -49,7 +62,7 @@ async fn run( state_manager: &mut StateManager, parent_ctx: &mut RequestContext, ) -> Result { - let instructions: Option = match &node.instructions { + let mut instructions: Option = match &node.instructions { Some(s) => Some( state_manager .interpolate(s) @@ -57,10 +70,24 @@ async fn run( ), None => None, }; - let prompt = state_manager + let mut prompt = state_manager .interpolate(&node.prompt) .context("Failed to interpolate llm node prompt")?; + if let Some(schema) = &node.output_schema { + let hint = format_schema_hint(schema); + match instructions.as_mut() { + Some(s) => { + s.push_str("\n\n"); + s.push_str(&hint); + } + None => { + prompt.push_str("\n\n"); + prompt.push_str(&hint); + } + } + } + let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref()); validate_tools_subset(®ular_tools, &mcp_servers, parent_ctx)?; @@ -291,14 +318,30 @@ fn next_for_llm_node( /// of `state_updates` evaluation, then restore the prior value (or set /// it to `Null` if there wasn't one). Same pattern as /// `AgentNodeExecutor`'s `{{output}}` scoping. -fn apply_state_updates_with_output(node: &LlmNode, state_manager: &mut StateManager, output: &str) { +/// +/// When `node.output_schema` is set AND the output is a JSON object, its +/// top-level keys are also auto-merged into state permanently (before +/// state_updates evaluation, so explicit state_updates can override). +fn apply_state_updates_with_output( + node: &LlmNode, + state_manager: &mut StateManager, + output: &Value, +) { + if node.output_schema.is_some() + && let Some(obj) = output.as_object() + { + for (k, v) in obj { + state_manager.state_mut().set(k.clone(), v.clone()); + } + } + let Some(updates) = &node.state_updates else { return; }; let prev_output = state_manager.state().get(OUTPUT_KEY).cloned(); state_manager .state_mut() - .set(OUTPUT_KEY.into(), Value::String(output.to_string())); + .set(OUTPUT_KEY.into(), output.clone()); for (key, template) in updates { let value = state_manager.interpolate_lenient(template); @@ -317,6 +360,14 @@ fn apply_state_updates_with_output(node: &LlmNode, state_manager: &mut StateMana } } +fn format_schema_hint(schema: &Value) -> String { + let schema_json = serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string()); + format!( + "Respond with a JSON object that matches this schema. Output ONLY the JSON \ + object with no surrounding prose or markdown fences.\n\nSchema:\n{schema_json}" + ) +} + fn describe_tools_filter(tools: Option<&[String]>) -> String { match tools { None => "".into(), @@ -352,6 +403,7 @@ mod tests { max_attempts: 1, max_iterations: 10, state_updates: updates, + output_schema: None, timeout: None, } } @@ -362,7 +414,7 @@ mod tests { u.insert("response".into(), "{{output}}".into()); let node = node_with(Some(u)); let mut state = manager_with(&[]); - apply_state_updates_with_output(&node, &mut state, "the answer"); + apply_state_updates_with_output(&node, &mut state, &json!("the answer")); assert_eq!(state.state().get("response"), Some(&json!("the answer"))); } @@ -372,7 +424,7 @@ mod tests { u.insert("summary".into(), "{{topic}}: {{output}}".into()); let node = node_with(Some(u)); let mut state = manager_with(&[("topic", json!("LOINC"))]); - apply_state_updates_with_output(&node, &mut state, "abc"); + apply_state_updates_with_output(&node, &mut state, &json!("abc")); assert_eq!(state.state().get("summary"), Some(&json!("LOINC: abc"))); } @@ -382,7 +434,7 @@ mod tests { u.insert("k".into(), "{{output}}".into()); let node = node_with(Some(u)); let mut state = manager_with(&[]); - apply_state_updates_with_output(&node, &mut state, "anything"); + apply_state_updates_with_output(&node, &mut state, &json!("anything")); assert_eq!(state.state().get(OUTPUT_KEY), Some(&json!(null))); } @@ -392,7 +444,7 @@ mod tests { u.insert("greeting".into(), "{{output}}".into()); let node = node_with(Some(u)); let mut state = manager_with(&[("output", json!("preserved"))]); - apply_state_updates_with_output(&node, &mut state, "new"); + apply_state_updates_with_output(&node, &mut state, &json!("new")); assert_eq!(state.state().get("greeting"), Some(&json!("new"))); assert_eq!(state.state().get(OUTPUT_KEY), Some(&json!("preserved"))); } @@ -401,11 +453,82 @@ mod tests { fn no_state_updates_is_a_noop() { let node = node_with(None); let mut state = manager_with(&[("k", json!("v"))]); - apply_state_updates_with_output(&node, &mut state, "x"); + apply_state_updates_with_output(&node, &mut state, &json!("x")); assert_eq!(state.state().get("k"), Some(&json!("v"))); assert!(state.state().get(OUTPUT_KEY).is_none()); } + fn node_with_schema(updates: Option>, schema: Value) -> LlmNode { + let mut n = node_with(updates); + n.output_schema = Some(schema); + n + } + + #[test] + fn output_schema_auto_merges_top_level_keys() { + let node = node_with_schema(None, json!({"type": "object"})); + let mut state = manager_with(&[]); + let output = json!({"goal": "do X", "summary": "details"}); + apply_state_updates_with_output(&node, &mut state, &output); + assert_eq!(state.state().get("goal"), Some(&json!("do X"))); + assert_eq!(state.state().get("summary"), Some(&json!("details"))); + } + + #[test] + fn output_schema_preserves_nested_value_types() { + let node = node_with_schema(None, json!({"type": "object"})); + let mut state = manager_with(&[]); + let output = json!({ + "tags": ["a", "b"], + "config": { "key": "value" }, + "count": 42 + }); + apply_state_updates_with_output(&node, &mut state, &output); + assert_eq!(state.state().get("tags"), Some(&json!(["a", "b"]))); + assert_eq!(state.state().get("config"), Some(&json!({"key": "value"}))); + assert_eq!(state.state().get("count"), Some(&json!(42))); + } + + #[test] + fn output_schema_explicit_state_updates_override_auto_merge() { + let mut u = HashMap::new(); + u.insert("goal".into(), "renamed-{{output.goal}}".into()); + let node = node_with_schema(Some(u), json!({"type": "object"})); + let mut state = manager_with(&[]); + let output = json!({"goal": "do X"}); + apply_state_updates_with_output(&node, &mut state, &output); + assert_eq!(state.state().get("goal"), Some(&json!("renamed-do X"))); + } + + #[test] + fn output_schema_skips_auto_merge_for_non_object() { + let node = node_with_schema(None, json!({"type": "array"})); + let mut state = manager_with(&[]); + let output = json!([1, 2, 3]); + apply_state_updates_with_output(&node, &mut state, &output); + assert!(state.state().get("0").is_none()); + assert!(state.state().get(OUTPUT_KEY).is_none()); + } + + #[test] + fn no_schema_does_not_auto_merge() { + let node = node_with(None); + let mut state = manager_with(&[]); + let output = json!({"goal": "do X"}); + apply_state_updates_with_output(&node, &mut state, &output); + assert!(state.state().get("goal").is_none()); + } + + #[test] + fn format_schema_hint_includes_schema_and_instruction() { + let schema = json!({"type": "object", "properties": {"goal": {"type": "string"}}}); + let hint = format_schema_hint(&schema); + assert!(hint.contains("Schema:")); + assert!(hint.contains("\"goal\"")); + assert!(hint.contains("JSON")); + assert!(hint.contains("ONLY")); + } + #[test] fn describe_tools_filter_renders_each_case() { assert_eq!(describe_tools_filter(None), ""); diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 6a6f833..704f4e2 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -8,6 +8,7 @@ pub mod llm; pub mod parser; pub mod script; pub mod state; +pub mod structured; pub mod types; pub mod user_interaction; pub mod validator; diff --git a/src/graph/state.rs b/src/graph/state.rs index 0ec7396..31f435e 100644 --- a/src/graph/state.rs +++ b/src/graph/state.rs @@ -13,13 +13,14 @@ use std::path::PathBuf; use std::sync::LazyLock; static TEMPLATE_VAR_RE: LazyLock = - LazyLock::new(|| Regex::new(r"\{\{([a-zA-Z0-9_\.]+)\}\}").expect("invalid template regex")); + LazyLock::new(|| Regex::new(r"\{\{([a-zA-Z0-9_\.\[\]]+)\}\}").expect("invalid template regex")); /// Wraps [`GraphState`] with template interpolation, script-output merging, /// and a large-state temp-file fallback for use with scripts. /// /// Template syntax: `{{key}}` for top-level keys, `{{a.b.c}}` for nested -/// JSON paths. Use [`StateManager::interpolate`] for strict interpolation +/// JSON paths, and `{{arr[0]}}` / `{{a.b[2].c}}` / `{{matrix[0][1]}}` for +/// array indices. Use [`StateManager::interpolate`] for strict interpolation /// (errors on missing keys) or [`StateManager::interpolate_lenient`] for /// best-effort (missing keys become empty strings). pub struct StateManager { @@ -89,10 +90,20 @@ impl StateManager { fn get_nested_value(&self, key: &str) -> Option<&Value> { let mut parts = key.split('.'); - let root = parts.next()?; - let mut current = self.state.get(root)?; + let first = parts.next()?; + let (root_key, root_indices) = split_indices(first)?; + let mut current = self.state.get(root_key)?; + for idx in root_indices { + current = current.get(idx)?; + } for part in parts { - current = current.get(part)?; + let (segment_key, indices) = split_indices(part)?; + if !segment_key.is_empty() { + current = current.get(segment_key)?; + } + for idx in indices { + current = current.get(idx)?; + } } Some(current) } @@ -203,6 +214,26 @@ impl StateRepresentation { } } +fn split_indices(segment: &str) -> Option<(&str, Vec)> { + let bracket_start = segment.find('['); + let key = match bracket_start { + Some(i) => &segment[..i], + None => return Some((segment, Vec::new())), + }; + let mut indices = Vec::new(); + let mut rest = &segment[bracket_start.unwrap()..]; + while !rest.is_empty() { + if !rest.starts_with('[') { + return None; + } + let close = rest.find(']')?; + let idx: usize = rest[1..close].parse().ok()?; + indices.push(idx); + rest = &rest[close + 1..]; + } + Some((key, indices)) +} + fn value_to_string(value: &Value) -> String { match value { Value::String(s) => s.clone(), @@ -321,6 +352,45 @@ mod tests { assert_eq!(result, r#"{"key":"value"}"#); } + #[test] + fn interpolates_array_indices() { + let manager = manager_with(&[("items", json!(["a", "b", "c"]))]); + assert_eq!(manager.interpolate("{{items[0]}}").unwrap(), "a"); + assert_eq!(manager.interpolate("{{items[2]}}").unwrap(), "c"); + } + + #[test] + fn interpolates_array_indices_inside_nested_paths() { + let manager = manager_with(&[("outer", json!({ "inner": { "arr": ["x", "y", "z"] } }))]); + let result = manager + .interpolate("first={{outer.inner.arr[0]}} last={{outer.inner.arr[2]}}") + .unwrap(); + assert_eq!(result, "first=x last=z"); + } + + #[test] + fn interpolates_object_fields_after_array_index() { + let manager = manager_with(&[("users", json!([{ "name": "Alice" }, { "name": "Bob" }]))]); + let result = manager + .interpolate("{{users[0].name}} and {{users[1].name}}") + .unwrap(); + assert_eq!(result, "Alice and Bob"); + } + + #[test] + fn interpolates_nested_array_indices() { + let manager = manager_with(&[("matrix", json!([[1, 2], [3, 4]]))]); + assert_eq!(manager.interpolate("{{matrix[0][1]}}").unwrap(), "2"); + assert_eq!(manager.interpolate("{{matrix[1][0]}}").unwrap(), "3"); + } + + #[test] + fn out_of_bounds_array_index_is_missing() { + let manager = manager_with(&[("items", json!(["a", "b"]))]); + let err = manager.interpolate("{{items[5]}}").unwrap_err().to_string(); + assert!(err.contains("not found"), "got: {err}"); + } + #[test] fn replaces_all_occurrences_of_same_key() { let manager = manager_with(&[("n", json!("Alice"))]); diff --git a/src/graph/structured.rs b/src/graph/structured.rs new file mode 100644 index 0000000..c311341 --- /dev/null +++ b/src/graph/structured.rs @@ -0,0 +1,194 @@ +//! Structured-output extraction for `llm` and `agent` nodes. Takes the +//! raw final text of a node and a user-supplied JSON Schema, and returns +//! a parsed [`serde_json::Value`] conforming to that schema (best-effort). +//! +//! Strategy: try to parse `raw` directly first (with light cleanup of +//! markdown fences), and only invoke a follow-up LLM call against the +//! built-in `structured-output` role if direct parsing fails. On +//! extractor-output parse failure, perform one repair retry. + +use crate::client::call_chat_completions; +use crate::config::{Input, RequestContext, Role, RoleLike}; +use crate::utils::{create_abort_signal, dimmed_text}; +use anyhow::{Context, Result, bail}; +use serde_json::Value; +use std::sync::Arc; + +const EXTRACTOR_ROLE_NAME: &str = "__structured_output__"; + +const EXTRACTOR_ROLE_PROMPT: &str = "\ +Extract a JSON object from the user's input that strictly conforms to the provided JSON Schema. + +Rules: +- Output ONLY the JSON object. No prose, no explanation, no markdown fences, no tokens. +- The first character of your response must be `{` and the last must be `}`. +- Every key marked `required` in the schema MUST appear in the output. +- All values MUST match the types specified in the schema. +- If the input is already a valid JSON object matching the schema, return it unchanged. +- If a field cannot be determined from the input, use `null` (when allowed) or your best inferred value. +- Do NOT invent fields not present in the schema."; + +pub async fn extract(raw: &str, schema: &Value, parent_ctx: &mut RequestContext) -> Result { + if let Some(parsed) = try_parse_json(raw) { + return Ok(parsed); + } + + eprintln!( + "{}", + dimmed_text("▸ structured-output: parsing raw output failed, invoking extractor") + ); + extract_via_extractor(raw, schema, parent_ctx, false).await +} + +async fn extract_via_extractor( + raw: &str, + schema: &Value, + parent_ctx: &mut RequestContext, + is_repair: bool, +) -> Result { + let role = build_extractor_role()?; + let prompt = build_extractor_prompt(raw, schema, is_repair); + + let saved_role = parent_ctx.role.clone(); + parent_ctx.role = Some(role); + let result = run_one_shot(&prompt, parent_ctx).await; + parent_ctx.role = saved_role; + + let output = result.context("Structured-output extractor LLM call failed")?; + + match try_parse_json(&output) { + Some(value) => Ok(value), + None if is_repair => bail!( + "Structured-output extractor failed to produce valid JSON after repair retry. \ + Last response:\n{output}" + ), + None => { + eprintln!( + "{}", + dimmed_text("▸ structured-output: extractor returned invalid JSON, retrying") + ); + Box::pin(extract_via_extractor(&output, schema, parent_ctx, true)).await + } + } +} + +fn build_extractor_role() -> Result { + let mut role = Role::new(EXTRACTOR_ROLE_NAME, EXTRACTOR_ROLE_PROMPT); + role.set_enabled_tools(Some(String::new())); + role.set_enabled_mcp_servers(Some(String::new())); + Ok(role) +} + +fn build_extractor_prompt(raw: &str, schema: &Value, is_repair: bool) -> String { + let schema_json = serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string()); + if is_repair { + format!( + "Your previous response was not valid JSON. Output ONLY a JSON object \ + matching this schema. No prose, no fences.\n\nSchema:\n{schema_json}\n\nInput:\n{raw}" + ) + } else { + format!("Schema:\n{schema_json}\n\nInput:\n{raw}") + } +} + +async fn run_one_shot(prompt: &str, ctx: &mut RequestContext) -> Result { + let abort = create_abort_signal(); + let app_cfg = Arc::clone(&ctx.app.config); + let role_for_input = ctx.role.clone(); + let input = Input::from_str(ctx, prompt, role_for_input); + let client = input.create_client()?; + ctx.before_chat_completion(&input)?; + let (output, tool_results) = + call_chat_completions(&input, false, false, client.as_ref(), ctx, abort).await?; + ctx.after_chat_completion(app_cfg.as_ref(), &input, &output, &tool_results)?; + Ok(output) +} + +fn try_parse_json(raw: &str) -> Option { + let cleaned = strip_code_fences(raw.trim()); + serde_json::from_str(cleaned).ok() +} + +fn strip_code_fences(s: &str) -> &str { + let after_open = s + .strip_prefix("```json") + .or_else(|| s.strip_prefix("```")) + .map(str::trim_start) + .unwrap_or(s); + after_open + .strip_suffix("```") + .map(str::trim_end) + .unwrap_or(after_open) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn try_parse_json_accepts_plain_object() { + let v = try_parse_json(r#"{"a": 1}"#).unwrap(); + assert_eq!(v, json!({"a": 1})); + } + + #[test] + fn try_parse_json_strips_json_fences() { + let raw = "```json\n{\"a\": 1}\n```"; + let v = try_parse_json(raw).unwrap(); + assert_eq!(v, json!({"a": 1})); + } + + #[test] + fn try_parse_json_strips_bare_fences() { + let raw = "```\n{\"a\": 1}\n```"; + let v = try_parse_json(raw).unwrap(); + assert_eq!(v, json!({"a": 1})); + } + + #[test] + fn try_parse_json_tolerates_whitespace() { + let v = try_parse_json(" \n {\"x\": true}\n\n").unwrap(); + assert_eq!(v, json!({"x": true})); + } + + #[test] + fn try_parse_json_returns_none_on_prose() { + assert!(try_parse_json("Here is the result: it's good").is_none()); + } + + #[test] + fn try_parse_json_returns_none_on_partial_json() { + assert!(try_parse_json("{\"a\": ").is_none()); + } + + #[test] + fn try_parse_json_accepts_arrays() { + let v = try_parse_json("[1, 2, 3]").unwrap(); + assert_eq!(v, json!([1, 2, 3])); + } + + #[test] + fn build_extractor_prompt_includes_schema_and_input() { + let schema = json!({"type": "object"}); + let prompt = build_extractor_prompt("hello", &schema, false); + assert!(prompt.contains("Schema:")); + assert!(prompt.contains("Input:")); + assert!(prompt.contains("hello")); + } + + #[test] + fn build_extractor_prompt_repair_includes_repair_instruction() { + let schema = json!({"type": "object"}); + let prompt = build_extractor_prompt("oops", &schema, true); + assert!(prompt.contains("previous response")); + assert!(prompt.contains("oops")); + } + + #[test] + fn build_extractor_role_disables_tools_and_mcp() { + let role = build_extractor_role().expect("builtin role must exist"); + assert_eq!(role.enabled_tools().as_deref(), Some("")); + assert_eq!(role.enabled_mcp_servers().as_deref(), Some("")); + } +} diff --git a/src/graph/types.rs b/src/graph/types.rs index 3e973b7..c54130c 100644 --- a/src/graph/types.rs +++ b/src/graph/types.rs @@ -131,6 +131,13 @@ pub struct AgentNode { #[serde(default, skip_serializing_if = "Option::is_none")] pub state_updates: Option>, + /// JSON Schema describing the expected shape of the agent's final + /// output. When set, the agent's raw text is post-processed through + /// a built-in structured-output extractor and parsed as JSON. Top- + /// level keys of the parsed object are auto-merged into state. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub output_schema: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub timeout: Option, } @@ -256,6 +263,13 @@ pub struct LlmNode { #[serde(default, skip_serializing_if = "Option::is_none")] pub state_updates: Option>, + /// JSON Schema (as parsed JSON) describing the expected shape of the + /// node's output. When set, the raw LLM response is post-processed + /// through a built-in structured-output extractor and parsed as JSON. + /// Top-level keys of the parsed object are auto-merged into state. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub output_schema: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub timeout: Option, } diff --git a/src/graph/validator.rs b/src/graph/validator.rs index 73b6778..6b4be8f 100644 --- a/src/graph/validator.rs +++ b/src/graph/validator.rs @@ -411,6 +411,7 @@ mod tests { agent: agent.into(), prompt: "hi".into(), state_updates: None, + output_schema: None, timeout: None, }), next: next.map(String::from),