feat: migrated llm node validation to graph loading time instead of graph runtime
This commit is contained in:
@@ -15,13 +15,14 @@ use super::script::ScriptExecutor;
|
||||
use super::state::StateManager;
|
||||
use super::types::{EndNode, Graph, Node, NodeType};
|
||||
use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
|
||||
use super::validator::GraphValidator;
|
||||
use super::validator::{AgentValidationContext, GraphValidator};
|
||||
use crate::config::RequestContext;
|
||||
use crate::utils::AbortSignal;
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
pub struct GraphExecutor {
|
||||
@@ -72,7 +73,13 @@ impl GraphExecutor {
|
||||
let GraphExecutor { graph, base_dir } = self;
|
||||
|
||||
if graph.settings.validate_before_run {
|
||||
let validator = GraphValidator::new(&base_dir);
|
||||
let mut validator = GraphValidator::new(&base_dir);
|
||||
if let Some(agent) = &ctx.agent {
|
||||
validator = validator.with_agent_context(AgentValidationContext::from_agent(
|
||||
agent,
|
||||
Arc::clone(&ctx.app.config),
|
||||
));
|
||||
}
|
||||
let result = validator.validate(&graph);
|
||||
for w in &result.warnings {
|
||||
logger.validation_warning(w.node_id.as_deref(), &w.message);
|
||||
|
||||
+198
-1
@@ -10,10 +10,12 @@
|
||||
//! against dynamically-routed graphs.
|
||||
|
||||
use super::types::{Graph, Node, NodeType};
|
||||
use crate::config::paths;
|
||||
use crate::client::{Model, ModelType};
|
||||
use crate::config::{Agent, AppConfig, paths};
|
||||
use anyhow::{Result, bail};
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A single validation finding, optionally scoped to a node.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -81,19 +83,50 @@ impl ValidationResult {
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent-level context that `llm`-node `tools` and `model` references are
|
||||
/// validated against. Supplying it lets a typo in a node's `tools` or
|
||||
/// `model` be caught at graph load time instead of when the node runs.
|
||||
pub struct AgentValidationContext {
|
||||
pub tool_names: HashSet<String>,
|
||||
pub mcp_servers: HashSet<String>,
|
||||
pub app_config: Arc<AppConfig>,
|
||||
}
|
||||
|
||||
impl AgentValidationContext {
|
||||
pub fn from_agent(agent: &Agent, app_config: Arc<AppConfig>) -> Self {
|
||||
Self {
|
||||
tool_names: agent
|
||||
.functions()
|
||||
.declarations()
|
||||
.iter()
|
||||
.map(|d| d.name.clone())
|
||||
.collect(),
|
||||
mcp_servers: agent.mcp_server_names().iter().cloned().collect(),
|
||||
app_config,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Validator for graph structures. `base_dir` is used to resolve relative
|
||||
/// script paths (typically the owning agent's data directory).
|
||||
pub struct GraphValidator {
|
||||
base_dir: PathBuf,
|
||||
agent_ctx: Option<AgentValidationContext>,
|
||||
}
|
||||
|
||||
impl GraphValidator {
|
||||
pub fn new(base_dir: impl Into<PathBuf>) -> Self {
|
||||
Self {
|
||||
base_dir: base_dir.into(),
|
||||
agent_ctx: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_agent_context(mut self, ctx: AgentValidationContext) -> Self {
|
||||
self.agent_ctx = Some(ctx);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn validate(&self, graph: &Graph) -> ValidationResult {
|
||||
let mut result = ValidationResult::default();
|
||||
self.validate_node_references(graph, &mut result);
|
||||
@@ -104,6 +137,7 @@ impl GraphValidator {
|
||||
self.validate_agents(graph, &mut result);
|
||||
self.validate_approval_routes(graph, &mut result);
|
||||
self.validate_rag_nodes(graph, &mut result);
|
||||
self.validate_llm_nodes(graph, &mut result);
|
||||
result
|
||||
}
|
||||
|
||||
@@ -128,6 +162,43 @@ impl GraphValidator {
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_llm_nodes(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||
let Some(ctx) = &self.agent_ctx else {
|
||||
return;
|
||||
};
|
||||
for (node_id, node) in &graph.nodes {
|
||||
let NodeType::Llm(llm) = &node.node_type else {
|
||||
continue;
|
||||
};
|
||||
if let Some(tools) = &llm.tools {
|
||||
for entry in tools {
|
||||
if let Some(server) = entry.strip_prefix("mcp:") {
|
||||
if !ctx.mcp_servers.contains(server) {
|
||||
result.error(ValidationError::with_node(
|
||||
node_id,
|
||||
format!("llm node references unknown MCP server 'mcp:{server}'"),
|
||||
));
|
||||
}
|
||||
} else if !ctx.tool_names.contains(entry) {
|
||||
result.error(ValidationError::with_node(
|
||||
node_id,
|
||||
format!("llm node references unknown tool '{entry}'"),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(model_id) = &llm.model
|
||||
&& Model::retrieve_model(ctx.app_config.as_ref(), model_id, ModelType::Chat)
|
||||
.is_err()
|
||||
{
|
||||
result.error(ValidationError::with_node(
|
||||
node_id,
|
||||
format!("llm node references unknown model '{model_id}'"),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||
for (node_id, node) in &graph.nodes {
|
||||
for (target, label) in declared_targets(node) {
|
||||
@@ -490,6 +561,132 @@ mod tests {
|
||||
assert!(result.errors.iter().any(|e| e.message.contains("ghost")));
|
||||
}
|
||||
|
||||
fn agent_ctx(tools: &[&str], mcp: &[&str]) -> AgentValidationContext {
|
||||
AgentValidationContext {
|
||||
tool_names: tools.iter().map(|s| s.to_string()).collect(),
|
||||
mcp_servers: mcp.iter().map(|s| s.to_string()).collect(),
|
||||
app_config: Arc::new(AppConfig::default()),
|
||||
}
|
||||
}
|
||||
|
||||
fn llm_node_with(id: &str, tools: Option<Vec<&str>>, model: Option<&str>) -> Node {
|
||||
let mut node = llm_node(id, None, Some("end"));
|
||||
if let NodeType::Llm(ref mut n) = node.node_type {
|
||||
n.tools = tools.map(|t| t.iter().map(|s| s.to_string()).collect());
|
||||
n.model = model.map(String::from);
|
||||
}
|
||||
node
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_unknown_tool_is_an_error() {
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
("l", llm_node_with("l", Some(vec!["bogus_tool"]), None)),
|
||||
("end", end_node("end")),
|
||||
],
|
||||
"l",
|
||||
);
|
||||
let result = validator()
|
||||
.with_agent_context(agent_ctx(&["read_query"], &[]))
|
||||
.validate(&graph);
|
||||
assert!(!result.is_valid());
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("bogus_tool"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_known_tool_passes() {
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
("l", llm_node_with("l", Some(vec!["read_query"]), None)),
|
||||
("end", end_node("end")),
|
||||
],
|
||||
"l",
|
||||
);
|
||||
let result = validator()
|
||||
.with_agent_context(agent_ctx(&["read_query"], &[]))
|
||||
.validate(&graph);
|
||||
assert!(result.is_valid());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_unknown_mcp_server_is_an_error() {
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
("l", llm_node_with("l", Some(vec!["mcp:bogus"]), None)),
|
||||
("end", end_node("end")),
|
||||
],
|
||||
"l",
|
||||
);
|
||||
let result = validator()
|
||||
.with_agent_context(agent_ctx(&[], &["pubmed-search"]))
|
||||
.validate(&graph);
|
||||
assert!(!result.is_valid());
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("mcp:bogus"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_known_mcp_server_passes() {
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
(
|
||||
"l",
|
||||
llm_node_with("l", Some(vec!["mcp:pubmed-search"]), None),
|
||||
),
|
||||
("end", end_node("end")),
|
||||
],
|
||||
"l",
|
||||
);
|
||||
let result = validator()
|
||||
.with_agent_context(agent_ctx(&[], &["pubmed-search"]))
|
||||
.validate(&graph);
|
||||
assert!(result.is_valid());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_unknown_model_is_an_error() {
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
("l", llm_node_with("l", None, Some("nonexistent:model"))),
|
||||
("end", end_node("end")),
|
||||
],
|
||||
"l",
|
||||
);
|
||||
let result = validator()
|
||||
.with_agent_context(agent_ctx(&[], &[]))
|
||||
.validate(&graph);
|
||||
assert!(!result.is_valid());
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("nonexistent:model"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_validation_skipped_without_agent_context() {
|
||||
let graph = graph_with(
|
||||
vec![
|
||||
("l", llm_node_with("l", Some(vec!["bogus_tool"]), None)),
|
||||
("end", end_node("end")),
|
||||
],
|
||||
"l",
|
||||
);
|
||||
let result = validator().validate(&graph);
|
||||
assert!(result.is_valid());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rag_node_without_documents_errors() {
|
||||
let graph = graph_with(
|
||||
|
||||
Reference in New Issue
Block a user