From 5f044cab2ba9035d761290c0c7384d0dd2d917bd Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Thu, 14 May 2026 11:57:18 -0600 Subject: [PATCH] feat: scaffolded together the initial llm node type and its executor --- src/graph/executor.rs | 37 +++++++- src/graph/llm.rs | 187 +++++++++++++++++++++++++++++++++++++++++ src/graph/mod.rs | 6 +- src/graph/types.rs | 150 +++++++++++++++++++++++++++++++++ src/graph/validator.rs | 5 ++ 5 files changed, 382 insertions(+), 3 deletions(-) create mode 100644 src/graph/llm.rs diff --git a/src/graph/executor.rs b/src/graph/executor.rs index 76f359c..525d12f 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -7,10 +7,11 @@ //! template as the graph's return value. use super::agent::AgentNodeExecutor; +use super::llm::{self, LlmNodeExecutor}; use super::parser::GraphParser; use super::script::ScriptExecutor; use super::state::StateManager; -use super::types::{EndNode, Graph, Node, NodeType}; +use super::types::{EndNode, Graph, LlmNode, Node, NodeType}; use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor}; use super::validator::GraphValidator; use crate::config::RequestContext; @@ -169,6 +170,7 @@ fn node_type_label(node: &Node) -> &'static str { NodeType::Script(_) => "script", NodeType::Approval(_) => "approval", NodeType::Input(_) => "input", + NodeType::Llm(_) => "llm", NodeType::End(_) => "end", } } @@ -217,10 +219,43 @@ async fn step( InputNodeExecutor::execute(input_node, node.next.as_deref(), state, ctx).await?; Ok(StepResult::Continue(next)) } + NodeType::Llm(llm_node) => { + let result = LlmNodeExecutor::execute(llm_node, state, ctx).await; + let (output, failed) = match result { + Ok(out) => (out, false), + Err(e) => { + warn!("[graph:{}] llm node '{}' failed: {e}", graph_name, current); + (format!("LLM node failed: {e}"), true) + } + }; + apply_state_updates_with_llm_output(llm_node, state, &output); + let next = next_for_llm_node(node, failed, llm_node.fallback.as_deref())?; + Ok(StepResult::Continue(next)) + } NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))), } } +fn next_for_llm_node(node: &Node, failed: bool, fallback: Option<&str>) -> Result { + if failed && let Some(fb) = fallback { + return Ok(fb.to_string()); + } + node.next.clone().ok_or_else(|| { + anyhow!( + "llm node '{}' has no `next` set; llm nodes need static routing", + node.id + ) + }) +} + +fn apply_state_updates_with_llm_output( + node: &super::types::LlmNode, + state: &mut StateManager, + output: &str, +) { + crate::graph::llm::apply_state_updates_with_output(node, state, output); +} + /// Apply the end node's `state_updates`, then interpolate its `output` /// template against the resulting state. Both use lenient interpolation /// so the graph still produces a result even when some keys are absent. diff --git a/src/graph/llm.rs b/src/graph/llm.rs new file mode 100644 index 0000000..ca31353 --- /dev/null +++ b/src/graph/llm.rs @@ -0,0 +1,187 @@ +//! Execution of `llm`-type graph nodes — one-shot LLM calls with a +//! bounded tool-call loop, an opt-in tool whitelist, and per-node +//! overrides for model/temperature/top_p. +//! +//! See `docs/implementation/graph-agents/10.5-llm-nodes.md` for the +//! design. The current implementation provides the routing and +//! state-update plumbing; the actual call_chat_completions loop lives +//! in `run_llm_once` and is the next implementation step. Calling +//! `LlmNodeExecutor::execute` today produces a controlled error so the +//! tolerant-fail routing in the executor still flows. + +use super::state::StateManager; +use super::types::LlmNode; +use crate::config::RequestContext; +use crate::utils::dimmed_text; +use anyhow::{Context, Result, bail}; +use serde_json::Value; + +const OUTPUT_KEY: &str = "output"; + +pub struct LlmNodeExecutor; + +impl LlmNodeExecutor { + /// Interpolate the node's templates, run the LLM call, then return + /// the model's final response. State updates are applied by the + /// graph executor (which knows whether to use the success path or + /// the failure path). + pub async fn execute( + node: &LlmNode, + state_manager: &mut StateManager, + _parent_ctx: &mut RequestContext, + ) -> Result { + let _instructions = state_manager + .interpolate(&node.instructions) + .context("Failed to interpolate llm node instructions")?; + let _prompt = state_manager + .interpolate(&node.prompt) + .context("Failed to interpolate llm node prompt")?; + + eprintln!( + "{}", + dimmed_text(&format!( + "▸ llm call: model={} tools={}", + node.model.as_deref().unwrap_or(""), + describe_tools_filter(node.tools.as_deref()) + )) + ); + + bail!( + "llm node execution body not yet implemented — see \ + docs/implementation/graph-agents/10.5-llm-nodes.md \ + (steps 3 & 5 of the implementation order)" + ); + } +} + +/// Expose the LLM call's final output as `{{output}}` for the duration +/// of `state_updates` evaluation, then restore the prior value (or set +/// it to `Null` if there wasn't one). Same pattern as +/// `AgentNodeExecutor`'s `{{output}}` scoping. +pub fn apply_state_updates_with_output( + node: &LlmNode, + state_manager: &mut StateManager, + output: &str, +) { + let Some(updates) = &node.state_updates else { + return; + }; + let prev_output = state_manager.state().get(OUTPUT_KEY).cloned(); + state_manager + .state_mut() + .set(OUTPUT_KEY.into(), Value::String(output.to_string())); + + for (key, template) in updates { + let value = state_manager.interpolate_lenient(template); + state_manager + .state_mut() + .set(key.clone(), Value::String(value)); + } + + match prev_output { + Some(v) => state_manager.state_mut().set(OUTPUT_KEY.into(), v), + None => { + state_manager + .state_mut() + .set(OUTPUT_KEY.into(), Value::Null); + } + } +} + +fn describe_tools_filter(tools: Option<&[String]>) -> String { + match tools { + None => "".into(), + Some(t) if t.is_empty() => "".into(), + Some(t) => t.join(","), + } +} + +#[cfg(test)] +mod tests { + use super::super::types::*; + use super::*; + use serde_json::json; + use std::collections::HashMap; + + fn manager_with(pairs: &[(&str, Value)]) -> StateManager { + let mut map = HashMap::new(); + for (k, v) in pairs { + map.insert((*k).into(), v.clone()); + } + StateManager::new(map) + } + + fn node_with(updates: Option>) -> LlmNode { + LlmNode { + instructions: "sys".into(), + prompt: "user".into(), + tools: None, + model: None, + temperature: None, + top_p: None, + fallback: None, + max_attempts: 1, + max_iterations: 10, + state_updates: updates, + timeout: None, + } + } + + #[test] + fn state_updates_expose_output_during_evaluation() { + let mut u = HashMap::new(); + u.insert("response".into(), "{{output}}".into()); + let node = node_with(Some(u)); + let mut state = manager_with(&[]); + apply_state_updates_with_output(&node, &mut state, "the answer"); + assert_eq!(state.state().get("response"), Some(&json!("the answer"))); + } + + #[test] + fn state_updates_can_mix_existing_keys_with_output() { + let mut u = HashMap::new(); + u.insert("summary".into(), "{{topic}}: {{output}}".into()); + let node = node_with(Some(u)); + let mut state = manager_with(&[("topic", json!("LOINC"))]); + apply_state_updates_with_output(&node, &mut state, "abc"); + assert_eq!(state.state().get("summary"), Some(&json!("LOINC: abc"))); + } + + #[test] + fn output_key_is_cleared_after_state_updates() { + let mut u = HashMap::new(); + u.insert("k".into(), "{{output}}".into()); + let node = node_with(Some(u)); + let mut state = manager_with(&[]); + apply_state_updates_with_output(&node, &mut state, "anything"); + assert_eq!(state.state().get(OUTPUT_KEY), Some(&json!(null))); + } + + #[test] + fn pre_existing_output_value_is_restored() { + let mut u = HashMap::new(); + u.insert("greeting".into(), "{{output}}".into()); + let node = node_with(Some(u)); + let mut state = manager_with(&[("output", json!("preserved"))]); + apply_state_updates_with_output(&node, &mut state, "new"); + assert_eq!(state.state().get("greeting"), Some(&json!("new"))); + assert_eq!(state.state().get(OUTPUT_KEY), Some(&json!("preserved"))); + } + + #[test] + fn no_state_updates_is_a_noop() { + let node = node_with(None); + let mut state = manager_with(&[("k", json!("v"))]); + apply_state_updates_with_output(&node, &mut state, "x"); + assert_eq!(state.state().get("k"), Some(&json!("v"))); + assert!(state.state().get(OUTPUT_KEY).is_none()); + } + + #[test] + fn describe_tools_filter_renders_each_case() { + assert_eq!(describe_tools_filter(None), ""); + assert_eq!(describe_tools_filter(Some(&[])), ""); + let tools = vec!["a".to_string(), "b".to_string()]; + assert_eq!(describe_tools_filter(Some(&tools)), "a,b"); + } +} diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 06947cc..6a6f833 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -4,6 +4,7 @@ pub mod agent; pub mod dispatch; pub mod executor; +pub mod llm; pub mod parser; pub mod script; pub mod state; @@ -14,12 +15,13 @@ pub mod validator; pub use agent::AgentNodeExecutor; pub use dispatch::{active_agent_graph_name, run_active_agent_graph}; pub use executor::GraphExecutor; +pub use llm::LlmNodeExecutor; pub use parser::{GraphParser, agent_has_graph, load_agent_graph}; pub use script::ScriptExecutor; pub use state::{StateManager, StateRepresentation}; pub use types::{ - AgentNode, ApprovalNode, EndNode, Graph, GraphSettings, GraphState, InputNode, Node, NodeType, - ScriptNode, + AgentNode, ApprovalNode, EndNode, Graph, GraphSettings, GraphState, InputNode, LlmNode, Node, + NodeType, ScriptNode, }; pub use user_interaction::{ApprovalNodeExecutor, InputNodeExecutor}; pub use validator::{GraphValidator, ValidationError, ValidationResult}; diff --git a/src/graph/types.rs b/src/graph/types.rs index e1dbe3f..c46f5c3 100644 --- a/src/graph/types.rs +++ b/src/graph/types.rs @@ -112,6 +112,7 @@ pub enum NodeType { Script(ScriptNode), Approval(ApprovalNode), Input(InputNode), + Llm(LlmNode), End(EndNode), } @@ -199,6 +200,66 @@ pub struct InputNode { pub on_timeout: Option, } +/// `llm`-type node: a one-shot LLM call (with bounded tool-call loop) +/// against a caller-supplied system prompt + user prompt. Unlike +/// `agent`-type nodes, this does NOT spawn a sub-agent; it runs in a +/// fresh isolated context. Tool access is opt-in via the `tools` +/// whitelist (no tools when unset). +/// +/// Routing (tolerant-fail): +/// - success → `Node.next` +/// - failure WITH fallback → `fallback` +/// - failure WITHOUT fallback → `Node.next` +/// +/// `state_updates` are always applied. `{{output}}` resolves to the +/// LLM's response on success, or to an error description on failure. +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct LlmNode { + pub instructions: String, + + pub prompt: String, + + /// Whitelist of tool names. Each entry is either an exact function + /// name or the shorthand `mcp:` (expands to the three MCP + /// meta-functions for that server). Unset = no tools. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub model: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub fallback: Option, + + /// Number of attempts on transient errors. Default 1 = no retries. + #[serde(default = "default_llm_max_attempts")] + pub max_attempts: u32, + + /// Hard cap on tool-call-loop turns within a single attempt. + #[serde(default = "default_llm_max_iterations")] + pub max_iterations: u32, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub state_updates: Option>, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +fn default_llm_max_attempts() -> u32 { + 1 +} + +fn default_llm_max_iterations() -> u32 { + 10 +} + /// `end`-type node: terminate execution; `output` (templated) is returned /// as the graph's final result. #[derive(Debug, Clone, Deserialize, Serialize)] @@ -548,4 +609,93 @@ routes: assert_eq!(state.get("user"), Some(&json!("alice"))); assert!(state.history().is_empty()); } + + #[test] + fn llm_node_with_all_fields() { + let yaml = r#" +id: classify +type: llm +instructions: "You are a classifier." +prompt: "Classify: {{input_text}}" +tools: + - read_query + - "mcp:pubmed-search" +model: anthropic:claude-3-5-haiku-20241022 +temperature: 0.0 +top_p: 0.5 +fallback: skip_classify +max_attempts: 3 +max_iterations: 5 +state_updates: + category: "{{output}}" +timeout: 30 +next: review +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + let llm = match node.node_type { + NodeType::Llm(l) => l, + _ => panic!("expected Llm variant"), + }; + assert_eq!(llm.instructions, "You are a classifier."); + assert_eq!(llm.prompt, "Classify: {{input_text}}"); + let tools = llm.tools.unwrap(); + assert_eq!(tools, vec!["read_query", "mcp:pubmed-search"]); + assert_eq!( + llm.model.as_deref(), + Some("anthropic:claude-3-5-haiku-20241022") + ); + assert_eq!(llm.temperature, Some(0.0)); + assert_eq!(llm.top_p, Some(0.5)); + assert_eq!(llm.fallback.as_deref(), Some("skip_classify")); + assert_eq!(llm.max_attempts, 3); + assert_eq!(llm.max_iterations, 5); + assert_eq!(llm.timeout, Some(30)); + assert!(llm.state_updates.is_some()); + assert_eq!(node.next.as_deref(), Some("review")); + } + + #[test] + fn llm_node_minimal_fields_use_defaults() { + let yaml = r#" +id: pure_text +type: llm +instructions: "System." +prompt: "User." +next: done +"#; + let node: Node = serde_yaml::from_str(yaml).unwrap(); + let llm = match node.node_type { + NodeType::Llm(l) => l, + _ => panic!("expected Llm variant"), + }; + assert_eq!(llm.instructions, "System."); + assert_eq!(llm.prompt, "User."); + assert!(llm.tools.is_none()); + assert!(llm.model.is_none()); + assert!(llm.fallback.is_none()); + assert_eq!(llm.max_attempts, 1); + assert_eq!(llm.max_iterations, 10); + } + + #[test] + fn llm_node_missing_instructions_fails() { + let yaml = r#" +id: bad +type: llm +prompt: "User only — no system prompt." +"#; + let result: std::result::Result = serde_yaml::from_str(yaml); + assert!(result.is_err()); + } + + #[test] + fn llm_node_missing_prompt_fails() { + let yaml = r#" +id: bad +type: llm +instructions: "System only — no user prompt." +"#; + let result: std::result::Result = serde_yaml::from_str(yaml); + assert!(result.is_err()); + } } diff --git a/src/graph/validator.rs b/src/graph/validator.rs index bb0c590..73b6778 100644 --- a/src/graph/validator.rs +++ b/src/graph/validator.rs @@ -262,6 +262,11 @@ fn declared_targets(node: &Node) -> Vec<(String, &'static str)> { out.push((t.clone(), "'on_timeout'")); } } + NodeType::Llm(l) => { + if let Some(t) = &l.fallback { + out.push((t.clone(), "llm 'fallback'")); + } + } NodeType::Agent(_) | NodeType::End(_) => {} } out