feat: created full llm node runtime implementation

This commit is contained in:
2026-05-14 14:00:24 -06:00
parent 5669830510
commit 33782c59a8
2 changed files with 262 additions and 111 deletions
-91
View File
@@ -317,25 +317,6 @@ impl Functions {
self.declarations.iter().find(|v| v.name == name)
}
/// Narrow the declared tool list to a caller-supplied whitelist.
/// Entries are matched by exact name. The shorthand `mcp:<server>`
/// expands to the three MCP meta-functions Loki registers per
/// server (`mcp_invoke_<server>`, `mcp_search_<server>`,
/// `mcp_describe_<server>`).
pub fn retain_named(&mut self, allowed: &[String]) {
let mut expanded: std::collections::HashSet<String> = std::collections::HashSet::new();
for entry in allowed {
if let Some(server) = entry.strip_prefix("mcp:") {
expanded.insert(format!("{MCP_INVOKE_META_FUNCTION_NAME_PREFIX}_{server}"));
expanded.insert(format!("{MCP_SEARCH_META_FUNCTION_NAME_PREFIX}_{server}"));
expanded.insert(format!("{MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX}_{server}"));
} else {
expanded.insert(entry.clone());
}
}
self.declarations.retain(|d| expanded.contains(&d.name));
}
pub fn contains(&self, name: &str) -> bool {
self.declarations.iter().any(|v| v.name == name)
}
@@ -1749,76 +1730,4 @@ mod tests {
assert_eq!(result.call.name, "my_tool");
assert_eq!(result.output, json!({"result": "ok"}));
}
fn function_with_names(names: &[&str]) -> Functions {
let declarations = names
.iter()
.map(|n| FunctionDeclaration {
name: (*n).to_string(),
description: String::new(),
parameters: JsonSchema::default(),
agent: false,
})
.collect();
Functions { declarations }
}
#[test]
fn retain_named_keeps_only_exact_matches() {
let mut f = function_with_names(&["read_query", "describe_table", "web_search_loki"]);
f.retain_named(&["read_query".to_string()]);
assert!(f.contains("read_query"));
assert!(!f.contains("describe_table"));
assert!(!f.contains("web_search_loki"));
}
#[test]
fn retain_named_with_empty_list_removes_all() {
let mut f = function_with_names(&["a", "b", "c"]);
f.retain_named(&[]);
assert!(!f.contains("a"));
assert!(!f.contains("b"));
assert!(!f.contains("c"));
assert!(f.is_empty());
}
#[test]
fn retain_named_with_unknown_name_drops_everything() {
let mut f = function_with_names(&["a", "b"]);
f.retain_named(&["nonexistent".to_string()]);
assert!(f.is_empty());
}
#[test]
fn retain_named_with_mcp_shorthand_keeps_all_three_meta_functions() {
let mut f = Functions::default();
f.append_mcp_meta_functions(vec!["github".to_string(), "slack".to_string()]);
f.retain_named(&["mcp:github".to_string()]);
assert!(f.contains("mcp_invoke_github"));
assert!(f.contains("mcp_search_github"));
assert!(f.contains("mcp_describe_github"));
assert!(!f.contains("mcp_invoke_slack"));
assert!(!f.contains("mcp_search_slack"));
assert!(!f.contains("mcp_describe_slack"));
}
#[test]
fn retain_named_mixes_exact_names_and_mcp_shorthand() {
let mut f = function_with_names(&["read_query", "describe_table"]);
f.append_mcp_meta_functions(vec!["pubmed".to_string()]);
f.retain_named(&["read_query".to_string(), "mcp:pubmed".to_string()]);
assert!(f.contains("read_query"));
assert!(!f.contains("describe_table"));
assert!(f.contains("mcp_invoke_pubmed"));
assert!(f.contains("mcp_search_pubmed"));
assert!(f.contains("mcp_describe_pubmed"));
}
#[test]
fn retain_named_with_mcp_shorthand_for_unknown_server_drops_other_servers() {
let mut f = Functions::default();
f.append_mcp_meta_functions(vec!["alpha".to_string()]);
f.retain_named(&["mcp:beta".to_string()]);
assert!(!f.contains("mcp_invoke_alpha"));
}
}
+262 -20
View File
@@ -1,20 +1,18 @@
//! Execution of `llm`-type graph nodes — one-shot LLM calls with a
//! bounded tool-call loop, an opt-in tool whitelist, and per-node
//! overrides for model/temperature/top_p.
//!
//! See `docs/implementation/graph-agents/10.5-llm-nodes.md` for the
//! design. The current implementation provides the routing and
//! state-update plumbing; the actual call_chat_completions loop lives
//! in `run_llm_once` and is the next implementation step. Calling
//! `LlmNodeExecutor::execute` today produces a controlled error so the
//! tolerant-fail routing in the executor still flows.
//! bounded tool-call loop, an opt-in tool whitelist (delegated to
//! `Role.enabled_tools` / `Role.enabled_mcp_servers`), and per-node
//! overrides for model/temperature/top_p. See
//! `docs/implementation/graph-agents/10.5-llm-nodes.md` for the design.
use super::state::StateManager;
use super::types::LlmNode;
use crate::config::RequestContext;
use crate::utils::dimmed_text;
use anyhow::{Context, Result, bail};
use crate::client::{Model, ModelType, call_chat_completions};
use crate::config::{RequestContext, Role, RoleLike};
use crate::utils::{create_abort_signal, dimmed_text};
use anyhow::{Context, Result, anyhow, bail, Error};
use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
const OUTPUT_KEY: &str = "output";
@@ -49,9 +47,9 @@ impl LlmNodeExecutor {
async fn run(
node: &LlmNode,
state_manager: &mut StateManager,
_parent_ctx: &mut RequestContext,
parent_ctx: &mut RequestContext,
) -> Result<String> {
let _instructions: Option<String> = match &node.instructions {
let instructions: Option<String> = match &node.instructions {
Some(s) => Some(
state_manager
.interpolate(s)
@@ -59,10 +57,13 @@ async fn run(
),
None => None,
};
let _prompt = state_manager
let prompt = state_manager
.interpolate(&node.prompt)
.context("Failed to interpolate llm node prompt")?;
let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref());
validate_tools_subset(&regular_tools, &mcp_servers, parent_ctx)?;
eprintln!(
"{}",
dimmed_text(&format!(
@@ -72,11 +73,205 @@ async fn run(
))
);
bail!(
"llm node execution body not yet implemented — see \
docs/implementation/graph-agents/10.5-llm-nodes.md \
(steps 3 & 5 of the implementation order)"
);
let role = build_inline_role(
node,
instructions.as_deref(),
&regular_tools,
&mcp_servers,
parent_ctx,
)?;
let saved_role = parent_ctx.role.clone();
parent_ctx.role = Some(role);
let result = run_with_retries(node, &prompt, parent_ctx).await;
parent_ctx.role = saved_role;
result
}
async fn run_with_retries(
node: &LlmNode,
prompt: &str,
ctx: &mut RequestContext,
) -> Result<String> {
let mut last_err: Option<Error> = None;
for attempt in 1..=node.max_attempts {
match run_chat_loop(node, prompt, ctx).await {
Ok(out) => return Ok(out),
Err(e) if is_transient(&e) && attempt < node.max_attempts => {
warn!("llm node attempt {attempt} failed (transient): {e}; retrying");
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.unwrap_or_else(|| anyhow!("llm node exhausted retries")))
}
async fn run_chat_loop(node: &LlmNode, prompt: &str, ctx: &mut RequestContext) -> Result<String> {
let abort = create_abort_signal();
let app_cfg = Arc::clone(&ctx.app.config);
let role_for_input = ctx.role.clone();
let mut input = crate::config::Input::from_str(ctx, prompt, role_for_input);
let mut accumulated = String::new();
for turn in 0..node.max_iterations {
let client = input.create_client()?;
ctx.before_chat_completion(&input)?;
let (output, tool_results) =
call_chat_completions(&input, false, false, client.as_ref(), ctx, abort.clone())
.await?;
ctx.after_chat_completion(app_cfg.as_ref(), &input, &output, &tool_results)?;
if !output.is_empty() {
if !accumulated.is_empty() {
accumulated.push('\n');
}
accumulated.push_str(&output);
}
if tool_results.is_empty() {
return Ok(accumulated);
}
if turn + 1 == node.max_iterations {
bail!(
"llm node hit max_iterations ({}) before LLM concluded",
node.max_iterations
);
}
input = input.merge_tool_results(output, tool_results);
}
bail!("llm node ended without producing output")
}
fn build_inline_role(
node: &LlmNode,
instructions: Option<&str>,
regular_tools: &[String],
mcp_servers: &[String],
parent_ctx: &RequestContext,
) -> Result<Role> {
let mut role = Role::new("llm_node", instructions.unwrap_or(""));
let model = match &node.model {
Some(model_id) => {
Model::retrieve_model(parent_ctx.app.config.as_ref(), model_id, ModelType::Chat)
.with_context(|| format!("Unknown model '{model_id}' on llm node"))?
}
None => parent_ctx.current_model().clone(),
};
role.set_model(model);
if let Some(t) = node.temperature {
role.set_temperature(Some(t));
}
if let Some(p) = node.top_p {
role.set_top_p(Some(p));
}
if let Some(tool_entries) = &node.tools {
if tool_entries.is_empty() {
role.set_enabled_tools(Some(String::new()));
role.set_enabled_mcp_servers(Some(String::new()));
} else {
if !regular_tools.is_empty() {
role.set_enabled_tools(Some(regular_tools.join(",")));
} else {
role.set_enabled_tools(Some(String::new()));
}
if !mcp_servers.is_empty() {
role.set_enabled_mcp_servers(Some(mcp_servers.join(",")));
} else {
role.set_enabled_mcp_servers(Some(String::new()));
}
}
}
Ok(role)
}
fn categorize_tools(entries: Option<&[String]>) -> (Vec<String>, Vec<String>) {
let mut regular = Vec::new();
let mut mcp = Vec::new();
let Some(entries) = entries else {
return (regular, mcp);
};
for e in entries {
if let Some(server) = e.strip_prefix("mcp:") {
mcp.push(server.to_string());
} else {
regular.push(e.clone());
}
}
(regular, mcp)
}
fn validate_tools_subset(
regular: &[String],
mcp_servers: &[String],
parent_ctx: &RequestContext,
) -> Result<()> {
let agent = parent_ctx
.agent
.as_ref()
.ok_or_else(|| anyhow!("llm node requires an active agent"))?;
if !regular.is_empty() {
let known: HashSet<&str> = agent
.functions()
.declarations()
.iter()
.map(|d| d.name.as_str())
.collect();
for name in regular {
if !known.contains(name.as_str()) {
let mut avail: Vec<&str> = known.iter().copied().collect();
avail.sort();
bail!(
"llm node references unknown tool '{name}'. Agent '{}' provides: {}",
agent.name(),
avail.join(", ")
);
}
}
}
if !mcp_servers.is_empty() {
let known: HashSet<&str> = agent
.mcp_server_names()
.iter()
.map(|s| s.as_str())
.collect();
for server in mcp_servers {
if !known.contains(server.as_str()) {
let mut avail: Vec<&str> = known.iter().copied().collect();
avail.sort();
bail!(
"llm node references unknown MCP server 'mcp:{server}'. \
Agent '{}' has MCP servers: [{}]",
agent.name(),
avail.join(", ")
);
}
}
}
Ok(())
}
/// Loose substring match against the transient error messages we expect
/// from network/API failures or empty-output cases. Brittle by nature;
/// a typed error enum would be better long-term.
fn is_transient(err: &Error) -> bool {
let s = format!("{err:#}");
s.contains("timed out")
|| s.contains("rate limit")
|| s.contains("429")
|| s.contains("Connection reset")
|| s.contains("Connection refused")
|| s.contains("produced no output")
}
fn next_for_llm_node(
@@ -218,4 +413,51 @@ mod tests {
let tools = vec!["a".to_string(), "b".to_string()];
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
}
#[test]
fn categorize_tools_splits_mcp_and_regular() {
let entries = vec![
"read_query".to_string(),
"mcp:pubmed-search".to_string(),
"web_search_loki".to_string(),
"mcp:github".to_string(),
];
let (regular, mcp) = categorize_tools(Some(&entries));
assert_eq!(regular, vec!["read_query", "web_search_loki"]);
assert_eq!(mcp, vec!["pubmed-search", "github"]);
}
#[test]
fn categorize_tools_with_none_returns_empty() {
let (regular, mcp) = categorize_tools(None);
assert!(regular.is_empty());
assert!(mcp.is_empty());
}
#[test]
fn categorize_tools_with_empty_returns_empty() {
let (regular, mcp) = categorize_tools(Some(&[]));
assert!(regular.is_empty());
assert!(mcp.is_empty());
}
#[test]
fn is_transient_matches_expected_signatures() {
assert!(is_transient(&anyhow!("request timed out after 30s")));
assert!(is_transient(&anyhow!("rate limit reached")));
assert!(is_transient(&anyhow!("429 too many requests")));
assert!(is_transient(&anyhow!("Connection reset by peer")));
assert!(is_transient(&anyhow!("Connection refused")));
assert!(is_transient(&anyhow!("llm produced no output")));
}
#[test]
fn is_transient_rejects_non_transient_errors() {
assert!(!is_transient(&anyhow!("Unknown model 'foo'")));
assert!(!is_transient(&anyhow!(
"llm node references unknown tool 'bad'"
)));
assert!(!is_transient(&anyhow!("hit max_iterations")));
assert!(!is_transient(&anyhow!("authentication failed")));
}
}