refactor: migrated llm nodes to use Roles to simplify instructions handling and to function like inline roles

This commit is contained in:
2026-05-14 13:24:34 -06:00
parent 99c6cff068
commit 5669830510
2 changed files with 37 additions and 19 deletions
+12 -7
View File
@@ -13,7 +13,7 @@ use super::state::StateManager;
use super::types::LlmNode;
use crate::config::RequestContext;
use crate::utils::dimmed_text;
use anyhow::{bail, Context, Result};
use anyhow::{Context, Result, bail};
use serde_json::Value;
const OUTPUT_KEY: &str = "output";
@@ -51,9 +51,14 @@ async fn run(
state_manager: &mut StateManager,
_parent_ctx: &mut RequestContext,
) -> Result<String> {
let _instructions = state_manager
.interpolate(&node.instructions)
.context("Failed to interpolate llm node instructions")?;
let _instructions: Option<String> = match &node.instructions {
Some(s) => Some(
state_manager
.interpolate(s)
.context("Failed to interpolate llm node instructions")?,
),
None => None,
};
let _prompt = state_manager
.interpolate(&node.prompt)
.context("Failed to interpolate llm node prompt")?;
@@ -119,7 +124,7 @@ fn apply_state_updates_with_output(node: &LlmNode, state_manager: &mut StateMana
fn describe_tools_filter(tools: Option<&[String]>) -> String {
match tools {
None => "<none>".into(),
None => "<inherit>".into(),
Some(t) if t.is_empty() => "<none>".into(),
Some(t) => t.join(","),
}
@@ -142,7 +147,7 @@ mod tests {
fn node_with(updates: Option<HashMap<String, String>>) -> LlmNode {
LlmNode {
instructions: "sys".into(),
instructions: Some("sys".into()),
prompt: "user".into(),
tools: None,
model: None,
@@ -208,7 +213,7 @@ mod tests {
#[test]
fn describe_tools_filter_renders_each_case() {
assert_eq!(describe_tools_filter(None), "<none>");
assert_eq!(describe_tools_filter(None), "<inherit>");
assert_eq!(describe_tools_filter(Some(&[])), "<none>");
let tools = vec!["a".to_string(), "b".to_string()];
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
+25 -12
View File
@@ -215,13 +215,21 @@ pub struct InputNode {
/// LLM's response on success, or to an error description on failure.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LlmNode {
pub instructions: String,
/// User-turn prompt. Templated against state. REQUIRED.
pub prompt: String,
/// Whitelist of tool names. Each entry is either an exact function
/// name or the shorthand `mcp:<server>` (expands to the three MCP
/// meta-functions for that server). Unset = no tools.
/// Optional system prompt. When set, the LLM call uses an inline
/// Role with `instructions` as `Role.prompt`. Templated against
/// state.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
/// Whitelist narrowing the active agent's tool universe.
/// Each entry is either an exact function name (`global_tools`
/// entry or `tools.{sh,py,ts}` subcommand) or the shorthand
/// `mcp:<server>` (where `<server>` must be in the agent's
/// `mcp_servers`). Unset = inherit agent's full set; `[]` = no
/// tools.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<String>>,
@@ -636,7 +644,7 @@ next: review
NodeType::Llm(l) => l,
_ => panic!("expected Llm variant"),
};
assert_eq!(llm.instructions, "You are a classifier.");
assert_eq!(llm.instructions.as_deref(), Some("You are a classifier."));
assert_eq!(llm.prompt, "Classify: {{input_text}}");
let tools = llm.tools.unwrap();
assert_eq!(tools, vec!["read_query", "mcp:pubmed-search"]);
@@ -668,7 +676,7 @@ next: done
NodeType::Llm(l) => l,
_ => panic!("expected Llm variant"),
};
assert_eq!(llm.instructions, "System.");
assert_eq!(llm.instructions.as_deref(), Some("System."));
assert_eq!(llm.prompt, "User.");
assert!(llm.tools.is_none());
assert!(llm.model.is_none());
@@ -678,14 +686,19 @@ next: done
}
#[test]
fn llm_node_missing_instructions_fails() {
fn llm_node_with_just_prompt_succeeds() {
let yaml = r#"
id: bad
id: pure
type: llm
prompt: "User only — no system prompt."
prompt: "User-only — no system prompt."
"#;
let result: std::result::Result<Node, _> = serde_yaml::from_str(yaml);
assert!(result.is_err());
let node: Node = serde_yaml::from_str(yaml).unwrap();
let llm = match node.node_type {
NodeType::Llm(l) => l,
_ => panic!("expected Llm variant"),
};
assert!(llm.instructions.is_none());
assert_eq!(llm.prompt, "User-only — no system prompt.");
}
#[test]