diff --git a/src/graph/executor.rs b/src/graph/executor.rs index b2e0f7c..71a389f 100644 --- a/src/graph/executor.rs +++ b/src/graph/executor.rs @@ -15,13 +15,14 @@ use super::script::ScriptExecutor; use super::state::StateManager; use super::types::{EndNode, Graph, Node, NodeType}; use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor}; -use super::validator::GraphValidator; +use super::validator::{AgentValidationContext, GraphValidator}; use crate::config::RequestContext; use crate::utils::AbortSignal; use anyhow::{Context, Result, anyhow, bail}; use serde_json::Value; use std::collections::HashMap; use std::path::{Path, PathBuf}; +use std::sync::Arc; use std::time::{Duration, Instant}; pub struct GraphExecutor { @@ -72,7 +73,13 @@ impl GraphExecutor { let GraphExecutor { graph, base_dir } = self; if graph.settings.validate_before_run { - let validator = GraphValidator::new(&base_dir); + let mut validator = GraphValidator::new(&base_dir); + if let Some(agent) = &ctx.agent { + validator = validator.with_agent_context(AgentValidationContext::from_agent( + agent, + Arc::clone(&ctx.app.config), + )); + } let result = validator.validate(&graph); for w in &result.warnings { logger.validation_warning(w.node_id.as_deref(), &w.message); diff --git a/src/graph/validator.rs b/src/graph/validator.rs index 0aff21b..cee0ee0 100644 --- a/src/graph/validator.rs +++ b/src/graph/validator.rs @@ -10,10 +10,12 @@ //! against dynamically-routed graphs. use super::types::{Graph, Node, NodeType}; -use crate::config::paths; +use crate::client::{Model, ModelType}; +use crate::config::{Agent, AppConfig, paths}; use anyhow::{Result, bail}; use std::collections::{HashSet, VecDeque}; use std::path::PathBuf; +use std::sync::Arc; /// A single validation finding, optionally scoped to a node. #[derive(Debug, Clone)] @@ -81,19 +83,50 @@ impl ValidationResult { } } +/// Agent-level context that `llm`-node `tools` and `model` references are +/// validated against. Supplying it lets a typo in a node's `tools` or +/// `model` be caught at graph load time instead of when the node runs. +pub struct AgentValidationContext { + pub tool_names: HashSet, + pub mcp_servers: HashSet, + pub app_config: Arc, +} + +impl AgentValidationContext { + pub fn from_agent(agent: &Agent, app_config: Arc) -> Self { + Self { + tool_names: agent + .functions() + .declarations() + .iter() + .map(|d| d.name.clone()) + .collect(), + mcp_servers: agent.mcp_server_names().iter().cloned().collect(), + app_config, + } + } +} + /// Validator for graph structures. `base_dir` is used to resolve relative /// script paths (typically the owning agent's data directory). pub struct GraphValidator { base_dir: PathBuf, + agent_ctx: Option, } impl GraphValidator { pub fn new(base_dir: impl Into) -> Self { Self { base_dir: base_dir.into(), + agent_ctx: None, } } + pub fn with_agent_context(mut self, ctx: AgentValidationContext) -> Self { + self.agent_ctx = Some(ctx); + self + } + pub fn validate(&self, graph: &Graph) -> ValidationResult { let mut result = ValidationResult::default(); self.validate_node_references(graph, &mut result); @@ -104,6 +137,7 @@ impl GraphValidator { self.validate_agents(graph, &mut result); self.validate_approval_routes(graph, &mut result); self.validate_rag_nodes(graph, &mut result); + self.validate_llm_nodes(graph, &mut result); result } @@ -128,6 +162,43 @@ impl GraphValidator { } } + fn validate_llm_nodes(&self, graph: &Graph, result: &mut ValidationResult) { + let Some(ctx) = &self.agent_ctx else { + return; + }; + for (node_id, node) in &graph.nodes { + let NodeType::Llm(llm) = &node.node_type else { + continue; + }; + if let Some(tools) = &llm.tools { + for entry in tools { + if let Some(server) = entry.strip_prefix("mcp:") { + if !ctx.mcp_servers.contains(server) { + result.error(ValidationError::with_node( + node_id, + format!("llm node references unknown MCP server 'mcp:{server}'"), + )); + } + } else if !ctx.tool_names.contains(entry) { + result.error(ValidationError::with_node( + node_id, + format!("llm node references unknown tool '{entry}'"), + )); + } + } + } + if let Some(model_id) = &llm.model + && Model::retrieve_model(ctx.app_config.as_ref(), model_id, ModelType::Chat) + .is_err() + { + result.error(ValidationError::with_node( + node_id, + format!("llm node references unknown model '{model_id}'"), + )); + } + } + } + fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) { for (node_id, node) in &graph.nodes { for (target, label) in declared_targets(node) { @@ -490,6 +561,132 @@ mod tests { assert!(result.errors.iter().any(|e| e.message.contains("ghost"))); } + fn agent_ctx(tools: &[&str], mcp: &[&str]) -> AgentValidationContext { + AgentValidationContext { + tool_names: tools.iter().map(|s| s.to_string()).collect(), + mcp_servers: mcp.iter().map(|s| s.to_string()).collect(), + app_config: Arc::new(AppConfig::default()), + } + } + + fn llm_node_with(id: &str, tools: Option>, model: Option<&str>) -> Node { + let mut node = llm_node(id, None, Some("end")); + if let NodeType::Llm(ref mut n) = node.node_type { + n.tools = tools.map(|t| t.iter().map(|s| s.to_string()).collect()); + n.model = model.map(String::from); + } + node + } + + #[test] + fn llm_node_unknown_tool_is_an_error() { + let graph = graph_with( + vec![ + ("l", llm_node_with("l", Some(vec!["bogus_tool"]), None)), + ("end", end_node("end")), + ], + "l", + ); + let result = validator() + .with_agent_context(agent_ctx(&["read_query"], &[])) + .validate(&graph); + assert!(!result.is_valid()); + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("bogus_tool")) + ); + } + + #[test] + fn llm_node_known_tool_passes() { + let graph = graph_with( + vec![ + ("l", llm_node_with("l", Some(vec!["read_query"]), None)), + ("end", end_node("end")), + ], + "l", + ); + let result = validator() + .with_agent_context(agent_ctx(&["read_query"], &[])) + .validate(&graph); + assert!(result.is_valid()); + } + + #[test] + fn llm_node_unknown_mcp_server_is_an_error() { + let graph = graph_with( + vec![ + ("l", llm_node_with("l", Some(vec!["mcp:bogus"]), None)), + ("end", end_node("end")), + ], + "l", + ); + let result = validator() + .with_agent_context(agent_ctx(&[], &["pubmed-search"])) + .validate(&graph); + assert!(!result.is_valid()); + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("mcp:bogus")) + ); + } + + #[test] + fn llm_node_known_mcp_server_passes() { + let graph = graph_with( + vec![ + ( + "l", + llm_node_with("l", Some(vec!["mcp:pubmed-search"]), None), + ), + ("end", end_node("end")), + ], + "l", + ); + let result = validator() + .with_agent_context(agent_ctx(&[], &["pubmed-search"])) + .validate(&graph); + assert!(result.is_valid()); + } + + #[test] + fn llm_node_unknown_model_is_an_error() { + let graph = graph_with( + vec![ + ("l", llm_node_with("l", None, Some("nonexistent:model"))), + ("end", end_node("end")), + ], + "l", + ); + let result = validator() + .with_agent_context(agent_ctx(&[], &[])) + .validate(&graph); + assert!(!result.is_valid()); + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("nonexistent:model")) + ); + } + + #[test] + fn llm_node_validation_skipped_without_agent_context() { + let graph = graph_with( + vec![ + ("l", llm_node_with("l", Some(vec!["bogus_tool"]), None)), + ("end", end_node("end")), + ], + "l", + ); + let result = validator().validate(&graph); + assert!(result.is_valid()); + } + #[test] fn rag_node_without_documents_errors() { let graph = graph_with(