feat: llm graph nodes support skills
This commit is contained in:
@@ -352,6 +352,14 @@ impl Agent {
|
||||
self.config.enabled_skills.as_deref()
|
||||
}
|
||||
|
||||
pub fn set_skills_enabled(&mut self, value: Option<bool>) {
|
||||
self.config.skills_enabled = value;
|
||||
}
|
||||
|
||||
pub fn set_enabled_skills(&mut self, value: Option<Vec<String>>) {
|
||||
self.config.enabled_skills = value;
|
||||
}
|
||||
|
||||
pub fn conversation_starters(&self) -> Vec<String> {
|
||||
self.config
|
||||
.conversation_starters
|
||||
@@ -696,6 +704,8 @@ impl AgentConfig {
|
||||
description: graph.description.clone(),
|
||||
global_tools: graph.global_tools.clone(),
|
||||
mcp_servers: graph.mcp_servers.clone(),
|
||||
skills_enabled: graph.skills_enabled,
|
||||
enabled_skills: graph.enabled_skills.clone(),
|
||||
conversation_starters: graph.conversation_starters.clone(),
|
||||
variables: graph.variables.clone(),
|
||||
can_spawn_agents: graph.has_agent_node(),
|
||||
|
||||
@@ -113,6 +113,8 @@ async fn run(
|
||||
parent_ctx,
|
||||
)?;
|
||||
|
||||
let saved_agent_skill_state = swap_in_node_skill_policy(node, parent_ctx);
|
||||
|
||||
let saved_role = parent_ctx.role.clone();
|
||||
parent_ctx.role = Some(role);
|
||||
let result = match node.timeout {
|
||||
@@ -128,9 +130,46 @@ async fn run(
|
||||
None => run_with_retries(node, &prompt, parent_ctx).await,
|
||||
};
|
||||
parent_ctx.role = saved_role;
|
||||
restore_agent_skill_policy(parent_ctx, saved_agent_skill_state);
|
||||
result
|
||||
}
|
||||
|
||||
struct SavedAgentSkillPolicy {
|
||||
skills_enabled: Option<bool>,
|
||||
enabled_skills: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn swap_in_node_skill_policy(
|
||||
node: &LlmNode,
|
||||
ctx: &mut RequestContext,
|
||||
) -> Option<SavedAgentSkillPolicy> {
|
||||
let agent = ctx.agent.as_mut()?;
|
||||
let saved = SavedAgentSkillPolicy {
|
||||
skills_enabled: agent.skills_enabled(),
|
||||
enabled_skills: agent.enabled_skills().map(|s| s.to_vec()),
|
||||
};
|
||||
|
||||
if let Some(b) = node.skills_enabled {
|
||||
agent.set_skills_enabled(Some(b));
|
||||
}
|
||||
|
||||
if let Some(names) = &node.enabled_skills {
|
||||
agent.set_enabled_skills(Some(names.clone()));
|
||||
}
|
||||
|
||||
Some(saved)
|
||||
}
|
||||
|
||||
fn restore_agent_skill_policy(ctx: &mut RequestContext, saved: Option<SavedAgentSkillPolicy>) {
|
||||
let Some(saved) = saved else { return };
|
||||
let Some(agent) = ctx.agent.as_mut() else {
|
||||
return;
|
||||
};
|
||||
|
||||
agent.set_skills_enabled(saved.skills_enabled);
|
||||
agent.set_enabled_skills(saved.enabled_skills);
|
||||
}
|
||||
|
||||
async fn run_with_retries(
|
||||
node: &LlmNode,
|
||||
prompt: &str,
|
||||
@@ -389,6 +428,8 @@ mod tests {
|
||||
state_updates: updates,
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
skills_enabled: None,
|
||||
enabled_skills: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,12 @@ pub struct Graph {
|
||||
#[serde(default)]
|
||||
pub mcp_servers: Vec<String>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub skills_enabled: Option<bool>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub enabled_skills: Option<Vec<String>>,
|
||||
|
||||
#[serde(default)]
|
||||
pub conversation_starters: Vec<String>,
|
||||
|
||||
@@ -293,6 +299,12 @@ pub struct LlmNode {
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub timeout: Option<u64>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub skills_enabled: Option<bool>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub enabled_skills: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
fn default_llm_max_attempts() -> u32 {
|
||||
|
||||
@@ -119,6 +119,7 @@ impl GraphValidator {
|
||||
self.validate_approval_routes(graph, &mut result);
|
||||
self.validate_rag_nodes(graph, &mut result);
|
||||
self.validate_llm_nodes(graph, &mut result);
|
||||
self.validate_llm_skills(graph, &mut result);
|
||||
self.validate_max_concurrency(graph, &mut result);
|
||||
self.validate_map_branches(graph, &mut result);
|
||||
self.validate_parallel_user_interaction(graph, &mut result);
|
||||
@@ -189,6 +190,39 @@ impl GraphValidator {
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_llm_skills(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||
for (node_id, node) in &graph.nodes {
|
||||
let NodeType::Llm(llm) = &node.node_type else {
|
||||
continue;
|
||||
};
|
||||
let Some(node_skills) = &llm.enabled_skills else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for name in node_skills {
|
||||
if name.trim().is_empty() {
|
||||
result.error(ValidationError::with_node(
|
||||
node_id,
|
||||
"llm node 'enabled_skills' contains an empty skill name",
|
||||
));
|
||||
continue;
|
||||
}
|
||||
if let Some(graph_skills) = &graph.enabled_skills
|
||||
&& !graph_skills.iter().any(|g| g == name)
|
||||
{
|
||||
result.error(ValidationError::with_node(
|
||||
node_id,
|
||||
format!(
|
||||
"llm node 'enabled_skills' references '{name}' which is not in \
|
||||
graph-level 'enabled_skills' ({})",
|
||||
graph_skills.join(", ")
|
||||
),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||
for (node_id, node) in &graph.nodes {
|
||||
for (target, label) in declared_targets(node) {
|
||||
@@ -847,6 +881,8 @@ mod tests {
|
||||
top_p: None,
|
||||
global_tools: Vec::new(),
|
||||
mcp_servers: Vec::new(),
|
||||
skills_enabled: None,
|
||||
enabled_skills: None,
|
||||
conversation_starters: Vec::new(),
|
||||
variables: Vec::new(),
|
||||
settings: GraphSettings::default(),
|
||||
@@ -946,6 +982,8 @@ mod tests {
|
||||
state_updates: None,
|
||||
output_schema: None,
|
||||
timeout: None,
|
||||
skills_enabled: None,
|
||||
enabled_skills: None,
|
||||
}),
|
||||
next: next.map(NextTargets::from),
|
||||
}
|
||||
@@ -967,6 +1005,99 @@ mod tests {
|
||||
assert!(result.errors.iter().any(|e| e.message.contains("ghost")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_skill_in_graph_set_passes() {
|
||||
let mut graph = graph_with(
|
||||
vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))],
|
||||
"l",
|
||||
);
|
||||
graph.enabled_skills = Some(vec!["code-review".into(), "git-master".into()]);
|
||||
if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type {
|
||||
n.enabled_skills = Some(vec!["code-review".into()]);
|
||||
}
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(
|
||||
!result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("enabled_skills")),
|
||||
"unexpected enabled_skills error: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_skill_not_in_graph_set_errors() {
|
||||
let mut graph = graph_with(
|
||||
vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))],
|
||||
"l",
|
||||
);
|
||||
graph.enabled_skills = Some(vec!["code-review".into()]);
|
||||
if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type {
|
||||
n.enabled_skills = Some(vec!["git-master".into()]);
|
||||
}
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(!result.is_valid());
|
||||
assert!(
|
||||
result.errors.iter().any(|e| e
|
||||
.message
|
||||
.contains("'git-master'")
|
||||
&& e.message.contains("graph-level")),
|
||||
"expected git-master subset error, got: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_empty_skill_name_errors() {
|
||||
let mut graph = graph_with(
|
||||
vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))],
|
||||
"l",
|
||||
);
|
||||
graph.enabled_skills = Some(vec!["code-review".into()]);
|
||||
if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type {
|
||||
n.enabled_skills = Some(vec!["".into()]);
|
||||
}
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(!result.is_valid());
|
||||
assert!(
|
||||
result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("empty skill name")),
|
||||
"expected empty-skill-name error, got: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn llm_node_skill_when_no_graph_set_is_permitted_by_validator() {
|
||||
let mut graph = graph_with(
|
||||
vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))],
|
||||
"l",
|
||||
);
|
||||
if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type {
|
||||
n.enabled_skills = Some(vec!["anything".into()]);
|
||||
}
|
||||
|
||||
let result = validator().validate(&graph);
|
||||
|
||||
assert!(
|
||||
!result
|
||||
.errors
|
||||
.iter()
|
||||
.any(|e| e.message.contains("enabled_skills")),
|
||||
"validator should not block when graph.enabled_skills is None: {:?}",
|
||||
result.errors
|
||||
);
|
||||
}
|
||||
|
||||
fn agent_ctx(tools: &[&str], mcp: &[&str]) -> AgentValidationContext {
|
||||
AgentValidationContext {
|
||||
tool_names: tools.iter().map(|s| s.to_string()).collect(),
|
||||
|
||||
Reference in New Issue
Block a user