feat: llm graph nodes support skills
This commit is contained in:
@@ -352,6 +352,14 @@ impl Agent {
|
|||||||
self.config.enabled_skills.as_deref()
|
self.config.enabled_skills.as_deref()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_skills_enabled(&mut self, value: Option<bool>) {
|
||||||
|
self.config.skills_enabled = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_enabled_skills(&mut self, value: Option<Vec<String>>) {
|
||||||
|
self.config.enabled_skills = value;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn conversation_starters(&self) -> Vec<String> {
|
pub fn conversation_starters(&self) -> Vec<String> {
|
||||||
self.config
|
self.config
|
||||||
.conversation_starters
|
.conversation_starters
|
||||||
@@ -696,6 +704,8 @@ impl AgentConfig {
|
|||||||
description: graph.description.clone(),
|
description: graph.description.clone(),
|
||||||
global_tools: graph.global_tools.clone(),
|
global_tools: graph.global_tools.clone(),
|
||||||
mcp_servers: graph.mcp_servers.clone(),
|
mcp_servers: graph.mcp_servers.clone(),
|
||||||
|
skills_enabled: graph.skills_enabled,
|
||||||
|
enabled_skills: graph.enabled_skills.clone(),
|
||||||
conversation_starters: graph.conversation_starters.clone(),
|
conversation_starters: graph.conversation_starters.clone(),
|
||||||
variables: graph.variables.clone(),
|
variables: graph.variables.clone(),
|
||||||
can_spawn_agents: graph.has_agent_node(),
|
can_spawn_agents: graph.has_agent_node(),
|
||||||
|
|||||||
@@ -113,6 +113,8 @@ async fn run(
|
|||||||
parent_ctx,
|
parent_ctx,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
let saved_agent_skill_state = swap_in_node_skill_policy(node, parent_ctx);
|
||||||
|
|
||||||
let saved_role = parent_ctx.role.clone();
|
let saved_role = parent_ctx.role.clone();
|
||||||
parent_ctx.role = Some(role);
|
parent_ctx.role = Some(role);
|
||||||
let result = match node.timeout {
|
let result = match node.timeout {
|
||||||
@@ -128,9 +130,46 @@ async fn run(
|
|||||||
None => run_with_retries(node, &prompt, parent_ctx).await,
|
None => run_with_retries(node, &prompt, parent_ctx).await,
|
||||||
};
|
};
|
||||||
parent_ctx.role = saved_role;
|
parent_ctx.role = saved_role;
|
||||||
|
restore_agent_skill_policy(parent_ctx, saved_agent_skill_state);
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct SavedAgentSkillPolicy {
|
||||||
|
skills_enabled: Option<bool>,
|
||||||
|
enabled_skills: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn swap_in_node_skill_policy(
|
||||||
|
node: &LlmNode,
|
||||||
|
ctx: &mut RequestContext,
|
||||||
|
) -> Option<SavedAgentSkillPolicy> {
|
||||||
|
let agent = ctx.agent.as_mut()?;
|
||||||
|
let saved = SavedAgentSkillPolicy {
|
||||||
|
skills_enabled: agent.skills_enabled(),
|
||||||
|
enabled_skills: agent.enabled_skills().map(|s| s.to_vec()),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(b) = node.skills_enabled {
|
||||||
|
agent.set_skills_enabled(Some(b));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(names) = &node.enabled_skills {
|
||||||
|
agent.set_enabled_skills(Some(names.clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(saved)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn restore_agent_skill_policy(ctx: &mut RequestContext, saved: Option<SavedAgentSkillPolicy>) {
|
||||||
|
let Some(saved) = saved else { return };
|
||||||
|
let Some(agent) = ctx.agent.as_mut() else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
agent.set_skills_enabled(saved.skills_enabled);
|
||||||
|
agent.set_enabled_skills(saved.enabled_skills);
|
||||||
|
}
|
||||||
|
|
||||||
async fn run_with_retries(
|
async fn run_with_retries(
|
||||||
node: &LlmNode,
|
node: &LlmNode,
|
||||||
prompt: &str,
|
prompt: &str,
|
||||||
@@ -389,6 +428,8 @@ mod tests {
|
|||||||
state_updates: updates,
|
state_updates: updates,
|
||||||
output_schema: None,
|
output_schema: None,
|
||||||
timeout: None,
|
timeout: None,
|
||||||
|
skills_enabled: None,
|
||||||
|
enabled_skills: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,12 @@ pub struct Graph {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub mcp_servers: Vec<String>,
|
pub mcp_servers: Vec<String>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub skills_enabled: Option<bool>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub enabled_skills: Option<Vec<String>>,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub conversation_starters: Vec<String>,
|
pub conversation_starters: Vec<String>,
|
||||||
|
|
||||||
@@ -293,6 +299,12 @@ pub struct LlmNode {
|
|||||||
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub timeout: Option<u64>,
|
pub timeout: Option<u64>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub skills_enabled: Option<bool>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub enabled_skills: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_llm_max_attempts() -> u32 {
|
fn default_llm_max_attempts() -> u32 {
|
||||||
|
|||||||
@@ -119,6 +119,7 @@ impl GraphValidator {
|
|||||||
self.validate_approval_routes(graph, &mut result);
|
self.validate_approval_routes(graph, &mut result);
|
||||||
self.validate_rag_nodes(graph, &mut result);
|
self.validate_rag_nodes(graph, &mut result);
|
||||||
self.validate_llm_nodes(graph, &mut result);
|
self.validate_llm_nodes(graph, &mut result);
|
||||||
|
self.validate_llm_skills(graph, &mut result);
|
||||||
self.validate_max_concurrency(graph, &mut result);
|
self.validate_max_concurrency(graph, &mut result);
|
||||||
self.validate_map_branches(graph, &mut result);
|
self.validate_map_branches(graph, &mut result);
|
||||||
self.validate_parallel_user_interaction(graph, &mut result);
|
self.validate_parallel_user_interaction(graph, &mut result);
|
||||||
@@ -189,6 +190,39 @@ impl GraphValidator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_llm_skills(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||||
|
for (node_id, node) in &graph.nodes {
|
||||||
|
let NodeType::Llm(llm) = &node.node_type else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let Some(node_skills) = &llm.enabled_skills else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
for name in node_skills {
|
||||||
|
if name.trim().is_empty() {
|
||||||
|
result.error(ValidationError::with_node(
|
||||||
|
node_id,
|
||||||
|
"llm node 'enabled_skills' contains an empty skill name",
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Some(graph_skills) = &graph.enabled_skills
|
||||||
|
&& !graph_skills.iter().any(|g| g == name)
|
||||||
|
{
|
||||||
|
result.error(ValidationError::with_node(
|
||||||
|
node_id,
|
||||||
|
format!(
|
||||||
|
"llm node 'enabled_skills' references '{name}' which is not in \
|
||||||
|
graph-level 'enabled_skills' ({})",
|
||||||
|
graph_skills.join(", ")
|
||||||
|
),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
|
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||||
for (node_id, node) in &graph.nodes {
|
for (node_id, node) in &graph.nodes {
|
||||||
for (target, label) in declared_targets(node) {
|
for (target, label) in declared_targets(node) {
|
||||||
@@ -847,6 +881,8 @@ mod tests {
|
|||||||
top_p: None,
|
top_p: None,
|
||||||
global_tools: Vec::new(),
|
global_tools: Vec::new(),
|
||||||
mcp_servers: Vec::new(),
|
mcp_servers: Vec::new(),
|
||||||
|
skills_enabled: None,
|
||||||
|
enabled_skills: None,
|
||||||
conversation_starters: Vec::new(),
|
conversation_starters: Vec::new(),
|
||||||
variables: Vec::new(),
|
variables: Vec::new(),
|
||||||
settings: GraphSettings::default(),
|
settings: GraphSettings::default(),
|
||||||
@@ -946,6 +982,8 @@ mod tests {
|
|||||||
state_updates: None,
|
state_updates: None,
|
||||||
output_schema: None,
|
output_schema: None,
|
||||||
timeout: None,
|
timeout: None,
|
||||||
|
skills_enabled: None,
|
||||||
|
enabled_skills: None,
|
||||||
}),
|
}),
|
||||||
next: next.map(NextTargets::from),
|
next: next.map(NextTargets::from),
|
||||||
}
|
}
|
||||||
@@ -967,6 +1005,99 @@ mod tests {
|
|||||||
assert!(result.errors.iter().any(|e| e.message.contains("ghost")));
|
assert!(result.errors.iter().any(|e| e.message.contains("ghost")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_skill_in_graph_set_passes() {
|
||||||
|
let mut graph = graph_with(
|
||||||
|
vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
graph.enabled_skills = Some(vec!["code-review".into(), "git-master".into()]);
|
||||||
|
if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type {
|
||||||
|
n.enabled_skills = Some(vec!["code-review".into()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("enabled_skills")),
|
||||||
|
"unexpected enabled_skills error: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_skill_not_in_graph_set_errors() {
|
||||||
|
let mut graph = graph_with(
|
||||||
|
vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
graph.enabled_skills = Some(vec!["code-review".into()]);
|
||||||
|
if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type {
|
||||||
|
n.enabled_skills = Some(vec!["git-master".into()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(!result.is_valid());
|
||||||
|
assert!(
|
||||||
|
result.errors.iter().any(|e| e
|
||||||
|
.message
|
||||||
|
.contains("'git-master'")
|
||||||
|
&& e.message.contains("graph-level")),
|
||||||
|
"expected git-master subset error, got: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_empty_skill_name_errors() {
|
||||||
|
let mut graph = graph_with(
|
||||||
|
vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
graph.enabled_skills = Some(vec!["code-review".into()]);
|
||||||
|
if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type {
|
||||||
|
n.enabled_skills = Some(vec!["".into()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(!result.is_valid());
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("empty skill name")),
|
||||||
|
"expected empty-skill-name error, got: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn llm_node_skill_when_no_graph_set_is_permitted_by_validator() {
|
||||||
|
let mut graph = graph_with(
|
||||||
|
vec![("l", llm_node("l", None, Some("end"))), ("end", end_node("end"))],
|
||||||
|
"l",
|
||||||
|
);
|
||||||
|
if let NodeType::Llm(ref mut n) = graph.nodes.get_mut("l").unwrap().node_type {
|
||||||
|
n.enabled_skills = Some(vec!["anything".into()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("enabled_skills")),
|
||||||
|
"validator should not block when graph.enabled_skills is None: {:?}",
|
||||||
|
result.errors
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
fn agent_ctx(tools: &[&str], mcp: &[&str]) -> AgentValidationContext {
|
fn agent_ctx(tools: &[&str], mcp: &[&str]) -> AgentValidationContext {
|
||||||
AgentValidationContext {
|
AgentValidationContext {
|
||||||
tool_names: tools.iter().map(|s| s.to_string()).collect(),
|
tool_names: tools.iter().map(|s| s.to_string()).collect(),
|
||||||
|
|||||||
Reference in New Issue
Block a user