feat: created full llm node runtime implementation
This commit is contained in:
@@ -317,25 +317,6 @@ impl Functions {
|
|||||||
self.declarations.iter().find(|v| v.name == name)
|
self.declarations.iter().find(|v| v.name == name)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Narrow the declared tool list to a caller-supplied whitelist.
|
|
||||||
/// Entries are matched by exact name. The shorthand `mcp:<server>`
|
|
||||||
/// expands to the three MCP meta-functions Loki registers per
|
|
||||||
/// server (`mcp_invoke_<server>`, `mcp_search_<server>`,
|
|
||||||
/// `mcp_describe_<server>`).
|
|
||||||
pub fn retain_named(&mut self, allowed: &[String]) {
|
|
||||||
let mut expanded: std::collections::HashSet<String> = std::collections::HashSet::new();
|
|
||||||
for entry in allowed {
|
|
||||||
if let Some(server) = entry.strip_prefix("mcp:") {
|
|
||||||
expanded.insert(format!("{MCP_INVOKE_META_FUNCTION_NAME_PREFIX}_{server}"));
|
|
||||||
expanded.insert(format!("{MCP_SEARCH_META_FUNCTION_NAME_PREFIX}_{server}"));
|
|
||||||
expanded.insert(format!("{MCP_DESCRIBE_META_FUNCTION_NAME_PREFIX}_{server}"));
|
|
||||||
} else {
|
|
||||||
expanded.insert(entry.clone());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.declarations.retain(|d| expanded.contains(&d.name));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn contains(&self, name: &str) -> bool {
|
pub fn contains(&self, name: &str) -> bool {
|
||||||
self.declarations.iter().any(|v| v.name == name)
|
self.declarations.iter().any(|v| v.name == name)
|
||||||
}
|
}
|
||||||
@@ -1749,76 +1730,4 @@ mod tests {
|
|||||||
assert_eq!(result.call.name, "my_tool");
|
assert_eq!(result.call.name, "my_tool");
|
||||||
assert_eq!(result.output, json!({"result": "ok"}));
|
assert_eq!(result.output, json!({"result": "ok"}));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn function_with_names(names: &[&str]) -> Functions {
|
|
||||||
let declarations = names
|
|
||||||
.iter()
|
|
||||||
.map(|n| FunctionDeclaration {
|
|
||||||
name: (*n).to_string(),
|
|
||||||
description: String::new(),
|
|
||||||
parameters: JsonSchema::default(),
|
|
||||||
agent: false,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Functions { declarations }
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn retain_named_keeps_only_exact_matches() {
|
|
||||||
let mut f = function_with_names(&["read_query", "describe_table", "web_search_loki"]);
|
|
||||||
f.retain_named(&["read_query".to_string()]);
|
|
||||||
assert!(f.contains("read_query"));
|
|
||||||
assert!(!f.contains("describe_table"));
|
|
||||||
assert!(!f.contains("web_search_loki"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn retain_named_with_empty_list_removes_all() {
|
|
||||||
let mut f = function_with_names(&["a", "b", "c"]);
|
|
||||||
f.retain_named(&[]);
|
|
||||||
assert!(!f.contains("a"));
|
|
||||||
assert!(!f.contains("b"));
|
|
||||||
assert!(!f.contains("c"));
|
|
||||||
assert!(f.is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn retain_named_with_unknown_name_drops_everything() {
|
|
||||||
let mut f = function_with_names(&["a", "b"]);
|
|
||||||
f.retain_named(&["nonexistent".to_string()]);
|
|
||||||
assert!(f.is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn retain_named_with_mcp_shorthand_keeps_all_three_meta_functions() {
|
|
||||||
let mut f = Functions::default();
|
|
||||||
f.append_mcp_meta_functions(vec!["github".to_string(), "slack".to_string()]);
|
|
||||||
f.retain_named(&["mcp:github".to_string()]);
|
|
||||||
assert!(f.contains("mcp_invoke_github"));
|
|
||||||
assert!(f.contains("mcp_search_github"));
|
|
||||||
assert!(f.contains("mcp_describe_github"));
|
|
||||||
assert!(!f.contains("mcp_invoke_slack"));
|
|
||||||
assert!(!f.contains("mcp_search_slack"));
|
|
||||||
assert!(!f.contains("mcp_describe_slack"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn retain_named_mixes_exact_names_and_mcp_shorthand() {
|
|
||||||
let mut f = function_with_names(&["read_query", "describe_table"]);
|
|
||||||
f.append_mcp_meta_functions(vec!["pubmed".to_string()]);
|
|
||||||
f.retain_named(&["read_query".to_string(), "mcp:pubmed".to_string()]);
|
|
||||||
assert!(f.contains("read_query"));
|
|
||||||
assert!(!f.contains("describe_table"));
|
|
||||||
assert!(f.contains("mcp_invoke_pubmed"));
|
|
||||||
assert!(f.contains("mcp_search_pubmed"));
|
|
||||||
assert!(f.contains("mcp_describe_pubmed"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn retain_named_with_mcp_shorthand_for_unknown_server_drops_other_servers() {
|
|
||||||
let mut f = Functions::default();
|
|
||||||
f.append_mcp_meta_functions(vec!["alpha".to_string()]);
|
|
||||||
f.retain_named(&["mcp:beta".to_string()]);
|
|
||||||
assert!(!f.contains("mcp_invoke_alpha"));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
+260
-18
@@ -1,20 +1,18 @@
|
|||||||
//! Execution of `llm`-type graph nodes — one-shot LLM calls with a
|
//! Execution of `llm`-type graph nodes — one-shot LLM calls with a
|
||||||
//! bounded tool-call loop, an opt-in tool whitelist, and per-node
|
//! bounded tool-call loop, an opt-in tool whitelist (delegated to
|
||||||
//! overrides for model/temperature/top_p.
|
//! `Role.enabled_tools` / `Role.enabled_mcp_servers`), and per-node
|
||||||
//!
|
//! overrides for model/temperature/top_p. See
|
||||||
//! See `docs/implementation/graph-agents/10.5-llm-nodes.md` for the
|
//! `docs/implementation/graph-agents/10.5-llm-nodes.md` for the design.
|
||||||
//! design. The current implementation provides the routing and
|
|
||||||
//! state-update plumbing; the actual call_chat_completions loop lives
|
|
||||||
//! in `run_llm_once` and is the next implementation step. Calling
|
|
||||||
//! `LlmNodeExecutor::execute` today produces a controlled error so the
|
|
||||||
//! tolerant-fail routing in the executor still flows.
|
|
||||||
|
|
||||||
use super::state::StateManager;
|
use super::state::StateManager;
|
||||||
use super::types::LlmNode;
|
use super::types::LlmNode;
|
||||||
use crate::config::RequestContext;
|
use crate::client::{Model, ModelType, call_chat_completions};
|
||||||
use crate::utils::dimmed_text;
|
use crate::config::{RequestContext, Role, RoleLike};
|
||||||
use anyhow::{Context, Result, bail};
|
use crate::utils::{create_abort_signal, dimmed_text};
|
||||||
|
use anyhow::{Context, Result, anyhow, bail, Error};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
const OUTPUT_KEY: &str = "output";
|
const OUTPUT_KEY: &str = "output";
|
||||||
|
|
||||||
@@ -49,9 +47,9 @@ impl LlmNodeExecutor {
|
|||||||
async fn run(
|
async fn run(
|
||||||
node: &LlmNode,
|
node: &LlmNode,
|
||||||
state_manager: &mut StateManager,
|
state_manager: &mut StateManager,
|
||||||
_parent_ctx: &mut RequestContext,
|
parent_ctx: &mut RequestContext,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
let _instructions: Option<String> = match &node.instructions {
|
let instructions: Option<String> = match &node.instructions {
|
||||||
Some(s) => Some(
|
Some(s) => Some(
|
||||||
state_manager
|
state_manager
|
||||||
.interpolate(s)
|
.interpolate(s)
|
||||||
@@ -59,10 +57,13 @@ async fn run(
|
|||||||
),
|
),
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
let _prompt = state_manager
|
let prompt = state_manager
|
||||||
.interpolate(&node.prompt)
|
.interpolate(&node.prompt)
|
||||||
.context("Failed to interpolate llm node prompt")?;
|
.context("Failed to interpolate llm node prompt")?;
|
||||||
|
|
||||||
|
let (regular_tools, mcp_servers) = categorize_tools(node.tools.as_deref());
|
||||||
|
validate_tools_subset(®ular_tools, &mcp_servers, parent_ctx)?;
|
||||||
|
|
||||||
eprintln!(
|
eprintln!(
|
||||||
"{}",
|
"{}",
|
||||||
dimmed_text(&format!(
|
dimmed_text(&format!(
|
||||||
@@ -72,13 +73,207 @@ async fn run(
|
|||||||
))
|
))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let role = build_inline_role(
|
||||||
|
node,
|
||||||
|
instructions.as_deref(),
|
||||||
|
®ular_tools,
|
||||||
|
&mcp_servers,
|
||||||
|
parent_ctx,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let saved_role = parent_ctx.role.clone();
|
||||||
|
parent_ctx.role = Some(role);
|
||||||
|
let result = run_with_retries(node, &prompt, parent_ctx).await;
|
||||||
|
parent_ctx.role = saved_role;
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_with_retries(
|
||||||
|
node: &LlmNode,
|
||||||
|
prompt: &str,
|
||||||
|
ctx: &mut RequestContext,
|
||||||
|
) -> Result<String> {
|
||||||
|
let mut last_err: Option<Error> = None;
|
||||||
|
for attempt in 1..=node.max_attempts {
|
||||||
|
match run_chat_loop(node, prompt, ctx).await {
|
||||||
|
Ok(out) => return Ok(out),
|
||||||
|
Err(e) if is_transient(&e) && attempt < node.max_attempts => {
|
||||||
|
warn!("llm node attempt {attempt} failed (transient): {e}; retrying");
|
||||||
|
last_err = Some(e);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(last_err.unwrap_or_else(|| anyhow!("llm node exhausted retries")))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_chat_loop(node: &LlmNode, prompt: &str, ctx: &mut RequestContext) -> Result<String> {
|
||||||
|
let abort = create_abort_signal();
|
||||||
|
let app_cfg = Arc::clone(&ctx.app.config);
|
||||||
|
let role_for_input = ctx.role.clone();
|
||||||
|
let mut input = crate::config::Input::from_str(ctx, prompt, role_for_input);
|
||||||
|
let mut accumulated = String::new();
|
||||||
|
|
||||||
|
for turn in 0..node.max_iterations {
|
||||||
|
let client = input.create_client()?;
|
||||||
|
ctx.before_chat_completion(&input)?;
|
||||||
|
let (output, tool_results) =
|
||||||
|
call_chat_completions(&input, false, false, client.as_ref(), ctx, abort.clone())
|
||||||
|
.await?;
|
||||||
|
ctx.after_chat_completion(app_cfg.as_ref(), &input, &output, &tool_results)?;
|
||||||
|
|
||||||
|
if !output.is_empty() {
|
||||||
|
if !accumulated.is_empty() {
|
||||||
|
accumulated.push('\n');
|
||||||
|
}
|
||||||
|
accumulated.push_str(&output);
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_results.is_empty() {
|
||||||
|
return Ok(accumulated);
|
||||||
|
}
|
||||||
|
|
||||||
|
if turn + 1 == node.max_iterations {
|
||||||
bail!(
|
bail!(
|
||||||
"llm node execution body not yet implemented — see \
|
"llm node hit max_iterations ({}) before LLM concluded",
|
||||||
docs/implementation/graph-agents/10.5-llm-nodes.md \
|
node.max_iterations
|
||||||
(steps 3 & 5 of the implementation order)"
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
input = input.merge_tool_results(output, tool_results);
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("llm node ended without producing output")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_inline_role(
|
||||||
|
node: &LlmNode,
|
||||||
|
instructions: Option<&str>,
|
||||||
|
regular_tools: &[String],
|
||||||
|
mcp_servers: &[String],
|
||||||
|
parent_ctx: &RequestContext,
|
||||||
|
) -> Result<Role> {
|
||||||
|
let mut role = Role::new("llm_node", instructions.unwrap_or(""));
|
||||||
|
|
||||||
|
let model = match &node.model {
|
||||||
|
Some(model_id) => {
|
||||||
|
Model::retrieve_model(parent_ctx.app.config.as_ref(), model_id, ModelType::Chat)
|
||||||
|
.with_context(|| format!("Unknown model '{model_id}' on llm node"))?
|
||||||
|
}
|
||||||
|
None => parent_ctx.current_model().clone(),
|
||||||
|
};
|
||||||
|
role.set_model(model);
|
||||||
|
|
||||||
|
if let Some(t) = node.temperature {
|
||||||
|
role.set_temperature(Some(t));
|
||||||
|
}
|
||||||
|
if let Some(p) = node.top_p {
|
||||||
|
role.set_top_p(Some(p));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tool_entries) = &node.tools {
|
||||||
|
if tool_entries.is_empty() {
|
||||||
|
role.set_enabled_tools(Some(String::new()));
|
||||||
|
role.set_enabled_mcp_servers(Some(String::new()));
|
||||||
|
} else {
|
||||||
|
if !regular_tools.is_empty() {
|
||||||
|
role.set_enabled_tools(Some(regular_tools.join(",")));
|
||||||
|
} else {
|
||||||
|
role.set_enabled_tools(Some(String::new()));
|
||||||
|
}
|
||||||
|
if !mcp_servers.is_empty() {
|
||||||
|
role.set_enabled_mcp_servers(Some(mcp_servers.join(",")));
|
||||||
|
} else {
|
||||||
|
role.set_enabled_mcp_servers(Some(String::new()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(role)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn categorize_tools(entries: Option<&[String]>) -> (Vec<String>, Vec<String>) {
|
||||||
|
let mut regular = Vec::new();
|
||||||
|
let mut mcp = Vec::new();
|
||||||
|
let Some(entries) = entries else {
|
||||||
|
return (regular, mcp);
|
||||||
|
};
|
||||||
|
for e in entries {
|
||||||
|
if let Some(server) = e.strip_prefix("mcp:") {
|
||||||
|
mcp.push(server.to_string());
|
||||||
|
} else {
|
||||||
|
regular.push(e.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(regular, mcp)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_tools_subset(
|
||||||
|
regular: &[String],
|
||||||
|
mcp_servers: &[String],
|
||||||
|
parent_ctx: &RequestContext,
|
||||||
|
) -> Result<()> {
|
||||||
|
let agent = parent_ctx
|
||||||
|
.agent
|
||||||
|
.as_ref()
|
||||||
|
.ok_or_else(|| anyhow!("llm node requires an active agent"))?;
|
||||||
|
|
||||||
|
if !regular.is_empty() {
|
||||||
|
let known: HashSet<&str> = agent
|
||||||
|
.functions()
|
||||||
|
.declarations()
|
||||||
|
.iter()
|
||||||
|
.map(|d| d.name.as_str())
|
||||||
|
.collect();
|
||||||
|
for name in regular {
|
||||||
|
if !known.contains(name.as_str()) {
|
||||||
|
let mut avail: Vec<&str> = known.iter().copied().collect();
|
||||||
|
avail.sort();
|
||||||
|
bail!(
|
||||||
|
"llm node references unknown tool '{name}'. Agent '{}' provides: {}",
|
||||||
|
agent.name(),
|
||||||
|
avail.join(", ")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !mcp_servers.is_empty() {
|
||||||
|
let known: HashSet<&str> = agent
|
||||||
|
.mcp_server_names()
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.as_str())
|
||||||
|
.collect();
|
||||||
|
for server in mcp_servers {
|
||||||
|
if !known.contains(server.as_str()) {
|
||||||
|
let mut avail: Vec<&str> = known.iter().copied().collect();
|
||||||
|
avail.sort();
|
||||||
|
bail!(
|
||||||
|
"llm node references unknown MCP server 'mcp:{server}'. \
|
||||||
|
Agent '{}' has MCP servers: [{}]",
|
||||||
|
agent.name(),
|
||||||
|
avail.join(", ")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Loose substring match against the transient error messages we expect
|
||||||
|
/// from network/API failures or empty-output cases. Brittle by nature;
|
||||||
|
/// a typed error enum would be better long-term.
|
||||||
|
fn is_transient(err: &Error) -> bool {
|
||||||
|
let s = format!("{err:#}");
|
||||||
|
s.contains("timed out")
|
||||||
|
|| s.contains("rate limit")
|
||||||
|
|| s.contains("429")
|
||||||
|
|| s.contains("Connection reset")
|
||||||
|
|| s.contains("Connection refused")
|
||||||
|
|| s.contains("produced no output")
|
||||||
|
}
|
||||||
|
|
||||||
fn next_for_llm_node(
|
fn next_for_llm_node(
|
||||||
node_next: Option<&str>,
|
node_next: Option<&str>,
|
||||||
failed: bool,
|
failed: bool,
|
||||||
@@ -218,4 +413,51 @@ mod tests {
|
|||||||
let tools = vec!["a".to_string(), "b".to_string()];
|
let tools = vec!["a".to_string(), "b".to_string()];
|
||||||
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
|
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn categorize_tools_splits_mcp_and_regular() {
|
||||||
|
let entries = vec![
|
||||||
|
"read_query".to_string(),
|
||||||
|
"mcp:pubmed-search".to_string(),
|
||||||
|
"web_search_loki".to_string(),
|
||||||
|
"mcp:github".to_string(),
|
||||||
|
];
|
||||||
|
let (regular, mcp) = categorize_tools(Some(&entries));
|
||||||
|
assert_eq!(regular, vec!["read_query", "web_search_loki"]);
|
||||||
|
assert_eq!(mcp, vec!["pubmed-search", "github"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn categorize_tools_with_none_returns_empty() {
|
||||||
|
let (regular, mcp) = categorize_tools(None);
|
||||||
|
assert!(regular.is_empty());
|
||||||
|
assert!(mcp.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn categorize_tools_with_empty_returns_empty() {
|
||||||
|
let (regular, mcp) = categorize_tools(Some(&[]));
|
||||||
|
assert!(regular.is_empty());
|
||||||
|
assert!(mcp.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_transient_matches_expected_signatures() {
|
||||||
|
assert!(is_transient(&anyhow!("request timed out after 30s")));
|
||||||
|
assert!(is_transient(&anyhow!("rate limit reached")));
|
||||||
|
assert!(is_transient(&anyhow!("429 too many requests")));
|
||||||
|
assert!(is_transient(&anyhow!("Connection reset by peer")));
|
||||||
|
assert!(is_transient(&anyhow!("Connection refused")));
|
||||||
|
assert!(is_transient(&anyhow!("llm produced no output")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_transient_rejects_non_transient_errors() {
|
||||||
|
assert!(!is_transient(&anyhow!("Unknown model 'foo'")));
|
||||||
|
assert!(!is_transient(&anyhow!(
|
||||||
|
"llm node references unknown tool 'bad'"
|
||||||
|
)));
|
||||||
|
assert!(!is_transient(&anyhow!("hit max_iterations")));
|
||||||
|
assert!(!is_transient(&anyhow!("authentication failed")));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user