From 33782c59a8980108d2dddad076694e0c26ce2c2e Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Thu, 14 May 2026 14:00:24 -0600 Subject: [PATCH] feat: created full llm node runtime implementation --- src/function/mod.rs | 91 -------------- src/graph/llm.rs | 282 ++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 262 insertions(+), 111 deletions(-) diff --git a/src/function/mod.rs b/src/function/mod.rs index 30751d0..1512d83 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -317,25 +317,6 @@ impl Functions { self.declarations.iter().find(|v| v.name == name) } - /// Narrow the declared tool list to a caller-supplied whitelist. - /// Entries are matched by exact name. The shorthand `mcp:` - /// expands to the three MCP meta-functions Loki registers per - /// server (`mcp_invoke_`, `mcp_search_`, - /// `mcp_describe_`). - pub fn retain_named(&mut self, allowed: &[String]) { - let mut expanded: std::collections::HashSet = std::collections::HashSet::new(); - for entry in allowed { - if let Some(server) = entry.strip_prefix("mcp:") { - expanded.insert(format!("{MCP_INVOKE_META_FUNCTION_NAME_PREFIX}_{server}")); - expanded.insert(format!("{MCP_SEARCH_META_FUNCTION_NAME_PREFIX}_{server}")); - expanded.insert(format!("{MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX}_{server}")); - } else { - expanded.insert(entry.clone()); - } - } - self.declarations.retain(|d| expanded.contains(&d.name)); - } - pub fn contains(&self, name: &str) -> bool { self.declarations.iter().any(|v| v.name == name) } @@ -1749,76 +1730,4 @@ mod tests { assert_eq!(result.call.name, "my_tool"); assert_eq!(result.output, json!({"result": "ok"})); } - - fn function_with_names(names: &[&str]) -> Functions { - let declarations = names - .iter() - .map(|n| FunctionDeclaration { - name: (*n).to_string(), - description: String::new(), - parameters: JsonSchema::default(), - agent: false, - }) - .collect(); - Functions { declarations } - } - - #[test] - fn retain_named_keeps_only_exact_matches() { - let mut f = function_with_names(&["read_query", "describe_table", "web_search_loki"]); - f.retain_named(&["read_query".to_string()]); - assert!(f.contains("read_query")); - assert!(!f.contains("describe_table")); - assert!(!f.contains("web_search_loki")); - } - - #[test] - fn retain_named_with_empty_list_removes_all() { - let mut f = function_with_names(&["a", "b", "c"]); - f.retain_named(&[]); - assert!(!f.contains("a")); - assert!(!f.contains("b")); - assert!(!f.contains("c")); - assert!(f.is_empty()); - } - - #[test] - fn retain_named_with_unknown_name_drops_everything() { - let mut f = function_with_names(&["a", "b"]); - f.retain_named(&["nonexistent".to_string()]); - assert!(f.is_empty()); - } - - #[test] - fn retain_named_with_mcp_shorthand_keeps_all_three_meta_functions() { - let mut f = Functions::default(); - f.append_mcp_meta_functions(vec!["github".to_string(), "slack".to_string()]); - f.retain_named(&["mcp:github".to_string()]); - assert!(f.contains("mcp_invoke_github")); - assert!(f.contains("mcp_search_github")); - assert!(f.contains("mcp_describe_github")); - assert!(!f.contains("mcp_invoke_slack")); - assert!(!f.contains("mcp_search_slack")); - assert!(!f.contains("mcp_describe_slack")); - } - - #[test] - fn retain_named_mixes_exact_names_and_mcp_shorthand() { - let mut f = function_with_names(&["read_query", "describe_table"]); - f.append_mcp_meta_functions(vec!["pubmed".to_string()]); - f.retain_named(&["read_query".to_string(), "mcp:pubmed".to_string()]); - assert!(f.contains("read_query")); - assert!(!f.contains("describe_table")); - assert!(f.contains("mcp_invoke_pubmed")); - assert!(f.contains("mcp_search_pubmed")); - assert!(f.contains("mcp_describe_pubmed")); - } - - #[test] - fn retain_named_with_mcp_shorthand_for_unknown_server_drops_other_servers() { - let mut f = Functions::default(); - f.append_mcp_meta_functions(vec!["alpha".to_string()]); - f.retain_named(&["mcp:beta".to_string()]); - assert!(!f.contains("mcp_invoke_alpha")); - } } diff --git a/src/graph/llm.rs b/src/graph/llm.rs index 8f7be3f..d2828cb 100644 --- a/src/graph/llm.rs +++ b/src/graph/llm.rs @@ -1,20 +1,18 @@ //! Execution of `llm`-type graph nodes — one-shot LLM calls with a -//! bounded tool-call loop, an opt-in tool whitelist, and per-node -//! overrides for model/temperature/top_p. -//! -//! See `docs/implementation/graph-agents/10.5-llm-nodes.md` for the -//! design. The current implementation provides the routing and -//! state-update plumbing; the actual call_chat_completions loop lives -//! in `run_llm_once` and is the next implementation step. Calling -//! `LlmNodeExecutor::execute` today produces a controlled error so the -//! tolerant-fail routing in the executor still flows. +//! bounded tool-call loop, an opt-in tool whitelist (delegated to +//! `Role.enabled_tools` / `Role.enabled_mcp_servers`), and per-node +//! overrides for model/temperature/top_p. See +//! `docs/implementation/graph-agents/10.5-llm-nodes.md` for the design. use super::state::StateManager; use super::types::LlmNode; -use crate::config::RequestContext; -use crate::utils::dimmed_text; -use anyhow::{Context, Result, bail}; +use crate::client::{Model, ModelType, call_chat_completions}; +use crate::config::{RequestContext, Role, RoleLike}; +use crate::utils::{create_abort_signal, dimmed_text}; +use anyhow::{Context, Result, anyhow, bail, Error}; use serde_json::Value; +use std::collections::HashSet; +use std::sync::Arc; const OUTPUT_KEY: &str = "output"; @@ -49,9 +47,9 @@ impl LlmNodeExecutor { async fn run( node: &LlmNode, state_manager: &mut StateManager, - _parent_ctx: &mut RequestContext, + parent_ctx: &mut RequestContext, ) -> Result { - let _instructions: Option = match &node.instructions { + let instructions: Option = match &node.instructions { Some(s) => Some( state_manager .interpolate(s) @@ -59,10 +57,13 @@ async fn run( ), None => None, }; - let _prompt = state_manager + let prompt = state_manager .interpolate(&node.prompt) .context("Failed to interpolate llm node prompt")?; + let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref()); + validate_tools_subset(®ular_tools, &mcp_servers, parent_ctx)?; + eprintln!( "{}", dimmed_text(&format!( @@ -72,11 +73,205 @@ async fn run( )) ); - bail!( - "llm node execution body not yet implemented — see \ - docs/implementation/graph-agents/10.5-llm-nodes.md \ - (steps 3 & 5 of the implementation order)" - ); + let role = build_inline_role( + node, + instructions.as_deref(), + ®ular_tools, + &mcp_servers, + parent_ctx, + )?; + + let saved_role = parent_ctx.role.clone(); + parent_ctx.role = Some(role); + let result = run_with_retries(node, &prompt, parent_ctx).await; + parent_ctx.role = saved_role; + result +} + +async fn run_with_retries( + node: &LlmNode, + prompt: &str, + ctx: &mut RequestContext, +) -> Result { + let mut last_err: Option = None; + for attempt in 1..=node.max_attempts { + match run_chat_loop(node, prompt, ctx).await { + Ok(out) => return Ok(out), + Err(e) if is_transient(&e) && attempt < node.max_attempts => { + warn!("llm node attempt {attempt} failed (transient): {e}; retrying"); + last_err = Some(e); + } + Err(e) => return Err(e), + } + } + Err(last_err.unwrap_or_else(|| anyhow!("llm node exhausted retries"))) +} + +async fn run_chat_loop(node: &LlmNode, prompt: &str, ctx: &mut RequestContext) -> Result { + let abort = create_abort_signal(); + let app_cfg = Arc::clone(&ctx.app.config); + let role_for_input = ctx.role.clone(); + let mut input = crate::config::Input::from_str(ctx, prompt, role_for_input); + let mut accumulated = String::new(); + + for turn in 0..node.max_iterations { + let client = input.create_client()?; + ctx.before_chat_completion(&input)?; + let (output, tool_results) = + call_chat_completions(&input, false, false, client.as_ref(), ctx, abort.clone()) + .await?; + ctx.after_chat_completion(app_cfg.as_ref(), &input, &output, &tool_results)?; + + if !output.is_empty() { + if !accumulated.is_empty() { + accumulated.push('\n'); + } + accumulated.push_str(&output); + } + + if tool_results.is_empty() { + return Ok(accumulated); + } + + if turn + 1 == node.max_iterations { + bail!( + "llm node hit max_iterations ({}) before LLM concluded", + node.max_iterations + ); + } + + input = input.merge_tool_results(output, tool_results); + } + + bail!("llm node ended without producing output") +} + +fn build_inline_role( + node: &LlmNode, + instructions: Option<&str>, + regular_tools: &[String], + mcp_servers: &[String], + parent_ctx: &RequestContext, +) -> Result { + let mut role = Role::new("llm_node", instructions.unwrap_or("")); + + let model = match &node.model { + Some(model_id) => { + Model::retrieve_model(parent_ctx.app.config.as_ref(), model_id, ModelType::Chat) + .with_context(|| format!("Unknown model '{model_id}' on llm node"))? + } + None => parent_ctx.current_model().clone(), + }; + role.set_model(model); + + if let Some(t) = node.temperature { + role.set_temperature(Some(t)); + } + if let Some(p) = node.top_p { + role.set_top_p(Some(p)); + } + + if let Some(tool_entries) = &node.tools { + if tool_entries.is_empty() { + role.set_enabled_tools(Some(String::new())); + role.set_enabled_mcp_servers(Some(String::new())); + } else { + if !regular_tools.is_empty() { + role.set_enabled_tools(Some(regular_tools.join(","))); + } else { + role.set_enabled_tools(Some(String::new())); + } + if !mcp_servers.is_empty() { + role.set_enabled_mcp_servers(Some(mcp_servers.join(","))); + } else { + role.set_enabled_mcp_servers(Some(String::new())); + } + } + } + + Ok(role) +} + +fn categorize_tools(entries: Option<&[String]>) -> (Vec, Vec) { + let mut regular = Vec::new(); + let mut mcp = Vec::new(); + let Some(entries) = entries else { + return (regular, mcp); + }; + for e in entries { + if let Some(server) = e.strip_prefix("mcp:") { + mcp.push(server.to_string()); + } else { + regular.push(e.clone()); + } + } + (regular, mcp) +} + +fn validate_tools_subset( + regular: &[String], + mcp_servers: &[String], + parent_ctx: &RequestContext, +) -> Result<()> { + let agent = parent_ctx + .agent + .as_ref() + .ok_or_else(|| anyhow!("llm node requires an active agent"))?; + + if !regular.is_empty() { + let known: HashSet<&str> = agent + .functions() + .declarations() + .iter() + .map(|d| d.name.as_str()) + .collect(); + for name in regular { + if !known.contains(name.as_str()) { + let mut avail: Vec<&str> = known.iter().copied().collect(); + avail.sort(); + bail!( + "llm node references unknown tool '{name}'. Agent '{}' provides: {}", + agent.name(), + avail.join(", ") + ); + } + } + } + + if !mcp_servers.is_empty() { + let known: HashSet<&str> = agent + .mcp_server_names() + .iter() + .map(|s| s.as_str()) + .collect(); + for server in mcp_servers { + if !known.contains(server.as_str()) { + let mut avail: Vec<&str> = known.iter().copied().collect(); + avail.sort(); + bail!( + "llm node references unknown MCP server 'mcp:{server}'. \ + Agent '{}' has MCP servers: [{}]", + agent.name(), + avail.join(", ") + ); + } + } + } + + Ok(()) +} + +/// Loose substring match against the transient error messages we expect +/// from network/API failures or empty-output cases. Brittle by nature; +/// a typed error enum would be better long-term. +fn is_transient(err: &Error) -> bool { + let s = format!("{err:#}"); + s.contains("timed out") + || s.contains("rate limit") + || s.contains("429") + || s.contains("Connection reset") + || s.contains("Connection refused") + || s.contains("produced no output") } fn next_for_llm_node( @@ -218,4 +413,51 @@ mod tests { let tools = vec!["a".to_string(), "b".to_string()]; assert_eq!(describe_tools_filter(Some(&tools)), "a,b"); } + + #[test] + fn categorize_tools_splits_mcp_and_regular() { + let entries = vec![ + "read_query".to_string(), + "mcp:pubmed-search".to_string(), + "web_search_loki".to_string(), + "mcp:github".to_string(), + ]; + let (regular, mcp) = categorize_tools(Some(&entries)); + assert_eq!(regular, vec!["read_query", "web_search_loki"]); + assert_eq!(mcp, vec!["pubmed-search", "github"]); + } + + #[test] + fn categorize_tools_with_none_returns_empty() { + let (regular, mcp) = categorize_tools(None); + assert!(regular.is_empty()); + assert!(mcp.is_empty()); + } + + #[test] + fn categorize_tools_with_empty_returns_empty() { + let (regular, mcp) = categorize_tools(Some(&[])); + assert!(regular.is_empty()); + assert!(mcp.is_empty()); + } + + #[test] + fn is_transient_matches_expected_signatures() { + assert!(is_transient(&anyhow!("request timed out after 30s"))); + assert!(is_transient(&anyhow!("rate limit reached"))); + assert!(is_transient(&anyhow!("429 too many requests"))); + assert!(is_transient(&anyhow!("Connection reset by peer"))); + assert!(is_transient(&anyhow!("Connection refused"))); + assert!(is_transient(&anyhow!("llm produced no output"))); + } + + #[test] + fn is_transient_rejects_non_transient_errors() { + assert!(!is_transient(&anyhow!("Unknown model 'foo'"))); + assert!(!is_transient(&anyhow!( + "llm node references unknown tool 'bad'" + ))); + assert!(!is_transient(&anyhow!("hit max_iterations"))); + assert!(!is_transient(&anyhow!("authentication failed"))); + } }