From fb8633dc75615f8b340afe57950ef57f6305c559 Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Tue, 2 Jun 2026 12:39:43 -0600 Subject: [PATCH] feat: llm graph nodes support skills --- src/config/agent.rs | 10 ++++ src/graph/llm.rs | 41 +++++++++++++ src/graph/types.rs | 12 ++++ src/graph/validator.rs | 131 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+) diff --git a/src/config/agent.rs b/src/config/agent.rs index b80ce98..e93a513 100644 --- a/src/config/agent.rs +++ b/src/config/agent.rs @@ -352,6 +352,14 @@ impl Agent { self.config.enabled_skills.as_deref() } + pub fn set_skills_enabled(&mut self, value: Option) { + self.config.skills_enabled = value; + } + + pub fn set_enabled_skills(&mut self, value: Option>) { + self.config.enabled_skills = value; + } + pub fn conversation_starters(&self) -> Vec { self.config .conversation_starters @@ -696,6 +704,8 @@ impl AgentConfig { description: graph.description.clone(), global_tools: graph.global_tools.clone(), mcp_servers: graph.mcp_servers.clone(), + skills_enabled: graph.skills_enabled, + enabled_skills: graph.enabled_skills.clone(), conversation_starters: graph.conversation_starters.clone(), variables: graph.variables.clone(), can_spawn_agents: graph.has_agent_node(), diff --git a/src/graph/llm.rs b/src/graph/llm.rs index a4411a1..53d24c9 100644 --- a/src/graph/llm.rs +++ b/src/graph/llm.rs @@ -113,6 +113,8 @@ async fn run( parent_ctx, )?; + let saved_agent_skill_state = swap_in_node_skill_policy(node, parent_ctx); + let saved_role = parent_ctx.role.clone(); parent_ctx.role = Some(role); let result = match node.timeout { @@ -128,9 +130,46 @@ async fn run( None => run_with_retries(node, &prompt, parent_ctx).await, }; parent_ctx.role = saved_role; + restore_agent_skill_policy(parent_ctx, saved_agent_skill_state); result } +struct SavedAgentSkillPolicy { + skills_enabled: Option, + enabled_skills: Option>, +} + +fn swap_in_node_skill_policy( + node: &LlmNode, + ctx: &mut RequestContext, +) -> Option { + let agent = ctx.agent.as_mut()?; + let saved = SavedAgentSkillPolicy { + skills_enabled: agent.skills_enabled(), + enabled_skills: agent.enabled_skills().map(|s| s.to_vec()), + }; + + if let Some(b) = node.skills_enabled { + agent.set_skills_enabled(Some(b)); + } + + if let Some(names) = &node.enabled_skills { + agent.set_enabled_skills(Some(names.clone())); + } + + Some(saved) +} + +fn restore_agent_skill_policy(ctx: &mut RequestContext, saved: Option) { + let Some(saved) = saved else { return }; + let Some(agent) = ctx.agent.as_mut() else { + return; + }; + + agent.set_skills_enabled(saved.skills_enabled); + agent.set_enabled_skills(saved.enabled_skills); +} + async fn run_with_retries( node: &LlmNode, prompt: &str, @@ -389,6 +428,8 @@ mod tests { state_updates: updates, output_schema: None, timeout: None, + skills_enabled: None, + enabled_skills: None, } } diff --git a/src/graph/types.rs b/src/graph/types.rs index ef0342f..35fd52c 100644 --- a/src/graph/types.rs +++ b/src/graph/types.rs @@ -31,6 +31,12 @@ pub struct Graph { #[serde(default)] pub mcp_servers: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub skills_enabled: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub enabled_skills: Option>, + #[serde(default)] pub conversation_starters: Vec, @@ -293,6 +299,12 @@ pub struct LlmNode { #[serde(default, skip_serializing_if = "Option::is_none")] pub timeout: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub skills_enabled: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub enabled_skills: Option>, } fn default_llm_max_attempts() -> u32 { diff --git a/src/graph/validator.rs b/src/graph/validator.rs index ca4e749..d92dc20 100644 --- a/src/graph/validator.rs +++ b/src/graph/validator.rs @@ -119,6 +119,7 @@ impl GraphValidator { self.validate_approval_routes(graph, &mut result); self.validate_rag_nodes(graph, &mut result); self.validate_llm_nodes(graph, &mut result); + self.validate_llm_skills(graph, &mut result); self.validate_max_concurrency(graph, &mut result); self.validate_map_branches(graph, &mut result); self.validate_parallel_user_interaction(graph, &mut result); @@ -189,6 +190,39 @@ impl GraphValidator { } } + fn validate_llm_skills(&self, graph: &Graph, result: &mut ValidationResult) { + for (node_id, node) in &graph.nodes { + let NodeType::Llm(llm) = &node.node_type else { + continue; + }; + let Some(node_skills) = &llm.enabled_skills else { + continue; + }; + + for name in node_skills { + if name.trim().is_empty() { + result.error(ValidationError::with_node( + node_id, + "llm node 'enabled_skills' contains an empty skill name", + )); + continue; + } + if let Some(graph_skills) = &graph.enabled_skills + && !graph_skills.iter().any(|g| g == name) + { + result.error(ValidationError::with_node( + node_id, + format!( + "llm node 'enabled_skills' references '{name}' which is not in \ + graph-level 'enabled_skills' ({})", + graph_skills.join(", ") + ), + )); + } + } + } + } + fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) { for (node_id, node) in &graph.nodes { for (target, label) in declared_targets(node) { @@ -847,6 +881,8 @@ mod tests { top_p: None, global_tools: Vec::new(), mcp_servers: Vec::new(), + skills_enabled: None, + enabled_skills: None, conversation_starters: Vec::new(), variables: Vec::new(), settings: GraphSettings::default(), @@ -946,6 +982,8 @@ mod tests { state_updates: None, output_schema: None, timeout: None, + skills_enabled: None, + enabled_skills: None, }), next: next.map(NextTargets::from), } @@ -967,6 +1005,99 @@ mod tests { assert!(result.errors.iter().any(|e| e.message.contains("ghost"))); } + #[test] + fn llm_node_skill_in_graph_set_passes() { + let mut graph = graph_with( + vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))], + "l", + ); + graph.enabled_skills = Some(vec!["code-review".into(), "git-master".into()]); + if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type { + n.enabled_skills = Some(vec!["code-review".into()]); + } + + let result = validator().validate(&graph); + + assert!( + !result + .errors + .iter() + .any(|e| e.message.contains("enabled_skills")), + "unexpected enabled_skills error: {:?}", + result.errors + ); + } + + #[test] + fn llm_node_skill_not_in_graph_set_errors() { + let mut graph = graph_with( + vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))], + "l", + ); + graph.enabled_skills = Some(vec!["code-review".into()]); + if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type { + n.enabled_skills = Some(vec!["git-master".into()]); + } + + let result = validator().validate(&graph); + + assert!(!result.is_valid()); + assert!( + result.errors.iter().any(|e| e + .message + .contains("'git-master'") + && e.message.contains("graph-level")), + "expected git-master subset error, got: {:?}", + result.errors + ); + } + + #[test] + fn llm_node_empty_skill_name_errors() { + let mut graph = graph_with( + vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))], + "l", + ); + graph.enabled_skills = Some(vec!["code-review".into()]); + if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type { + n.enabled_skills = Some(vec!["".into()]); + } + + let result = validator().validate(&graph); + + assert!(!result.is_valid()); + assert!( + result + .errors + .iter() + .any(|e| e.message.contains("empty skill name")), + "expected empty-skill-name error, got: {:?}", + result.errors + ); + } + + #[test] + fn llm_node_skill_when_no_graph_set_is_permitted_by_validator() { + let mut graph = graph_with( + vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))], + "l", + ); + if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type { + n.enabled_skills = Some(vec!["anything".into()]); + } + + let result = validator().validate(&graph); + + assert!( + !result + .errors + .iter() + .any(|e| e.message.contains("enabled_skills")), + "validator should not block when graph.enabled_skills is None: {:?}", + result.errors + ); + } + fn agent_ctx(tools: &[&str], mcp: &[&str]) -> AgentValidationContext { AgentValidationContext { tool_names: tools.iter().map(|s| s.to_string()).collect(),