feat: initial support for RAG nodes in the graph execution system

This commit is contained in:
2026-05-15 14:11:23 -06:00
parent c70ac98223
commit 8a2f18204f
10 changed files with 454 additions and 47 deletions
+7
View File
@@ -10,6 +10,7 @@ use super::agent::AgentNodeExecutor;
use super::llm::LlmNodeExecutor;
use super::logging::GraphLogger;
use super::parser::GraphParser;
use super::rag::RagNodeExecutor;
use super::script::ScriptExecutor;
use super::state::StateManager;
use super::types::{EndNode, Graph, Node, NodeType};
@@ -204,6 +205,12 @@ async fn step(
let next = LlmNodeExecutor::execute(llm_node, node.next.as_deref(), state, ctx).await?;
Ok(StepResult::Continue(next))
}
NodeType::Rag(rag_node) => {
let next =
RagNodeExecutor::execute(rag_node, current, node.next.as_deref(), state, ctx)
.await?;
Ok(StepResult::Continue(next))
}
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
}
}
+2 -1
View File
@@ -9,11 +9,11 @@
//! The logger also accumulates per-node wall-clock timings and emits a
//! performance summary (slowest-first) when the graph completes.
use std::cmp::Reverse;
use super::state::StateManager;
use super::types::{Node, NodeType};
use crate::utils::dimmed_text;
use indexmap::IndexMap;
use std::cmp::Reverse;
use std::time::Duration;
#[derive(Debug, Clone, Default)]
@@ -161,6 +161,7 @@ fn node_type_label(node: &Node) -> &'static str {
NodeType::Approval(_) => "approval",
NodeType::Input(_) => "input",
NodeType::Llm(_) => "llm",
NodeType::Rag(_) => "rag",
NodeType::End(_) => "end",
}
}
+3 -1
View File
@@ -7,6 +7,7 @@ pub mod executor;
pub mod llm;
pub mod logging;
pub mod parser;
pub mod rag;
pub mod script;
pub mod state;
pub mod structured;
@@ -20,11 +21,12 @@ pub use executor::GraphExecutor;
pub use llm::LlmNodeExecutor;
pub use logging::GraphLogger;
pub use parser::{GraphParser, agent_has_graph};
pub use rag::RagNodeExecutor;
pub use script::ScriptExecutor;
pub use state::{StateManager, StateRepresentation};
pub use types::{
AgentNode, ApprovalNode, EndNode, Graph, GraphSettings, GraphState, InputNode, LlmNode, Node,
NodeType, ScriptNode,
NodeType, RagNode, ScriptNode,
};
pub use user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
pub use validator::{GraphValidator, ValidationError, ValidationResult};
+4 -2
View File
@@ -104,12 +104,14 @@ fn enhance_yaml_error(error: serde_yaml::Error) -> Error {
Each node requires `type` plus that type's fields:\n\
- agent: `agent`, `prompt`\n\
- script: `script`\n\
- approval: `question`, `options`, `routes`\n\
- approval: `question`, `options`, `routes`, `on_other`\n\
- input: `question`\n\
- llm: `prompt`\n\
- rag: `documents`\n\
- end: (no required fields)"
} else if msg.contains("unknown field") || msg.contains("unknown variant") {
"\n\nHint: Check for typos in field names or `type:` values.\n\
Valid node types: agent, script, approval, input, end."
Valid node types: agent, script, approval, input, llm, rag, end."
} else if msg.contains("invalid type") {
"\n\nHint: Check that field values have the correct type.\n\
- Strings should be quoted if they contain special characters\n\
+148
View File
@@ -0,0 +1,148 @@
//! Execution of `rag`-type graph nodes.
//!
//! A `rag` node runs a hybrid (vector + keyword) retrieval against the
//! per-node knowledge base built at agent-load time, and writes the result
//! into graph state. The result is exposed to `state_updates` as
//! `{{output}}` — a JSON object `{ context, sources }` where `sources` is
//! an array of source paths.
use super::state::StateManager;
use super::types::RagNode;
use crate::config::RequestContext;
use crate::utils::{create_abort_signal, dimmed_text};
use anyhow::{Context, Result, anyhow};
use serde_json::{Map, Value};
use std::time::Duration;
use tokio::time::timeout;
const OUTPUT_KEY: &str = "output";
const DEFAULT_QUERY: &str = "{{initial_prompt}}";
const DEFAULT_RAG_TIMEOUT_SECS: u64 = 120;
pub struct RagNodeExecutor;
impl RagNodeExecutor {
/// Interpolate the node's query, run the retrieval against this node's
/// knowledge base, expose the result as `{{output}}` for `state_updates`,
/// and return `node_next`.
pub async fn execute(
node: &RagNode,
node_id: &str,
node_next: Option<&str>,
state_manager: &mut StateManager,
ctx: &mut RequestContext,
) -> Result<String> {
let query_template = node.query.as_deref().unwrap_or(DEFAULT_QUERY);
let query = state_manager
.interpolate(query_template)
.context("Failed to interpolate rag node query")?;
let rag = ctx
.agent
.as_ref()
.and_then(|a| a.graph_rag(node_id))
.ok_or_else(|| anyhow!("rag node '{node_id}' has no initialized knowledge base"))?;
let top_k = node.top_k.unwrap_or_else(|| rag.configured_top_k());
let rerank = rag.configured_reranker();
eprintln!(
"{}",
dimmed_text(&format!("▸ rag lookup: node={node_id} top_k={top_k}"))
);
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_RAG_TIMEOUT_SECS));
let abort = create_abort_signal();
let (context, sources_str, _ids) =
timeout(timeout_dur, rag.search(&query, top_k, rerank, abort))
.await
.with_context(|| {
format!(
"rag node '{node_id}' timed out after {}s",
timeout_dur.as_secs()
)
})?
.with_context(|| format!("rag node '{node_id}' retrieval failed"))?;
let output = build_rag_output(context, &sources_str);
apply_state_updates(node, state_manager, &output);
node_next
.map(String::from)
.ok_or_else(|| anyhow!("rag node '{node_id}' has no `next` set"))
}
}
/// Assemble the `{{output}}` value as `{ "context": <ctx>, "sources": [...] }`.
/// `Rag::search` returns sources as a `- {path}` bullet list; it is split
/// into a JSON array so downstream templates can index `{{output.sources[0]}}`.
fn build_rag_output(context: String, sources_str: &str) -> Value {
let sources: Vec<Value> = sources_str
.lines()
.map(|line| line.trim().trim_start_matches("- ").trim())
.filter(|s| !s.is_empty())
.map(|s| Value::String(s.to_string()))
.collect();
let mut obj = Map::new();
obj.insert("context".into(), Value::String(context));
obj.insert("sources".into(), Value::Array(sources));
Value::Object(obj)
}
/// Expose the retrieval result as `{{output}}` for the duration of
/// `state_updates` evaluation, then restore the prior value. Same scoping
/// pattern as `llm`/`agent` nodes.
fn apply_state_updates(node: &RagNode, state_manager: &mut StateManager, output: &Value) {
let Some(updates) = &node.state_updates else {
return;
};
let prev_output = state_manager.state().get(OUTPUT_KEY).cloned();
state_manager
.state_mut()
.set(OUTPUT_KEY.into(), output.clone());
for (key, template) in updates {
let value = state_manager.interpolate_lenient(template);
state_manager
.state_mut()
.set(key.clone(), Value::String(value));
}
match prev_output {
Some(v) => state_manager.state_mut().set(OUTPUT_KEY.into(), v),
None => state_manager
.state_mut()
.set(OUTPUT_KEY.into(), Value::Null),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn build_rag_output_splits_bullet_sources_into_array() {
let out = build_rag_output("ctx".into(), "- a.md\n- https://x.com/spec");
assert_eq!(out["context"], json!("ctx"));
assert_eq!(out["sources"], json!(["a.md", "https://x.com/spec"]));
}
#[test]
fn build_rag_output_handles_empty_sources() {
let out = build_rag_output("ctx".into(), "");
assert_eq!(out["sources"], json!([]));
}
#[test]
fn build_rag_output_ignores_blank_lines() {
let out = build_rag_output("c".into(), "- a\n\n- b\n");
assert_eq!(out["sources"], json!(["a", "b"]));
}
#[test]
fn build_rag_output_tolerates_unprefixed_lines() {
let out = build_rag_output("c".into(), "plain/path");
assert_eq!(out["sources"], json!(["plain/path"]));
}
}
+30
View File
@@ -152,6 +152,7 @@ pub enum NodeType {
Approval(ApprovalNode),
Input(InputNode),
Llm(LlmNode),
Rag(RagNode),
End(EndNode),
}
@@ -328,6 +329,35 @@ fn default_llm_max_iterations() -> u32 {
10
}
/// `rag`-type node: run a hybrid (vector + keyword) retrieval against a
/// per-node knowledge base and write the result into state. The retrieved
/// context and the list of source paths are exposed to `state_updates` via
/// `{{output.context}}` and `{{output.sources}}` (the whole result is
/// `{{output}}`, a JSON object). The knowledge base is built once at agent
/// load time into `<agent>/<node-id>.yaml`.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RagNode {
/// Knowledge sources (files, directories, URLs, loader-protocol paths).
/// REQUIRED — this is what makes the node a RAG node.
pub documents: Vec<String>,
/// Retrieval query, templated against state. Defaults to
/// `{{initial_prompt}}` when omitted.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub query: Option<String>,
/// Number of chunks to retrieve. Defaults to the knowledge base's own
/// configured `top_k` when omitted.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub top_k: Option<usize>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub state_updates: Option<HashMap<String, String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,
}
/// `end`-type node: terminate execution; `output` (templated) is returned
/// as the graph's final result.
#[derive(Debug, Clone, Deserialize, Serialize)]
+99 -1
View File
@@ -103,9 +103,31 @@ impl GraphValidator {
self.validate_scripts(graph, &mut result);
self.validate_agents(graph, &mut result);
self.validate_approval_routes(graph, &mut result);
self.validate_rag_nodes(graph, &mut result);
result
}
fn validate_rag_nodes(&self, graph: &Graph, result: &mut ValidationResult) {
for (node_id, node) in &graph.nodes {
if let NodeType::Rag(r) = &node.node_type {
if r.documents.is_empty() {
result.error(ValidationError::with_node(
node_id,
"RAG node has no 'documents'; at least one knowledge source \
is required",
));
}
if r.state_updates.is_none() {
result.warning(ValidationError::with_node(
node_id,
"RAG node has no 'state_updates'; its retrieval result will \
not be written to state",
));
}
}
}
}
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
for (node_id, node) in &graph.nodes {
for (target, label) in declared_targets(node) {
@@ -272,7 +294,9 @@ fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
out.push((t.clone(), "llm 'fallback'"));
}
}
NodeType::Agent(_) | NodeType::End(_) => {}
// `agent`/`rag` route only via `next` (already collected above);
// `end` is terminal. No type-specific routing edges to add.
NodeType::Agent(_) | NodeType::Rag(_) | NodeType::End(_) => {}
}
out
}
@@ -416,6 +440,80 @@ mod tests {
}
}
fn rag_node(id: &str, documents: &[&str], with_state_updates: bool) -> Node {
let state_updates = with_state_updates.then(|| {
let mut m: HashMap<String, String> = HashMap::new();
m.insert("ctx".into(), "{{output.context}}".into());
m
});
Node {
id: id.into(),
description: String::new(),
node_type: NodeType::Rag(RagNode {
documents: documents.iter().map(|s| (*s).into()).collect(),
query: None,
top_k: None,
state_updates,
timeout: None,
}),
next: Some("end".into()),
}
}
#[test]
fn rag_node_without_documents_errors() {
let graph = graph_with(
vec![("r", rag_node("r", &[], true)), ("end", end_node("end"))],
"r",
);
let result = validator().validate(&graph);
assert!(!result.is_valid());
assert!(
result
.errors
.iter()
.any(|e| e.message.contains("no 'documents'") && e.node_id.as_deref() == Some("r"))
);
}
#[test]
fn rag_node_without_state_updates_warns() {
let graph = graph_with(
vec![
("r", rag_node("r", &["./docs"], false)),
("end", end_node("end")),
],
"r",
);
let result = validator().validate(&graph);
assert!(result.is_valid());
assert!(
result
.warnings
.iter()
.any(|w| w.message.contains("no 'state_updates'"))
);
}
#[test]
fn valid_rag_node_produces_no_findings() {
let graph = graph_with(
vec![
("r", rag_node("r", &["./docs"], true)),
("end", end_node("end")),
],
"r",
);
let result = validator().validate(&graph);
assert!(result.is_valid());
assert!(
!result
.warnings
.iter()
.any(|w| w.message.contains("RAG node"))
);
}
fn agent_node(id: &str, agent: &str, next: Option<&str>) -> Node {
Node {
id: id.into(),