feat: migrated llm node validation to graph loading time instead of graph runtime

This commit is contained in:
2026-05-18 11:51:47 -06:00
parent a615559d9c
commit 87ab900481
2 changed files with 207 additions and 3 deletions
+9 -2
View File
@@ -15,13 +15,14 @@ use super::script::ScriptExecutor;
use super::state::StateManager;
use super::types::{EndNode, Graph, Node, NodeType};
use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
use super::validator::GraphValidator;
use super::validator::{AgentValidationContext, GraphValidator};
use crate::config::RequestContext;
use crate::utils::AbortSignal;
use anyhow::{Context, Result, anyhow, bail};
use serde_json::Value;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub struct GraphExecutor {
@@ -72,7 +73,13 @@ impl GraphExecutor {
let GraphExecutor { graph, base_dir } = self;
if graph.settings.validate_before_run {
let validator = GraphValidator::new(&base_dir);
let mut validator = GraphValidator::new(&base_dir);
if let Some(agent) = &ctx.agent {
validator = validator.with_agent_context(AgentValidationContext::from_agent(
agent,
Arc::clone(&ctx.app.config),
));
}
let result = validator.validate(&graph);
for w in &result.warnings {
logger.validation_warning(w.node_id.as_deref(), &w.message);
+198 -1
View File
@@ -10,10 +10,12 @@
//! against dynamically-routed graphs.
use super::types::{Graph, Node, NodeType};
use crate::config::paths;
use crate::client::{Model, ModelType};
use crate::config::{Agent, AppConfig, paths};
use anyhow::{Result, bail};
use std::collections::{HashSet, VecDeque};
use std::path::PathBuf;
use std::sync::Arc;
/// A single validation finding, optionally scoped to a node.
#[derive(Debug, Clone)]
@@ -81,19 +83,50 @@ impl ValidationResult {
}
}
/// Agent-level context that `llm`-node `tools` and `model` references are
/// validated against. Supplying it lets a typo in a node's `tools` or
/// `model` be caught at graph load time instead of when the node runs.
pub struct AgentValidationContext {
pub tool_names: HashSet<String>,
pub mcp_servers: HashSet<String>,
pub app_config: Arc<AppConfig>,
}
impl AgentValidationContext {
pub fn from_agent(agent: &Agent, app_config: Arc<AppConfig>) -> Self {
Self {
tool_names: agent
.functions()
.declarations()
.iter()
.map(|d| d.name.clone())
.collect(),
mcp_servers: agent.mcp_server_names().iter().cloned().collect(),
app_config,
}
}
}
/// Validator for graph structures. `base_dir` is used to resolve relative
/// script paths (typically the owning agent's data directory).
pub struct GraphValidator {
base_dir: PathBuf,
agent_ctx: Option<AgentValidationContext>,
}
impl GraphValidator {
pub fn new(base_dir: impl Into<PathBuf>) -> Self {
Self {
base_dir: base_dir.into(),
agent_ctx: None,
}
}
pub fn with_agent_context(mut self, ctx: AgentValidationContext) -> Self {
self.agent_ctx = Some(ctx);
self
}
pub fn validate(&self, graph: &Graph) -> ValidationResult {
let mut result = ValidationResult::default();
self.validate_node_references(graph, &mut result);
@@ -104,6 +137,7 @@ impl GraphValidator {
self.validate_agents(graph, &mut result);
self.validate_approval_routes(graph, &mut result);
self.validate_rag_nodes(graph, &mut result);
self.validate_llm_nodes(graph, &mut result);
result
}
@@ -128,6 +162,43 @@ impl GraphValidator {
}
}
fn validate_llm_nodes(&self, graph: &Graph, result: &mut ValidationResult) {
let Some(ctx) = &self.agent_ctx else {
return;
};
for (node_id, node) in &graph.nodes {
let NodeType::Llm(llm) = &node.node_type else {
continue;
};
if let Some(tools) = &llm.tools {
for entry in tools {
if let Some(server) = entry.strip_prefix("mcp:") {
if !ctx.mcp_servers.contains(server) {
result.error(ValidationError::with_node(
node_id,
format!("llm node references unknown MCP server 'mcp:{server}'"),
));
}
} else if !ctx.tool_names.contains(entry) {
result.error(ValidationError::with_node(
node_id,
format!("llm node references unknown tool '{entry}'"),
));
}
}
}
if let Some(model_id) = &llm.model
&& Model::retrieve_model(ctx.app_config.as_ref(), model_id, ModelType::Chat)
.is_err()
{
result.error(ValidationError::with_node(
node_id,
format!("llm node references unknown model '{model_id}'"),
));
}
}
}
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
for (node_id, node) in &graph.nodes {
for (target, label) in declared_targets(node) {
@@ -490,6 +561,132 @@ mod tests {
assert!(result.errors.iter().any(|e| e.message.contains("ghost")));
}
fn agent_ctx(tools: &[&str], mcp: &[&str]) -> AgentValidationContext {
AgentValidationContext {
tool_names: tools.iter().map(|s| s.to_string()).collect(),
mcp_servers: mcp.iter().map(|s| s.to_string()).collect(),
app_config: Arc::new(AppConfig::default()),
}
}
fn llm_node_with(id: &str, tools: Option<Vec<&str>>, model: Option<&str>) -> Node {
let mut node = llm_node(id, None, Some("end"));
if let NodeType::Llm(ref mut n) = node.node_type {
n.tools = tools.map(|t| t.iter().map(|s| s.to_string()).collect());
n.model = model.map(String::from);
}
node
}
#[test]
fn llm_node_unknown_tool_is_an_error() {
let graph = graph_with(
vec![
("l", llm_node_with("l", Some(vec!["bogus_tool"]), None)),
("end", end_node("end")),
],
"l",
);
let result = validator()
.with_agent_context(agent_ctx(&["read_query"], &[]))
.validate(&graph);
assert!(!result.is_valid());
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("bogus_tool"))
);
}
#[test]
fn llm_node_known_tool_passes() {
let graph = graph_with(
vec![
("l", llm_node_with("l", Some(vec!["read_query"]), None)),
("end", end_node("end")),
],
"l",
);
let result = validator()
.with_agent_context(agent_ctx(&["read_query"], &[]))
.validate(&graph);
assert!(result.is_valid());
}
#[test]
fn llm_node_unknown_mcp_server_is_an_error() {
let graph = graph_with(
vec![
("l", llm_node_with("l", Some(vec!["mcp:bogus"]), None)),
("end", end_node("end")),
],
"l",
);
let result = validator()
.with_agent_context(agent_ctx(&[], &["pubmed-search"]))
.validate(&graph);
assert!(!result.is_valid());
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("mcp:bogus"))
);
}
#[test]
fn llm_node_known_mcp_server_passes() {
let graph = graph_with(
vec![
(
"l",
llm_node_with("l", Some(vec!["mcp:pubmed-search"]), None),
),
("end", end_node("end")),
],
"l",
);
let result = validator()
.with_agent_context(agent_ctx(&[], &["pubmed-search"]))
.validate(&graph);
assert!(result.is_valid());
}
#[test]
fn llm_node_unknown_model_is_an_error() {
let graph = graph_with(
vec![
("l", llm_node_with("l", None, Some("nonexistent:model"))),
("end", end_node("end")),
],
"l",
);
let result = validator()
.with_agent_context(agent_ctx(&[], &[]))
.validate(&graph);
assert!(!result.is_valid());
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("nonexistent:model"))
);
}
#[test]
fn llm_node_validation_skipped_without_agent_context() {
let graph = graph_with(
vec![
("l", llm_node_with("l", Some(vec!["bogus_tool"]), None)),
("end", end_node("end")),
],
"l",
);
let result = validator().validate(&graph);
assert!(result.is_valid());
}
#[test]
fn rag_node_without_documents_errors() {
let graph = graph_with(