feat: migrated llm node validation to graph loading time instead of graph runtime
This commit is contained in:
@@ -15,13 +15,14 @@ use super::script::ScriptExecutor;
|
|||||||
use super::state::StateManager;
|
use super::state::StateManager;
|
||||||
use super::types::{EndNode, Graph, Node, NodeType};
|
use super::types::{EndNode, Graph, Node, NodeType};
|
||||||
use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
|
use super::user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
|
||||||
use super::validator::GraphValidator;
|
use super::validator::{AgentValidationContext, GraphValidator};
|
||||||
use crate::config::RequestContext;
|
use crate::config::RequestContext;
|
||||||
use crate::utils::AbortSignal;
|
use crate::utils::AbortSignal;
|
||||||
use anyhow::{Context, Result, anyhow, bail};
|
use anyhow::{Context, Result, anyhow, bail};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
pub struct GraphExecutor {
|
pub struct GraphExecutor {
|
||||||
@@ -72,7 +73,13 @@ impl GraphExecutor {
|
|||||||
let GraphExecutor { graph, base_dir } = self;
|
let GraphExecutor { graph, base_dir } = self;
|
||||||
|
|
||||||
if graph.settings.validate_before_run {
|
if graph.settings.validate_before_run {
|
||||||
let validator = GraphValidator::new(&base_dir);
|
let mut validator = GraphValidator::new(&base_dir);
|
||||||
|
if let Some(agent) = &ctx.agent {
|
||||||
|
validator = validator.with_agent_context(AgentValidationContext::from_agent(
|
||||||
|
agent,
|
||||||
|
Arc::clone(&ctx.app.config),
|
||||||
|
));
|
||||||
|
}
|
||||||
let result = validator.validate(&graph);
|
let result = validator.validate(&graph);
|
||||||
for w in &result.warnings {
|
for w in &result.warnings {
|
||||||
logger.validation_warning(w.node_id.as_deref(), &w.message);
|
logger.validation_warning(w.node_id.as_deref(), &w.message);
|
||||||
|
|||||||
+198
-1
@@ -10,10 +10,12 @@
|
|||||||
//! against dynamically-routed graphs.
|
//! against dynamically-routed graphs.
|
||||||
|
|
||||||
use super::types::{Graph, Node, NodeType};
|
use super::types::{Graph, Node, NodeType};
|
||||||
use crate::config::paths;
|
use crate::client::{Model, ModelType};
|
||||||
|
use crate::config::{Agent, AppConfig, paths};
|
||||||
use anyhow::{Result, bail};
|
use anyhow::{Result, bail};
|
||||||
use std::collections::{HashSet, VecDeque};
|
use std::collections::{HashSet, VecDeque};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// A single validation finding, optionally scoped to a node.
|
/// A single validation finding, optionally scoped to a node.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@@ -81,19 +83,50 @@ impl ValidationResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Agent-level context that `llm`-node `tools` and `model` references are
|
||||||
|
/// validated against. Supplying it lets a typo in a node's `tools` or
|
||||||
|
/// `model` be caught at graph load time instead of when the node runs.
|
||||||
|
pub struct AgentValidationContext {
|
||||||
|
pub tool_names: HashSet<String>,
|
||||||
|
pub mcp_servers: HashSet<String>,
|
||||||
|
pub app_config: Arc<AppConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AgentValidationContext {
|
||||||
|
pub fn from_agent(agent: &Agent, app_config: Arc<AppConfig>) -> Self {
|
||||||
|
Self {
|
||||||
|
tool_names: agent
|
||||||
|
.functions()
|
||||||
|
.declarations()
|
||||||
|
.iter()
|
||||||
|
.map(|d| d.name.clone())
|
||||||
|
.collect(),
|
||||||
|
mcp_servers: agent.mcp_server_names().iter().cloned().collect(),
|
||||||
|
app_config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Validator for graph structures. `base_dir` is used to resolve relative
|
/// Validator for graph structures. `base_dir` is used to resolve relative
|
||||||
/// script paths (typically the owning agent's data directory).
|
/// script paths (typically the owning agent's data directory).
|
||||||
pub struct GraphValidator {
|
pub struct GraphValidator {
|
||||||
base_dir: PathBuf,
|
base_dir: PathBuf,
|
||||||
|
agent_ctx: Option<AgentValidationContext>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GraphValidator {
|
impl GraphValidator {
|
||||||
pub fn new(base_dir: impl Into<PathBuf>) -> Self {
|
pub fn new(base_dir: impl Into<PathBuf>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
base_dir: base_dir.into(),
|
base_dir: base_dir.into(),
|
||||||
|
agent_ctx: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn with_agent_context(mut self, ctx: AgentValidationContext) -> Self {
|
||||||
|
self.agent_ctx = Some(ctx);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn validate(&self, graph: &Graph) -> ValidationResult {
|
pub fn validate(&self, graph: &Graph) -> ValidationResult {
|
||||||
let mut result = ValidationResult::default();
|
let mut result = ValidationResult::default();
|
||||||
self.validate_node_references(graph, &mut result);
|
self.validate_node_references(graph, &mut result);
|
||||||
@@ -104,6 +137,7 @@ impl GraphValidator {
|
|||||||
self.validate_agents(graph, &mut result);
|
self.validate_agents(graph, &mut result);
|
||||||
self.validate_approval_routes(graph, &mut result);
|
self.validate_approval_routes(graph, &mut result);
|
||||||
self.validate_rag_nodes(graph, &mut result);
|
self.validate_rag_nodes(graph, &mut result);
|
||||||
|
self.validate_llm_nodes(graph, &mut result);
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,6 +162,43 @@ impl GraphValidator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_llm_nodes(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||||
|
let Some(ctx) = &self.agent_ctx else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
for (node_id, node) in &graph.nodes {
|
||||||
|
let NodeType::Llm(llm) = &node.node_type else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
if let Some(tools) = &llm.tools {
|
||||||
|
for entry in tools {
|
||||||
|
if let Some(server) = entry.strip_prefix("mcp:") {
|
||||||
|
if !ctx.mcp_servers.contains(server) {
|
||||||
|
result.error(ValidationError::with_node(
|
||||||
|
node_id,
|
||||||
|
format!("llm node references unknown MCP server 'mcp:{server}'"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
} else if !ctx.tool_names.contains(entry) {
|
||||||
|
result.error(ValidationError::with_node(
|
||||||
|
node_id,
|
||||||
|
format!("llm node references unknown tool '{entry}'"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(model_id) = &llm.model
|
||||||
|
&& Model::retrieve_model(ctx.app_config.as_ref(), model_id, ModelType::Chat)
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
result.error(ValidationError::with_node(
|
||||||
|
node_id,
|
||||||
|
format!("llm node references unknown model '{model_id}'"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
|
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||||
for (node_id, node) in &graph.nodes {
|
for (node_id, node) in &graph.nodes {
|
||||||
for (target, label) in declared_targets(node) {
|
for (target, label) in declared_targets(node) {
|
||||||
@@ -490,6 +561,132 @@ mod tests {
|
|||||||
assert!(result.errors.iter().any(|e| e.message.contains("ghost")));
|
assert!(result.errors.iter().any(|e| e.message.contains("ghost")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn agent_ctx(tools: &[&str], mcp: &[&str]) -> AgentValidationContext {
|
||||||
|
AgentValidationContext {
|
||||||
|
tool_names: tools.iter().map(|s| s.to_string()).collect(),
|
||||||
|
mcp_servers: mcp.iter().map(|s| s.to_string()).collect(),
|
||||||
|
app_config: Arc::new(AppConfig::default()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn llm_node_with(id: &str, tools: Option<Vec<&str>>, model: Option<&str>) -> Node {
|
||||||
|
let mut node = llm_node(id, None, Some("end"));
|
||||||
|
if let NodeType::Llm(ref mut n) = node.node_type {
|
||||||
|
n.tools = tools.map(|t| t.iter().map(|s| s.to_string()).collect());
|
||||||
|
n.model = model.map(String::from);
|
||||||
|
}
|
||||||
|
node
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_unknown_tool_is_an_error() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("l", llm_node_with("l", Some(vec!["bogus_tool"]), None)),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
let result = validator()
|
||||||
|
.with_agent_context(agent_ctx(&["read_query"], &[]))
|
||||||
|
.validate(&graph);
|
||||||
|
assert!(!result.is_valid());
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("bogus_tool"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_known_tool_passes() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("l", llm_node_with("l", Some(vec!["read_query"]), None)),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
let result = validator()
|
||||||
|
.with_agent_context(agent_ctx(&["read_query"], &[]))
|
||||||
|
.validate(&graph);
|
||||||
|
assert!(result.is_valid());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_unknown_mcp_server_is_an_error() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("l", llm_node_with("l", Some(vec!["mcp:bogus"]), None)),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
let result = validator()
|
||||||
|
.with_agent_context(agent_ctx(&[], &["pubmed-search"]))
|
||||||
|
.validate(&graph);
|
||||||
|
assert!(!result.is_valid());
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("mcp:bogus"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_known_mcp_server_passes() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
(
|
||||||
|
"l",
|
||||||
|
llm_node_with("l", Some(vec!["mcp:pubmed-search"]), None),
|
||||||
|
),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
let result = validator()
|
||||||
|
.with_agent_context(agent_ctx(&[], &["pubmed-search"]))
|
||||||
|
.validate(&graph);
|
||||||
|
assert!(result.is_valid());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_unknown_model_is_an_error() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("l", llm_node_with("l", None, Some("nonexistent:model"))),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
let result = validator()
|
||||||
|
.with_agent_context(agent_ctx(&[], &[]))
|
||||||
|
.validate(&graph);
|
||||||
|
assert!(!result.is_valid());
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("nonexistent:model"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_validation_skipped_without_agent_context() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("l", llm_node_with("l", Some(vec!["bogus_tool"]), None)),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
assert!(result.is_valid());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn rag_node_without_documents_errors() {
|
fn rag_node_without_documents_errors() {
|
||||||
let graph = graph_with(
|
let graph = graph_with(
|
||||||
|
|||||||
Reference in New Issue
Block a user