feat: created full llm node runtime implementation
This commit is contained in:
@@ -317,25 +317,6 @@ impl Functions {
|
||||
self.declarations.iter().find(|v| v.name == name)
|
||||
}
|
||||
|
||||
/// Narrow the declared tool list to a caller-supplied whitelist.
|
||||
/// Entries are matched by exact name. The shorthand `mcp:<server>`
|
||||
/// expands to the three MCP meta-functions Loki registers per
|
||||
/// server (`mcp_invoke_<server>`, `mcp_search_<server>`,
|
||||
/// `mcp_describe_<server>`).
|
||||
pub fn retain_named(&mut self, allowed: &[String]) {
|
||||
let mut expanded: std::collections::HashSet<String> = std::collections::HashSet::new();
|
||||
for entry in allowed {
|
||||
if let Some(server) = entry.strip_prefix("mcp:") {
|
||||
expanded.insert(format!("{MCP_INVOKE_META_FUNCTION_NAME_PREFIX}_{server}"));
|
||||
expanded.insert(format!("{MCP_SEARCH_META_FUNCTION_NAME_PREFIX}_{server}"));
|
||||
expanded.insert(format!("{MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX}_{server}"));
|
||||
} else {
|
||||
expanded.insert(entry.clone());
|
||||
}
|
||||
}
|
||||
self.declarations.retain(|d| expanded.contains(&d.name));
|
||||
}
|
||||
|
||||
pub fn contains(&self, name: &str) -> bool {
|
||||
self.declarations.iter().any(|v| v.name == name)
|
||||
}
|
||||
@@ -1749,76 +1730,4 @@ mod tests {
|
||||
assert_eq!(result.call.name, "my_tool");
|
||||
assert_eq!(result.output, json!({"result": "ok"}));
|
||||
}
|
||||
|
||||
fn function_with_names(names: &[&str]) -> Functions {
|
||||
let declarations = names
|
||||
.iter()
|
||||
.map(|n| FunctionDeclaration {
|
||||
name: (*n).to_string(),
|
||||
description: String::new(),
|
||||
parameters: JsonSchema::default(),
|
||||
agent: false,
|
||||
})
|
||||
.collect();
|
||||
Functions { declarations }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retain_named_keeps_only_exact_matches() {
|
||||
let mut f = function_with_names(&["read_query", "describe_table", "web_search_loki"]);
|
||||
f.retain_named(&["read_query".to_string()]);
|
||||
assert!(f.contains("read_query"));
|
||||
assert!(!f.contains("describe_table"));
|
||||
assert!(!f.contains("web_search_loki"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retain_named_with_empty_list_removes_all() {
|
||||
let mut f = function_with_names(&["a", "b", "c"]);
|
||||
f.retain_named(&[]);
|
||||
assert!(!f.contains("a"));
|
||||
assert!(!f.contains("b"));
|
||||
assert!(!f.contains("c"));
|
||||
assert!(f.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retain_named_with_unknown_name_drops_everything() {
|
||||
let mut f = function_with_names(&["a", "b"]);
|
||||
f.retain_named(&["nonexistent".to_string()]);
|
||||
assert!(f.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retain_named_with_mcp_shorthand_keeps_all_three_meta_functions() {
|
||||
let mut f = Functions::default();
|
||||
f.append_mcp_meta_functions(vec!["github".to_string(), "slack".to_string()]);
|
||||
f.retain_named(&["mcp:github".to_string()]);
|
||||
assert!(f.contains("mcp_invoke_github"));
|
||||
assert!(f.contains("mcp_search_github"));
|
||||
assert!(f.contains("mcp_describe_github"));
|
||||
assert!(!f.contains("mcp_invoke_slack"));
|
||||
assert!(!f.contains("mcp_search_slack"));
|
||||
assert!(!f.contains("mcp_describe_slack"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retain_named_mixes_exact_names_and_mcp_shorthand() {
|
||||
let mut f = function_with_names(&["read_query", "describe_table"]);
|
||||
f.append_mcp_meta_functions(vec!["pubmed".to_string()]);
|
||||
f.retain_named(&["read_query".to_string(), "mcp:pubmed".to_string()]);
|
||||
assert!(f.contains("read_query"));
|
||||
assert!(!f.contains("describe_table"));
|
||||
assert!(f.contains("mcp_invoke_pubmed"));
|
||||
assert!(f.contains("mcp_search_pubmed"));
|
||||
assert!(f.contains("mcp_describe_pubmed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retain_named_with_mcp_shorthand_for_unknown_server_drops_other_servers() {
|
||||
let mut f = Functions::default();
|
||||
f.append_mcp_meta_functions(vec!["alpha".to_string()]);
|
||||
f.retain_named(&["mcp:beta".to_string()]);
|
||||
assert!(!f.contains("mcp_invoke_alpha"));
|
||||
}
|
||||
}
|
||||
|
||||
+262
-20
@@ -1,20 +1,18 @@
|
||||
//! Execution of `llm`-type graph nodes — one-shot LLM calls with a
|
||||
//! bounded tool-call loop, an opt-in tool whitelist, and per-node
|
||||
//! overrides for model/temperature/top_p.
|
||||
//!
|
||||
//! See `docs/implementation/graph-agents/10.5-llm-nodes.md` for the
|
||||
//! design. The current implementation provides the routing and
|
||||
//! state-update plumbing; the actual call_chat_completions loop lives
|
||||
//! in `run_llm_once` and is the next implementation step. Calling
|
||||
//! `LlmNodeExecutor::execute` today produces a controlled error so the
|
||||
//! tolerant-fail routing in the executor still flows.
|
||||
//! bounded tool-call loop, an opt-in tool whitelist (delegated to
|
||||
//! `Role.enabled_tools` / `Role.enabled_mcp_servers`), and per-node
|
||||
//! overrides for model/temperature/top_p. See
|
||||
//! `docs/implementation/graph-agents/10.5-llm-nodes.md` for the design.
|
||||
|
||||
use super::state::StateManager;
|
||||
use super::types::LlmNode;
|
||||
use crate::config::RequestContext;
|
||||
use crate::utils::dimmed_text;
|
||||
use anyhow::{Context, Result, bail};
|
||||
use crate::client::{Model, ModelType, call_chat_completions};
|
||||
use crate::config::{RequestContext, Role, RoleLike};
|
||||
use crate::utils::{create_abort_signal, dimmed_text};
|
||||
use anyhow::{Context, Result, anyhow, bail, Error};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
const OUTPUT_KEY: &str = "output";
|
||||
|
||||
@@ -49,9 +47,9 @@ impl LlmNodeExecutor {
|
||||
async fn run(
|
||||
node: &LlmNode,
|
||||
state_manager: &mut StateManager,
|
||||
_parent_ctx: &mut RequestContext,
|
||||
parent_ctx: &mut RequestContext,
|
||||
) -> Result<String> {
|
||||
let _instructions: Option<String> = match &node.instructions {
|
||||
let instructions: Option<String> = match &node.instructions {
|
||||
Some(s) => Some(
|
||||
state_manager
|
||||
.interpolate(s)
|
||||
@@ -59,10 +57,13 @@ async fn run(
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
let _prompt = state_manager
|
||||
let prompt = state_manager
|
||||
.interpolate(&node.prompt)
|
||||
.context("Failed to interpolate llm node prompt")?;
|
||||
|
||||
let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref());
|
||||
validate_tools_subset(®ular_tools, &mcp_servers, parent_ctx)?;
|
||||
|
||||
eprintln!(
|
||||
"{}",
|
||||
dimmed_text(&format!(
|
||||
@@ -72,11 +73,205 @@ async fn run(
|
||||
))
|
||||
);
|
||||
|
||||
bail!(
|
||||
"llm node execution body not yet implemented — see \
|
||||
docs/implementation/graph-agents/10.5-llm-nodes.md \
|
||||
(steps 3 & 5 of the implementation order)"
|
||||
);
|
||||
let role = build_inline_role(
|
||||
node,
|
||||
instructions.as_deref(),
|
||||
®ular_tools,
|
||||
&mcp_servers,
|
||||
parent_ctx,
|
||||
)?;
|
||||
|
||||
let saved_role = parent_ctx.role.clone();
|
||||
parent_ctx.role = Some(role);
|
||||
let result = run_with_retries(node, &prompt, parent_ctx).await;
|
||||
parent_ctx.role = saved_role;
|
||||
result
|
||||
}
|
||||
|
||||
async fn run_with_retries(
|
||||
node: &LlmNode,
|
||||
prompt: &str,
|
||||
ctx: &mut RequestContext,
|
||||
) -> Result<String> {
|
||||
let mut last_err: Option<Error> = None;
|
||||
for attempt in 1..=node.max_attempts {
|
||||
match run_chat_loop(node, prompt, ctx).await {
|
||||
Ok(out) => return Ok(out),
|
||||
Err(e) if is_transient(&e) && attempt < node.max_attempts => {
|
||||
warn!("llm node attempt {attempt} failed (transient): {e}; retrying");
|
||||
last_err = Some(e);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
Err(last_err.unwrap_or_else(|| anyhow!("llm node exhausted retries")))
|
||||
}
|
||||
|
||||
async fn run_chat_loop(node: &LlmNode, prompt: &str, ctx: &mut RequestContext) -> Result<String> {
|
||||
let abort = create_abort_signal();
|
||||
let app_cfg = Arc::clone(&ctx.app.config);
|
||||
let role_for_input = ctx.role.clone();
|
||||
let mut input = crate::config::Input::from_str(ctx, prompt, role_for_input);
|
||||
let mut accumulated = String::new();
|
||||
|
||||
for turn in 0..node.max_iterations {
|
||||
let client = input.create_client()?;
|
||||
ctx.before_chat_completion(&input)?;
|
||||
let (output, tool_results) =
|
||||
call_chat_completions(&input, false, false, client.as_ref(), ctx, abort.clone())
|
||||
.await?;
|
||||
ctx.after_chat_completion(app_cfg.as_ref(), &input, &output, &tool_results)?;
|
||||
|
||||
if !output.is_empty() {
|
||||
if !accumulated.is_empty() {
|
||||
accumulated.push('\n');
|
||||
}
|
||||
accumulated.push_str(&output);
|
||||
}
|
||||
|
||||
if tool_results.is_empty() {
|
||||
return Ok(accumulated);
|
||||
}
|
||||
|
||||
if turn + 1 == node.max_iterations {
|
||||
bail!(
|
||||
"llm node hit max_iterations ({}) before LLM concluded",
|
||||
node.max_iterations
|
||||
);
|
||||
}
|
||||
|
||||
input = input.merge_tool_results(output, tool_results);
|
||||
}
|
||||
|
||||
bail!("llm node ended without producing output")
|
||||
}
|
||||
|
||||
fn build_inline_role(
|
||||
node: &LlmNode,
|
||||
instructions: Option<&str>,
|
||||
regular_tools: &[String],
|
||||
mcp_servers: &[String],
|
||||
parent_ctx: &RequestContext,
|
||||
) -> Result<Role> {
|
||||
let mut role = Role::new("llm_node", instructions.unwrap_or(""));
|
||||
|
||||
let model = match &node.model {
|
||||
Some(model_id) => {
|
||||
Model::retrieve_model(parent_ctx.app.config.as_ref(), model_id, ModelType::Chat)
|
||||
.with_context(|| format!("Unknown model '{model_id}' on llm node"))?
|
||||
}
|
||||
None => parent_ctx.current_model().clone(),
|
||||
};
|
||||
role.set_model(model);
|
||||
|
||||
if let Some(t) = node.temperature {
|
||||
role.set_temperature(Some(t));
|
||||
}
|
||||
if let Some(p) = node.top_p {
|
||||
role.set_top_p(Some(p));
|
||||
}
|
||||
|
||||
if let Some(tool_entries) = &node.tools {
|
||||
if tool_entries.is_empty() {
|
||||
role.set_enabled_tools(Some(String::new()));
|
||||
role.set_enabled_mcp_servers(Some(String::new()));
|
||||
} else {
|
||||
if !regular_tools.is_empty() {
|
||||
role.set_enabled_tools(Some(regular_tools.join(",")));
|
||||
} else {
|
||||
role.set_enabled_tools(Some(String::new()));
|
||||
}
|
||||
if !mcp_servers.is_empty() {
|
||||
role.set_enabled_mcp_servers(Some(mcp_servers.join(",")));
|
||||
} else {
|
||||
role.set_enabled_mcp_servers(Some(String::new()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(role)
|
||||
}
|
||||
|
||||
fn categorize_tools(entries: Option<&[String]>) -> (Vec<String>, Vec<String>) {
|
||||
let mut regular = Vec::new();
|
||||
let mut mcp = Vec::new();
|
||||
let Some(entries) = entries else {
|
||||
return (regular, mcp);
|
||||
};
|
||||
for e in entries {
|
||||
if let Some(server) = e.strip_prefix("mcp:") {
|
||||
mcp.push(server.to_string());
|
||||
} else {
|
||||
regular.push(e.clone());
|
||||
}
|
||||
}
|
||||
(regular, mcp)
|
||||
}
|
||||
|
||||
fn validate_tools_subset(
|
||||
regular: &[String],
|
||||
mcp_servers: &[String],
|
||||
parent_ctx: &RequestContext,
|
||||
) -> Result<()> {
|
||||
let agent = parent_ctx
|
||||
.agent
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow!("llm node requires an active agent"))?;
|
||||
|
||||
if !regular.is_empty() {
|
||||
let known: HashSet<&str> = agent
|
||||
.functions()
|
||||
.declarations()
|
||||
.iter()
|
||||
.map(|d| d.name.as_str())
|
||||
.collect();
|
||||
for name in regular {
|
||||
if !known.contains(name.as_str()) {
|
||||
let mut avail: Vec<&str> = known.iter().copied().collect();
|
||||
avail.sort();
|
||||
bail!(
|
||||
"llm node references unknown tool '{name}'. Agent '{}' provides: {}",
|
||||
agent.name(),
|
||||
avail.join(", ")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !mcp_servers.is_empty() {
|
||||
let known: HashSet<&str> = agent
|
||||
.mcp_server_names()
|
||||
.iter()
|
||||
.map(|s| s.as_str())
|
||||
.collect();
|
||||
for server in mcp_servers {
|
||||
if !known.contains(server.as_str()) {
|
||||
let mut avail: Vec<&str> = known.iter().copied().collect();
|
||||
avail.sort();
|
||||
bail!(
|
||||
"llm node references unknown MCP server 'mcp:{server}'. \
|
||||
Agent '{}' has MCP servers: [{}]",
|
||||
agent.name(),
|
||||
avail.join(", ")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Loose substring match against the transient error messages we expect
|
||||
/// from network/API failures or empty-output cases. Brittle by nature;
|
||||
/// a typed error enum would be better long-term.
|
||||
fn is_transient(err: &Error) -> bool {
|
||||
let s = format!("{err:#}");
|
||||
s.contains("timed out")
|
||||
|| s.contains("rate limit")
|
||||
|| s.contains("429")
|
||||
|| s.contains("Connection reset")
|
||||
|| s.contains("Connection refused")
|
||||
|| s.contains("produced no output")
|
||||
}
|
||||
|
||||
fn next_for_llm_node(
|
||||
@@ -218,4 +413,51 @@ mod tests {
|
||||
let tools = vec!["a".to_string(), "b".to_string()];
|
||||
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn categorize_tools_splits_mcp_and_regular() {
|
||||
let entries = vec![
|
||||
"read_query".to_string(),
|
||||
"mcp:pubmed-search".to_string(),
|
||||
"web_search_loki".to_string(),
|
||||
"mcp:github".to_string(),
|
||||
];
|
||||
let (regular, mcp) = categorize_tools(Some(&entries));
|
||||
assert_eq!(regular, vec!["read_query", "web_search_loki"]);
|
||||
assert_eq!(mcp, vec!["pubmed-search", "github"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn categorize_tools_with_none_returns_empty() {
|
||||
let (regular, mcp) = categorize_tools(None);
|
||||
assert!(regular.is_empty());
|
||||
assert!(mcp.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn categorize_tools_with_empty_returns_empty() {
|
||||
let (regular, mcp) = categorize_tools(Some(&[]));
|
||||
assert!(regular.is_empty());
|
||||
assert!(mcp.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_transient_matches_expected_signatures() {
|
||||
assert!(is_transient(&anyhow!("request timed out after 30s")));
|
||||
assert!(is_transient(&anyhow!("rate limit reached")));
|
||||
assert!(is_transient(&anyhow!("429 too many requests")));
|
||||
assert!(is_transient(&anyhow!("Connection reset by peer")));
|
||||
assert!(is_transient(&anyhow!("Connection refused")));
|
||||
assert!(is_transient(&anyhow!("llm produced no output")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_transient_rejects_non_transient_errors() {
|
||||
assert!(!is_transient(&anyhow!("Unknown model 'foo'")));
|
||||
assert!(!is_transient(&anyhow!(
|
||||
"llm node references unknown tool 'bad'"
|
||||
)));
|
||||
assert!(!is_transient(&anyhow!("hit max_iterations")));
|
||||
assert!(!is_transient(&anyhow!("authentication failed")));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user