feat: initial support for RAG nodes in the graph execution system
This commit is contained in:
+138
-34
@@ -11,7 +11,7 @@ use crate::config::prompts::{
|
|||||||
DEFAULT_SPAWN_INSTRUCTIONS, DEFAULT_TEAMMATE_INSTRUCTIONS, DEFAULT_TODO_INSTRUCTIONS,
|
DEFAULT_SPAWN_INSTRUCTIONS, DEFAULT_TEAMMATE_INSTRUCTIONS, DEFAULT_TODO_INSTRUCTIONS,
|
||||||
DEFAULT_USER_INTERACTION_INSTRUCTIONS,
|
DEFAULT_USER_INTERACTION_INSTRUCTIONS,
|
||||||
};
|
};
|
||||||
use crate::graph::{Graph, GraphParser};
|
use crate::graph::{Graph, GraphParser, NodeType};
|
||||||
use crate::vault::SECRET_RE;
|
use crate::vault::SECRET_RE;
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use fancy_regex::Captures;
|
use fancy_regex::Captures;
|
||||||
@@ -38,6 +38,7 @@ pub struct Agent {
|
|||||||
session_dynamic_instructions: Option<String>,
|
session_dynamic_instructions: Option<String>,
|
||||||
functions: Functions,
|
functions: Functions,
|
||||||
rag: Option<Arc<Rag>>,
|
rag: Option<Arc<Rag>>,
|
||||||
|
graph_rags: HashMap<String, Arc<Rag>>,
|
||||||
model: Model,
|
model: Model,
|
||||||
vault: GlobalVault,
|
vault: GlobalVault,
|
||||||
}
|
}
|
||||||
@@ -99,6 +100,7 @@ impl Agent {
|
|||||||
let rag_path = paths::agent_rag_file(name, DEFAULT_AGENT_NAME);
|
let rag_path = paths::agent_rag_file(name, DEFAULT_AGENT_NAME);
|
||||||
let config_path = paths::agent_config_file(name);
|
let config_path = paths::agent_config_file(name);
|
||||||
let graph_path = paths::agent_graph_file(name);
|
let graph_path = paths::agent_graph_file(name);
|
||||||
|
let mut graph_for_rag: Option<Graph> = None;
|
||||||
let mut agent_config = match (config_path.exists(), graph_path.exists()) {
|
let mut agent_config = match (config_path.exists(), graph_path.exists()) {
|
||||||
(true, true) => bail!(
|
(true, true) => bail!(
|
||||||
"Agent '{name}' has both config.yaml and graph.yaml. A graph agent \
|
"Agent '{name}' has both config.yaml and graph.yaml. A graph agent \
|
||||||
@@ -111,7 +113,9 @@ impl Agent {
|
|||||||
let graph = parser
|
let graph = parser
|
||||||
.load_from_file(&graph_path)
|
.load_from_file(&graph_path)
|
||||||
.with_context(|| format!("Failed to load graph.yaml for agent '{name}'"))?;
|
.with_context(|| format!("Failed to load graph.yaml for agent '{name}'"))?;
|
||||||
AgentConfig::from_graph(name, &graph)
|
let config = AgentConfig::from_graph(name, &graph);
|
||||||
|
graph_for_rag = Some(graph);
|
||||||
|
config
|
||||||
}
|
}
|
||||||
(false, false) => bail!(
|
(false, false) => bail!(
|
||||||
"Agent '{name}' has neither a config.yaml nor a graph.yaml at '{}'",
|
"Agent '{name}' has neither a config.yaml nor a graph.yaml at '{}'",
|
||||||
@@ -154,44 +158,16 @@ impl Agent {
|
|||||||
.prompt()?;
|
.prompt()?;
|
||||||
}
|
}
|
||||||
if ans {
|
if ans {
|
||||||
let mut document_paths = vec![];
|
let document_paths =
|
||||||
for path in &agent_config.documents {
|
resolve_document_paths(&agent_config.documents, &loaders, &agent_data_dir)?;
|
||||||
if is_url(path) {
|
|
||||||
document_paths.push(path.to_string());
|
|
||||||
} else if is_loader_protocol(&loaders, path) {
|
|
||||||
let (protocol, document_path) = path
|
|
||||||
.split_once(':')
|
|
||||||
.with_context(|| "Invalid loader protocol path")?;
|
|
||||||
let resolved_path = resolve_home_dir(document_path);
|
|
||||||
let new_path = if Path::new(&resolved_path).is_relative() {
|
|
||||||
safe_join_path(&agent_data_dir, resolved_path)
|
|
||||||
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?
|
|
||||||
} else {
|
|
||||||
PathBuf::from(&resolved_path)
|
|
||||||
};
|
|
||||||
document_paths.push(format!("{}:{}", protocol, new_path.display()));
|
|
||||||
} else if Path::new(&resolve_home_dir(path)).is_relative() {
|
|
||||||
let new_path = safe_join_path(&agent_data_dir, path)
|
|
||||||
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?;
|
|
||||||
document_paths.push(new_path.display().to_string())
|
|
||||||
} else {
|
|
||||||
document_paths.push(path.to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let key = RagKey::Agent(name.to_string());
|
let key = RagKey::Agent(name.to_string());
|
||||||
let app_clone = app.clone();
|
let app_clone = app.clone();
|
||||||
let rag_path_clone = rag_path.clone();
|
let rag_path_clone = rag_path.clone();
|
||||||
|
let abort = abort_signal.clone();
|
||||||
let rag = app_state
|
let rag = app_state
|
||||||
.rag_cache
|
.rag_cache
|
||||||
.load_with(key, || async move {
|
.load_with(key, || async move {
|
||||||
Rag::init(
|
Rag::init(&app_clone, "rag", &rag_path_clone, &document_paths, abort).await
|
||||||
&app_clone,
|
|
||||||
"rag",
|
|
||||||
&rag_path_clone,
|
|
||||||
&document_paths,
|
|
||||||
abort_signal,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
})
|
})
|
||||||
.await?;
|
.await?;
|
||||||
Some(rag)
|
Some(rag)
|
||||||
@@ -202,6 +178,23 @@ impl Agent {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let graph_rags = match &graph_for_rag {
|
||||||
|
Some(graph) => {
|
||||||
|
init_graph_rags(
|
||||||
|
app,
|
||||||
|
app_state,
|
||||||
|
name,
|
||||||
|
graph,
|
||||||
|
&agent_data_dir,
|
||||||
|
&loaders,
|
||||||
|
info_flag,
|
||||||
|
abort_signal.clone(),
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
}
|
||||||
|
None => HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
if agent_config.auto_continue {
|
if agent_config.auto_continue {
|
||||||
functions.append_todo_functions();
|
functions.append_todo_functions();
|
||||||
}
|
}
|
||||||
@@ -224,6 +217,7 @@ impl Agent {
|
|||||||
session_dynamic_instructions: None,
|
session_dynamic_instructions: None,
|
||||||
functions,
|
functions,
|
||||||
rag,
|
rag,
|
||||||
|
graph_rags,
|
||||||
model,
|
model,
|
||||||
vault: app_state.vault.clone(),
|
vault: app_state.vault.clone(),
|
||||||
})
|
})
|
||||||
@@ -330,6 +324,10 @@ impl Agent {
|
|||||||
self.rag.clone()
|
self.rag.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn graph_rag(&self, node_id: &str) -> Option<Arc<Rag>> {
|
||||||
|
self.graph_rags.get(node_id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn append_mcp_meta_functions(&mut self, mcp_servers: Vec<String>) {
|
pub fn append_mcp_meta_functions(&mut self, mcp_servers: Vec<String>) {
|
||||||
self.functions.append_mcp_meta_functions(mcp_servers);
|
self.functions.append_mcp_meta_functions(mcp_servers);
|
||||||
}
|
}
|
||||||
@@ -784,6 +782,112 @@ pub struct AgentVariable {
|
|||||||
pub value: String,
|
pub value: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolve document path specs (URLs, loader-protocol paths, relative or
|
||||||
|
/// absolute file paths) into the concrete paths `Rag::init` expects.
|
||||||
|
/// Relative paths are joined against the agent's data directory.
|
||||||
|
fn resolve_document_paths(
|
||||||
|
documents: &[String],
|
||||||
|
loaders: &HashMap<String, String>,
|
||||||
|
agent_data_dir: &Path,
|
||||||
|
) -> Result<Vec<String>> {
|
||||||
|
let mut document_paths = vec![];
|
||||||
|
for path in documents {
|
||||||
|
if is_url(path) {
|
||||||
|
document_paths.push(path.to_string());
|
||||||
|
} else if is_loader_protocol(loaders, path) {
|
||||||
|
let (protocol, document_path) = path
|
||||||
|
.split_once(':')
|
||||||
|
.with_context(|| "Invalid loader protocol path")?;
|
||||||
|
let resolved_path = resolve_home_dir(document_path);
|
||||||
|
let new_path = if Path::new(&resolved_path).is_relative() {
|
||||||
|
safe_join_path(agent_data_dir, resolved_path)
|
||||||
|
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?
|
||||||
|
} else {
|
||||||
|
PathBuf::from(&resolved_path)
|
||||||
|
};
|
||||||
|
document_paths.push(format!("{}:{}", protocol, new_path.display()));
|
||||||
|
} else if Path::new(&resolve_home_dir(path)).is_relative() {
|
||||||
|
let new_path = safe_join_path(agent_data_dir, path)
|
||||||
|
.ok_or_else(|| anyhow!("Invalid document path: '{path}'"))?;
|
||||||
|
document_paths.push(new_path.display().to_string())
|
||||||
|
} else {
|
||||||
|
document_paths.push(path.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(document_paths)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build or load a knowledge base for every `rag` node in the graph. Each
|
||||||
|
/// node's RAG lives in `<agent>/<node-id>.yaml`. A missing knowledge base is
|
||||||
|
/// a hard error (interactive: after a declined confirm; non-interactive:
|
||||||
|
/// immediately) — a graph with an uninitialized `rag` node cannot run.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
async fn init_graph_rags(
|
||||||
|
app: &AppConfig,
|
||||||
|
app_state: &AppState,
|
||||||
|
agent_name: &str,
|
||||||
|
graph: &Graph,
|
||||||
|
agent_data_dir: &Path,
|
||||||
|
loaders: &HashMap<String, String>,
|
||||||
|
info_flag: bool,
|
||||||
|
abort_signal: AbortSignal,
|
||||||
|
) -> Result<HashMap<String, Arc<Rag>>> {
|
||||||
|
let mut rags = HashMap::new();
|
||||||
|
for (node_id, node) in &graph.nodes {
|
||||||
|
let NodeType::Rag(rag_node) = &node.node_type else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let rag_path = paths::agent_rag_file(agent_name, node_id);
|
||||||
|
let key = RagKey::GraphNode {
|
||||||
|
agent: agent_name.to_string(),
|
||||||
|
node: node_id.clone(),
|
||||||
|
};
|
||||||
|
let rag = if rag_path.exists() {
|
||||||
|
let app_clone = app.clone();
|
||||||
|
let path_clone = rag_path.clone();
|
||||||
|
let name_clone = node_id.clone();
|
||||||
|
app_state
|
||||||
|
.rag_cache
|
||||||
|
.load_with(key, || async move {
|
||||||
|
Rag::load(&app_clone, &name_clone, &path_clone)
|
||||||
|
})
|
||||||
|
.await?
|
||||||
|
} else if info_flag || !*IS_STDOUT_TERMINAL {
|
||||||
|
bail!(
|
||||||
|
"Agent '{agent_name}' requires RAG for rag node '{node_id}', but its \
|
||||||
|
knowledge base has not been built. Run the agent once interactively \
|
||||||
|
to initialize it."
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
let ans = Confirm::new(&format!(
|
||||||
|
"Initialize RAG knowledge base for rag node '{node_id}'?"
|
||||||
|
))
|
||||||
|
.with_default(true)
|
||||||
|
.prompt()?;
|
||||||
|
if !ans {
|
||||||
|
bail!(
|
||||||
|
"Agent '{agent_name}' has rag node '{node_id}' but its RAG was not \
|
||||||
|
initialized. RAG initialization is required for this agent."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let document_paths =
|
||||||
|
resolve_document_paths(&rag_node.documents, loaders, agent_data_dir)?;
|
||||||
|
let app_clone = app.clone();
|
||||||
|
let path_clone = rag_path.clone();
|
||||||
|
let name_clone = node_id.clone();
|
||||||
|
let abort = abort_signal.clone();
|
||||||
|
app_state
|
||||||
|
.rag_cache
|
||||||
|
.load_with(key, || async move {
|
||||||
|
Rag::init(&app_clone, &name_clone, &path_clone, &document_paths, abort).await
|
||||||
|
})
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
rags.insert(node_id.clone(), rag);
|
||||||
|
}
|
||||||
|
Ok(rags)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn list_agents() -> Vec<String> {
|
pub fn list_agents() -> Vec<String> {
|
||||||
let agents_data_dir = paths::agents_data_dir();
|
let agents_data_dir = paths::agents_data_dir();
|
||||||
if !agents_data_dir.exists() {
|
if !agents_data_dir.exists() {
|
||||||
|
|||||||
@@ -9,6 +9,12 @@ use std::sync::{Arc, Weak};
|
|||||||
pub enum RagKey {
|
pub enum RagKey {
|
||||||
Named(String),
|
Named(String),
|
||||||
Agent(String),
|
Agent(String),
|
||||||
|
/// A `rag` node's per-node knowledge base, keyed by owning agent name
|
||||||
|
/// and node id.
|
||||||
|
GraphNode {
|
||||||
|
agent: String,
|
||||||
|
node: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use super::agent::AgentNodeExecutor;
|
|||||||
use super::llm::LlmNodeExecutor;
|
use super::llm::LlmNodeExecutor;
|
||||||
use super::logging::GraphLogger;
|
use super::logging::GraphLogger;
|
||||||
use super::parser::GraphParser;
|
use super::parser::GraphParser;
|
||||||
|
use super::rag::RagNodeExecutor;
|
||||||
use super::script::ScriptExecutor;
|
use super::script::ScriptExecutor;
|
||||||
use super::state::StateManager;
|
use super::state::StateManager;
|
||||||
use super::types::{EndNode, Graph, Node, NodeType};
|
use super::types::{EndNode, Graph, Node, NodeType};
|
||||||
@@ -204,6 +205,12 @@ async fn step(
|
|||||||
let next = LlmNodeExecutor::execute(llm_node, node.next.as_deref(), state, ctx).await?;
|
let next = LlmNodeExecutor::execute(llm_node, node.next.as_deref(), state, ctx).await?;
|
||||||
Ok(StepResult::Continue(next))
|
Ok(StepResult::Continue(next))
|
||||||
}
|
}
|
||||||
|
NodeType::Rag(rag_node) => {
|
||||||
|
let next =
|
||||||
|
RagNodeExecutor::execute(rag_node, current, node.next.as_deref(), state, ctx)
|
||||||
|
.await?;
|
||||||
|
Ok(StepResult::Continue(next))
|
||||||
|
}
|
||||||
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
NodeType::End(end_node) => Ok(StepResult::End(resolve_end_output(end_node, state))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,11 +9,11 @@
|
|||||||
//! The logger also accumulates per-node wall-clock timings and emits a
|
//! The logger also accumulates per-node wall-clock timings and emits a
|
||||||
//! performance summary (slowest-first) when the graph completes.
|
//! performance summary (slowest-first) when the graph completes.
|
||||||
|
|
||||||
use std::cmp::Reverse;
|
|
||||||
use super::state::StateManager;
|
use super::state::StateManager;
|
||||||
use super::types::{Node, NodeType};
|
use super::types::{Node, NodeType};
|
||||||
use crate::utils::dimmed_text;
|
use crate::utils::dimmed_text;
|
||||||
use indexmap::IndexMap;
|
use indexmap::IndexMap;
|
||||||
|
use std::cmp::Reverse;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
@@ -161,6 +161,7 @@ fn node_type_label(node: &Node) -> &'static str {
|
|||||||
NodeType::Approval(_) => "approval",
|
NodeType::Approval(_) => "approval",
|
||||||
NodeType::Input(_) => "input",
|
NodeType::Input(_) => "input",
|
||||||
NodeType::Llm(_) => "llm",
|
NodeType::Llm(_) => "llm",
|
||||||
|
NodeType::Rag(_) => "rag",
|
||||||
NodeType::End(_) => "end",
|
NodeType::End(_) => "end",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+3
-1
@@ -7,6 +7,7 @@ pub mod executor;
|
|||||||
pub mod llm;
|
pub mod llm;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
pub mod parser;
|
pub mod parser;
|
||||||
|
pub mod rag;
|
||||||
pub mod script;
|
pub mod script;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
pub mod structured;
|
pub mod structured;
|
||||||
@@ -20,11 +21,12 @@ pub use executor::GraphExecutor;
|
|||||||
pub use llm::LlmNodeExecutor;
|
pub use llm::LlmNodeExecutor;
|
||||||
pub use logging::GraphLogger;
|
pub use logging::GraphLogger;
|
||||||
pub use parser::{GraphParser, agent_has_graph};
|
pub use parser::{GraphParser, agent_has_graph};
|
||||||
|
pub use rag::RagNodeExecutor;
|
||||||
pub use script::ScriptExecutor;
|
pub use script::ScriptExecutor;
|
||||||
pub use state::{StateManager, StateRepresentation};
|
pub use state::{StateManager, StateRepresentation};
|
||||||
pub use types::{
|
pub use types::{
|
||||||
AgentNode, ApprovalNode, EndNode, Graph, GraphSettings, GraphState, InputNode, LlmNode, Node,
|
AgentNode, ApprovalNode, EndNode, Graph, GraphSettings, GraphState, InputNode, LlmNode, Node,
|
||||||
NodeType, ScriptNode,
|
NodeType, RagNode, ScriptNode,
|
||||||
};
|
};
|
||||||
pub use user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
|
pub use user_interaction::{ApprovalNodeExecutor, InputNodeExecutor};
|
||||||
pub use validator::{GraphValidator, ValidationError, ValidationResult};
|
pub use validator::{GraphValidator, ValidationError, ValidationResult};
|
||||||
|
|||||||
+4
-2
@@ -104,12 +104,14 @@ fn enhance_yaml_error(error: serde_yaml::Error) -> Error {
|
|||||||
Each node requires `type` plus that type's fields:\n\
|
Each node requires `type` plus that type's fields:\n\
|
||||||
- agent: `agent`, `prompt`\n\
|
- agent: `agent`, `prompt`\n\
|
||||||
- script: `script`\n\
|
- script: `script`\n\
|
||||||
- approval: `question`, `options`, `routes`\n\
|
- approval: `question`, `options`, `routes`, `on_other`\n\
|
||||||
- input: `question`\n\
|
- input: `question`\n\
|
||||||
|
- llm: `prompt`\n\
|
||||||
|
- rag: `documents`\n\
|
||||||
- end: (no required fields)"
|
- end: (no required fields)"
|
||||||
} else if msg.contains("unknown field") || msg.contains("unknown variant") {
|
} else if msg.contains("unknown field") || msg.contains("unknown variant") {
|
||||||
"\n\nHint: Check for typos in field names or `type:` values.\n\
|
"\n\nHint: Check for typos in field names or `type:` values.\n\
|
||||||
Valid node types: agent, script, approval, input, end."
|
Valid node types: agent, script, approval, input, llm, rag, end."
|
||||||
} else if msg.contains("invalid type") {
|
} else if msg.contains("invalid type") {
|
||||||
"\n\nHint: Check that field values have the correct type.\n\
|
"\n\nHint: Check that field values have the correct type.\n\
|
||||||
- Strings should be quoted if they contain special characters\n\
|
- Strings should be quoted if they contain special characters\n\
|
||||||
|
|||||||
@@ -0,0 +1,148 @@
|
|||||||
|
//! Execution of `rag`-type graph nodes.
|
||||||
|
//!
|
||||||
|
//! A `rag` node runs a hybrid (vector + keyword) retrieval against the
|
||||||
|
//! per-node knowledge base built at agent-load time, and writes the result
|
||||||
|
//! into graph state. The result is exposed to `state_updates` as
|
||||||
|
//! `{{output}}` — a JSON object `{ context, sources }` where `sources` is
|
||||||
|
//! an array of source paths.
|
||||||
|
|
||||||
|
use super::state::StateManager;
|
||||||
|
use super::types::RagNode;
|
||||||
|
use crate::config::RequestContext;
|
||||||
|
use crate::utils::{create_abort_signal, dimmed_text};
|
||||||
|
use anyhow::{Context, Result, anyhow};
|
||||||
|
use serde_json::{Map, Value};
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::time::timeout;
|
||||||
|
|
||||||
|
const OUTPUT_KEY: &str = "output";
|
||||||
|
const DEFAULT_QUERY: &str = "{{initial_prompt}}";
|
||||||
|
const DEFAULT_RAG_TIMEOUT_SECS: u64 = 120;
|
||||||
|
|
||||||
|
pub struct RagNodeExecutor;
|
||||||
|
|
||||||
|
impl RagNodeExecutor {
|
||||||
|
/// Interpolate the node's query, run the retrieval against this node's
|
||||||
|
/// knowledge base, expose the result as `{{output}}` for `state_updates`,
|
||||||
|
/// and return `node_next`.
|
||||||
|
pub async fn execute(
|
||||||
|
node: &RagNode,
|
||||||
|
node_id: &str,
|
||||||
|
node_next: Option<&str>,
|
||||||
|
state_manager: &mut StateManager,
|
||||||
|
ctx: &mut RequestContext,
|
||||||
|
) -> Result<String> {
|
||||||
|
let query_template = node.query.as_deref().unwrap_or(DEFAULT_QUERY);
|
||||||
|
let query = state_manager
|
||||||
|
.interpolate(query_template)
|
||||||
|
.context("Failed to interpolate rag node query")?;
|
||||||
|
|
||||||
|
let rag = ctx
|
||||||
|
.agent
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|a| a.graph_rag(node_id))
|
||||||
|
.ok_or_else(|| anyhow!("rag node '{node_id}' has no initialized knowledge base"))?;
|
||||||
|
|
||||||
|
let top_k = node.top_k.unwrap_or_else(|| rag.configured_top_k());
|
||||||
|
let rerank = rag.configured_reranker();
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
"{}",
|
||||||
|
dimmed_text(&format!("▸ rag lookup: node={node_id} top_k={top_k}"))
|
||||||
|
);
|
||||||
|
|
||||||
|
let timeout_dur = Duration::from_secs(node.timeout.unwrap_or(DEFAULT_RAG_TIMEOUT_SECS));
|
||||||
|
let abort = create_abort_signal();
|
||||||
|
let (context, sources_str, _ids) =
|
||||||
|
timeout(timeout_dur, rag.search(&query, top_k, rerank, abort))
|
||||||
|
.await
|
||||||
|
.with_context(|| {
|
||||||
|
format!(
|
||||||
|
"rag node '{node_id}' timed out after {}s",
|
||||||
|
timeout_dur.as_secs()
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.with_context(|| format!("rag node '{node_id}' retrieval failed"))?;
|
||||||
|
|
||||||
|
let output = build_rag_output(context, &sources_str);
|
||||||
|
apply_state_updates(node, state_manager, &output);
|
||||||
|
|
||||||
|
node_next
|
||||||
|
.map(String::from)
|
||||||
|
.ok_or_else(|| anyhow!("rag node '{node_id}' has no `next` set"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Assemble the `{{output}}` value as `{ "context": <ctx>, "sources": [...] }`.
|
||||||
|
/// `Rag::search` returns sources as a `- {path}` bullet list; it is split
|
||||||
|
/// into a JSON array so downstream templates can index `{{output.sources[0]}}`.
|
||||||
|
fn build_rag_output(context: String, sources_str: &str) -> Value {
|
||||||
|
let sources: Vec<Value> = sources_str
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim().trim_start_matches("- ").trim())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.map(|s| Value::String(s.to_string()))
|
||||||
|
.collect();
|
||||||
|
let mut obj = Map::new();
|
||||||
|
obj.insert("context".into(), Value::String(context));
|
||||||
|
obj.insert("sources".into(), Value::Array(sources));
|
||||||
|
Value::Object(obj)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expose the retrieval result as `{{output}}` for the duration of
|
||||||
|
/// `state_updates` evaluation, then restore the prior value. Same scoping
|
||||||
|
/// pattern as `llm`/`agent` nodes.
|
||||||
|
fn apply_state_updates(node: &RagNode, state_manager: &mut StateManager, output: &Value) {
|
||||||
|
let Some(updates) = &node.state_updates else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
let prev_output = state_manager.state().get(OUTPUT_KEY).cloned();
|
||||||
|
state_manager
|
||||||
|
.state_mut()
|
||||||
|
.set(OUTPUT_KEY.into(), output.clone());
|
||||||
|
|
||||||
|
for (key, template) in updates {
|
||||||
|
let value = state_manager.interpolate_lenient(template);
|
||||||
|
state_manager
|
||||||
|
.state_mut()
|
||||||
|
.set(key.clone(), Value::String(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
match prev_output {
|
||||||
|
Some(v) => state_manager.state_mut().set(OUTPUT_KEY.into(), v),
|
||||||
|
None => state_manager
|
||||||
|
.state_mut()
|
||||||
|
.set(OUTPUT_KEY.into(), Value::Null),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_rag_output_splits_bullet_sources_into_array() {
|
||||||
|
let out = build_rag_output("ctx".into(), "- a.md\n- https://x.com/spec");
|
||||||
|
assert_eq!(out["context"], json!("ctx"));
|
||||||
|
assert_eq!(out["sources"], json!(["a.md", "https://x.com/spec"]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_rag_output_handles_empty_sources() {
|
||||||
|
let out = build_rag_output("ctx".into(), "");
|
||||||
|
assert_eq!(out["sources"], json!([]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_rag_output_ignores_blank_lines() {
|
||||||
|
let out = build_rag_output("c".into(), "- a\n\n- b\n");
|
||||||
|
assert_eq!(out["sources"], json!(["a", "b"]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_rag_output_tolerates_unprefixed_lines() {
|
||||||
|
let out = build_rag_output("c".into(), "plain/path");
|
||||||
|
assert_eq!(out["sources"], json!(["plain/path"]));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -152,6 +152,7 @@ pub enum NodeType {
|
|||||||
Approval(ApprovalNode),
|
Approval(ApprovalNode),
|
||||||
Input(InputNode),
|
Input(InputNode),
|
||||||
Llm(LlmNode),
|
Llm(LlmNode),
|
||||||
|
Rag(RagNode),
|
||||||
End(EndNode),
|
End(EndNode),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,6 +329,35 @@ fn default_llm_max_iterations() -> u32 {
|
|||||||
10
|
10
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// `rag`-type node: run a hybrid (vector + keyword) retrieval against a
|
||||||
|
/// per-node knowledge base and write the result into state. The retrieved
|
||||||
|
/// context and the list of source paths are exposed to `state_updates` via
|
||||||
|
/// `{{output.context}}` and `{{output.sources}}` (the whole result is
|
||||||
|
/// `{{output}}`, a JSON object). The knowledge base is built once at agent
|
||||||
|
/// load time into `<agent>/<node-id>.yaml`.
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct RagNode {
|
||||||
|
/// Knowledge sources (files, directories, URLs, loader-protocol paths).
|
||||||
|
/// REQUIRED — this is what makes the node a RAG node.
|
||||||
|
pub documents: Vec<String>,
|
||||||
|
|
||||||
|
/// Retrieval query, templated against state. Defaults to
|
||||||
|
/// `{{initial_prompt}}` when omitted.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub query: Option<String>,
|
||||||
|
|
||||||
|
/// Number of chunks to retrieve. Defaults to the knowledge base's own
|
||||||
|
/// configured `top_k` when omitted.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<usize>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub state_updates: Option<HashMap<String, String>>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub timeout: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
/// `end`-type node: terminate execution; `output` (templated) is returned
|
/// `end`-type node: terminate execution; `output` (templated) is returned
|
||||||
/// as the graph's final result.
|
/// as the graph's final result.
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
|||||||
+99
-1
@@ -103,9 +103,31 @@ impl GraphValidator {
|
|||||||
self.validate_scripts(graph, &mut result);
|
self.validate_scripts(graph, &mut result);
|
||||||
self.validate_agents(graph, &mut result);
|
self.validate_agents(graph, &mut result);
|
||||||
self.validate_approval_routes(graph, &mut result);
|
self.validate_approval_routes(graph, &mut result);
|
||||||
|
self.validate_rag_nodes(graph, &mut result);
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn validate_rag_nodes(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||||
|
for (node_id, node) in &graph.nodes {
|
||||||
|
if let NodeType::Rag(r) = &node.node_type {
|
||||||
|
if r.documents.is_empty() {
|
||||||
|
result.error(ValidationError::with_node(
|
||||||
|
node_id,
|
||||||
|
"RAG node has no 'documents'; at least one knowledge source \
|
||||||
|
is required",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
if r.state_updates.is_none() {
|
||||||
|
result.warning(ValidationError::with_node(
|
||||||
|
node_id,
|
||||||
|
"RAG node has no 'state_updates'; its retrieval result will \
|
||||||
|
not be written to state",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
|
fn validate_node_references(&self, graph: &Graph, result: &mut ValidationResult) {
|
||||||
for (node_id, node) in &graph.nodes {
|
for (node_id, node) in &graph.nodes {
|
||||||
for (target, label) in declared_targets(node) {
|
for (target, label) in declared_targets(node) {
|
||||||
@@ -272,7 +294,9 @@ fn declared_targets(node: &Node) -> Vec<(String, &'static str)> {
|
|||||||
out.push((t.clone(), "llm 'fallback'"));
|
out.push((t.clone(), "llm 'fallback'"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
NodeType::Agent(_) | NodeType::End(_) => {}
|
// `agent`/`rag` route only via `next` (already collected above);
|
||||||
|
// `end` is terminal. No type-specific routing edges to add.
|
||||||
|
NodeType::Agent(_) | NodeType::Rag(_) | NodeType::End(_) => {}
|
||||||
}
|
}
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
@@ -416,6 +440,80 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rag_node(id: &str, documents: &[&str], with_state_updates: bool) -> Node {
|
||||||
|
let state_updates = with_state_updates.then(|| {
|
||||||
|
let mut m: HashMap<String, String> = HashMap::new();
|
||||||
|
m.insert("ctx".into(), "{{output.context}}".into());
|
||||||
|
m
|
||||||
|
});
|
||||||
|
Node {
|
||||||
|
id: id.into(),
|
||||||
|
description: String::new(),
|
||||||
|
node_type: NodeType::Rag(RagNode {
|
||||||
|
documents: documents.iter().map(|s| (*s).into()).collect(),
|
||||||
|
query: None,
|
||||||
|
top_k: None,
|
||||||
|
state_updates,
|
||||||
|
timeout: None,
|
||||||
|
}),
|
||||||
|
next: Some("end".into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rag_node_without_documents_errors() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![("r", rag_node("r", &[], true)), ("end", end_node("end"))],
|
||||||
|
"r",
|
||||||
|
);
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
assert!(!result.is_valid());
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.errors
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.message.contains("no 'documents'") && e.node_id.as_deref() == Some("r"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rag_node_without_state_updates_warns() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("r", rag_node("r", &["./docs"], false)),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"r",
|
||||||
|
);
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
assert!(result.is_valid());
|
||||||
|
assert!(
|
||||||
|
result
|
||||||
|
.warnings
|
||||||
|
.iter()
|
||||||
|
.any(|w| w.message.contains("no 'state_updates'"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_rag_node_produces_no_findings() {
|
||||||
|
let graph = graph_with(
|
||||||
|
vec![
|
||||||
|
("r", rag_node("r", &["./docs"], true)),
|
||||||
|
("end", end_node("end")),
|
||||||
|
],
|
||||||
|
"r",
|
||||||
|
);
|
||||||
|
let result = validator().validate(&graph);
|
||||||
|
assert!(result.is_valid());
|
||||||
|
assert!(
|
||||||
|
!result
|
||||||
|
.warnings
|
||||||
|
.iter()
|
||||||
|
.any(|w| w.message.contains("RAG node"))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
fn agent_node(id: &str, agent: &str, next: Option<&str>) -> Node {
|
fn agent_node(id: &str, agent: &str, next: Option<&str>) -> Node {
|
||||||
Node {
|
Node {
|
||||||
id: id.into(),
|
id: id.into(),
|
||||||
|
|||||||
+17
-8
@@ -16,7 +16,8 @@ use parking_lot::RwLock;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, sync::Arc, time::Duration,
|
collections::HashMap, env, fmt, fmt::Debug, fs, hash::Hash, path::Path, sync::Arc,
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ pub struct Rag {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for Rag {
|
impl Debug for Rag {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("Rag")
|
f.debug_struct("Rag")
|
||||||
.field("name", &self.name)
|
.field("name", &self.name)
|
||||||
.field("path", &self.path)
|
.field("path", &self.path)
|
||||||
@@ -315,6 +316,14 @@ impl Rag {
|
|||||||
self.name == TEMP_RAG_NAME
|
self.name == TEMP_RAG_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn configured_top_k(&self) -> usize {
|
||||||
|
self.data.top_k
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn configured_reranker(&self) -> Option<&str> {
|
||||||
|
self.data.reranker_model.as_deref()
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn search(
|
pub async fn search(
|
||||||
&self,
|
&self,
|
||||||
text: &str,
|
text: &str,
|
||||||
@@ -323,7 +332,7 @@ impl Rag {
|
|||||||
abort_signal: AbortSignal,
|
abort_signal: AbortSignal,
|
||||||
) -> Result<(String, String, Vec<DocumentId>)> {
|
) -> Result<(String, String, Vec<DocumentId>)> {
|
||||||
let ret = abortable_run_with_spinner(
|
let ret = abortable_run_with_spinner(
|
||||||
self.hybird_search(text, top_k, rerank_model),
|
self.hybrid_search(text, top_k, rerank_model),
|
||||||
"Searching",
|
"Searching",
|
||||||
abort_signal,
|
abort_signal,
|
||||||
)
|
)
|
||||||
@@ -583,7 +592,7 @@ impl Rag {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn hybird_search(
|
async fn hybrid_search(
|
||||||
&self,
|
&self,
|
||||||
query: &str,
|
query: &str,
|
||||||
top_k: usize,
|
top_k: usize,
|
||||||
@@ -781,7 +790,7 @@ pub struct RagData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for RagData {
|
impl Debug for RagData {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("RagData")
|
f.debug_struct("RagData")
|
||||||
.field("embedding_model", &self.embedding_model)
|
.field("embedding_model", &self.embedding_model)
|
||||||
.field("chunk_size", &self.chunk_size)
|
.field("chunk_size", &self.chunk_size)
|
||||||
@@ -909,7 +918,7 @@ pub type FileId = usize;
|
|||||||
pub struct DocumentId(usize);
|
pub struct DocumentId(usize);
|
||||||
|
|
||||||
impl Debug for DocumentId {
|
impl Debug for DocumentId {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
let (file_index, document_index) = self.split();
|
let (file_index, document_index) = self.split();
|
||||||
f.write_fmt(format_args!("{file_index}-{document_index}"))
|
f.write_fmt(format_args!("{file_index}-{document_index}"))
|
||||||
}
|
}
|
||||||
@@ -951,8 +960,8 @@ impl SelectOption {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Display for SelectOption {
|
impl fmt::Display for SelectOption {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
write!(f, "{} ({})", self.value, self.description)
|
write!(f, "{} ({})", self.value, self.description)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user