From 0094be475fddbff6562bb77bcd446b83ea89a57a Mon Sep 17 00:00:00 2001 From: Alex Clarke Date: Fri, 15 May 2026 16:38:52 -0600 Subject: [PATCH] feat: added additional support for all RAG-configuration fields in RAG nodes --- src/config/agent.rs | 62 +++++++++++++++------ src/graph/llm.rs | 33 +++++------ src/graph/types.rs | 33 ++++++++--- src/graph/validator.rs | 6 +- src/rag/mod.rs | 124 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 212 insertions(+), 46 deletions(-) diff --git a/src/config/agent.rs b/src/config/agent.rs index af01c51..2a831e8 100644 --- a/src/config/agent.rs +++ b/src/config/agent.rs @@ -12,6 +12,7 @@ use crate::config::prompts::{ DEFAULT_USER_INTERACTION_INSTRUCTIONS, }; use crate::graph::{Graph, GraphParser, NodeType}; +use crate::rag::RagInitConfig; use crate::vault::SECRET_RE; use anyhow::{Context, Result}; use fancy_regex::Captures; @@ -673,7 +674,6 @@ impl AgentConfig { model_id: graph.model.clone(), temperature: graph.temperature, top_p: graph.top_p, - agent_session: graph.agent_session.clone(), description: graph.description.clone(), global_tools: graph.global_tools.clone(), mcp_servers: graph.mcp_servers.clone(), @@ -833,6 +833,9 @@ async fn init_graph_rags( abort_signal: AbortSignal, ) -> Result>> { let mut rags = HashMap::new(); + if info_flag { + return Ok(rags); + } for (node_id, node) in &graph.nodes { let NodeType::Rag(rag_node) = &node.node_type else { continue; @@ -852,23 +855,38 @@ async fn init_graph_rags( Rag::load(&app_clone, &name_clone, &path_clone) }) .await? - } else if info_flag || !*IS_STDOUT_TERMINAL { - bail!( - "Agent '{agent_name}' requires RAG for rag node '{node_id}', but its \ - knowledge base has not been built. Run the agent once interactively \ - to initialize it." - ); } else { - let ans = Confirm::new(&format!( - "Initialize RAG knowledge base for rag node '{node_id}'?" - )) - .with_default(true) - .prompt()?; - if !ans { - bail!( - "Agent '{agent_name}' has rag node '{node_id}' but its RAG was not \ - initialized. RAG initialization is required for this agent." - ); + let config = RagInitConfig { + embedding_model: rag_node.embedding_model.clone(), + chunk_size: rag_node.chunk_size, + chunk_overlap: rag_node.chunk_overlap, + reranker_model: rag_node.reranker_model.clone(), + top_k: rag_node.top_k, + batch_size: rag_node.batch_size, + }; + let fully_specified = config.embedding_model.is_some() + && config.chunk_size.is_some() + && config.chunk_overlap.is_some(); + if !fully_specified { + if !*IS_STDOUT_TERMINAL { + bail!( + "Agent '{agent_name}' requires RAG for rag node '{node_id}', but its \ + knowledge base is not built and the node does not fully specify how \ + to build it. Set `embedding_model`, `chunk_size`, and `chunk_overlap` \ + on the node, or run the agent once interactively." + ); + } + let ans = Confirm::new(&format!( + "Initialize RAG knowledge base for rag node '{node_id}'?" + )) + .with_default(true) + .prompt()?; + if !ans { + bail!( + "Agent '{agent_name}' has rag node '{node_id}' but its RAG was not \ + initialized. RAG initialization is required for this agent." + ); + } } let document_paths = resolve_document_paths(&rag_node.documents, loaders, agent_data_dir)?; @@ -879,7 +897,15 @@ async fn init_graph_rags( app_state .rag_cache .load_with(key, || async move { - Rag::init(&app_clone, &name_clone, &path_clone, &document_paths, abort).await + Rag::init_with_config( + &app_clone, + &name_clone, + &path_clone, + &document_paths, + &config, + abort, + ) + .await }) .await? }; diff --git a/src/graph/llm.rs b/src/graph/llm.rs index b8376ba..87a9937 100644 --- a/src/graph/llm.rs +++ b/src/graph/llm.rs @@ -198,21 +198,19 @@ fn build_inline_role( role.set_top_p(Some(p)); } - if let Some(tool_entries) = &node.tools { - if tool_entries.is_empty() { - role.set_enabled_tools(Some(String::new())); - role.set_enabled_mcp_servers(Some(String::new())); + if node.tools.as_deref().unwrap_or_default().is_empty() { + role.set_enabled_tools(Some(String::new())); + role.set_enabled_mcp_servers(Some(String::new())); + } else { + if !regular_tools.is_empty() { + role.set_enabled_tools(Some(regular_tools.join(","))); } else { - if !regular_tools.is_empty() { - role.set_enabled_tools(Some(regular_tools.join(","))); - } else { - role.set_enabled_tools(Some(String::new())); - } - if !mcp_servers.is_empty() { - role.set_enabled_mcp_servers(Some(mcp_servers.join(","))); - } else { - role.set_enabled_mcp_servers(Some(String::new())); - } + role.set_enabled_tools(Some(String::new())); + } + if !mcp_servers.is_empty() { + role.set_enabled_mcp_servers(Some(mcp_servers.join(","))); + } else { + role.set_enabled_mcp_servers(Some(String::new())); } } @@ -370,9 +368,8 @@ fn format_schema_hint(schema: &Value) -> String { fn describe_tools_filter(tools: Option<&[String]>) -> String { match tools { - None => "".into(), - Some(t) if t.is_empty() => "".into(), - Some(t) => t.join(","), + Some(t) if !t.is_empty() => t.join(","), + _ => "".into(), } } @@ -531,7 +528,7 @@ mod tests { #[test] fn describe_tools_filter_renders_each_case() { - assert_eq!(describe_tools_filter(None), ""); + assert_eq!(describe_tools_filter(None), ""); assert_eq!(describe_tools_filter(Some(&[])), ""); let tools = vec!["a".to_string(), "b".to_string()]; assert_eq!(describe_tools_filter(Some(&tools)), "a,b"); diff --git a/src/graph/types.rs b/src/graph/types.rs index 163b768..d471a1b 100644 --- a/src/graph/types.rs +++ b/src/graph/types.rs @@ -31,10 +31,6 @@ pub struct Graph { #[serde(default, skip_serializing_if = "Option::is_none")] pub top_p: Option, - /// Session to start the agent in (e.g. `temp`). Single-file mode only. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub agent_session: Option, - /// Global tools available to the agent's nodes. Single-file mode only. #[serde(default)] pub global_tools: Vec, @@ -282,8 +278,8 @@ pub struct LlmNode { /// Each entry is either an exact function name (`global_tools` /// entry or `tools.{sh,py,ts}` subcommand) or the shorthand /// `mcp:` (where `` must be in the agent's - /// `mcp_servers`). Unset = inherit agent's full set; `[]` = no - /// tools. + /// `mcp_servers`). Unset or `[]` = no tools — tools are strictly + /// opt-in. #[serde(default, skip_serializing_if = "Option::is_none")] pub tools: Option>, @@ -351,6 +347,28 @@ pub struct RagNode { #[serde(default, skip_serializing_if = "Option::is_none")] pub top_k: Option, + /// Embedding model for building the knowledge base. When this plus + /// `chunk_size` and `chunk_overlap` are all set, knowledge-base + /// construction runs non-interactively (no prompts). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub embedding_model: Option, + + /// Chunk size for splitting documents at build time. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub chunk_size: Option, + + /// Chunk overlap for splitting documents at build time. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub chunk_overlap: Option, + + /// Reranker model applied to hybrid-search results. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reranker_model: Option, + + /// Embedding-request batch size at build time. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub batch_size: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub state_updates: Option>, @@ -812,7 +830,6 @@ start: e model: anthropic:claude-sonnet-4-6 temperature: 0.2 top_p: 0.9 -agent_session: temp global_tools: - web_search_loki.sh mcp_servers: @@ -829,7 +846,6 @@ nodes: assert_eq!(graph.model.as_deref(), Some("anthropic:claude-sonnet-4-6")); assert_eq!(graph.temperature, Some(0.2)); assert_eq!(graph.top_p, Some(0.9)); - assert_eq!(graph.agent_session.as_deref(), Some("temp")); assert_eq!(graph.global_tools, vec!["web_search_loki.sh"]); assert_eq!(graph.mcp_servers, vec!["pubmed-search"]); assert_eq!(graph.conversation_starters, vec!["Look up 2160-0"]); @@ -842,7 +858,6 @@ nodes: assert!(graph.model.is_none()); assert!(graph.temperature.is_none()); assert!(graph.top_p.is_none()); - assert!(graph.agent_session.is_none()); assert!(graph.global_tools.is_empty()); assert!(graph.mcp_servers.is_empty()); assert!(graph.conversation_starters.is_empty()); diff --git a/src/graph/validator.rs b/src/graph/validator.rs index a4d6de9..147b125 100644 --- a/src/graph/validator.rs +++ b/src/graph/validator.rs @@ -382,7 +382,6 @@ mod tests { model: None, temperature: None, top_p: None, - agent_session: None, global_tools: Vec::new(), mcp_servers: Vec::new(), conversation_starters: Vec::new(), @@ -453,6 +452,11 @@ mod tests { documents: documents.iter().map(|s| (*s).into()).collect(), query: None, top_k: None, + embedding_model: None, + chunk_size: None, + chunk_overlap: None, + reranker_model: None, + batch_size: None, state_updates, timeout: None, }), diff --git a/src/rag/mod.rs b/src/rag/mod.rs index 91d69d6..6ed70d4 100644 --- a/src/rag/mod.rs +++ b/src/rag/mod.rs @@ -82,11 +82,135 @@ impl Clone for Rag { } } +/// Caller-supplied overrides for building a RAG knowledge base. Each field +/// takes precedence over the app-level `rag_*` config; a field left `None` +/// falls back to app config and then, if still unset, an interactive prompt. +#[derive(Debug, Clone, Default)] +pub struct RagInitConfig { + pub embedding_model: Option, + pub chunk_size: Option, + pub chunk_overlap: Option, + pub reranker_model: Option, + pub top_k: Option, + pub batch_size: Option, +} + impl Rag { fn create_embeddings_client(&self, model: Model) -> Result> { init_client(&self.app_config, model) } + /// Build a RAG knowledge base using caller-supplied config overrides. + /// Unlike [`Rag::init`], this does not bail outright in non-interactive + /// mode: it only requires a terminal when a needed value is missing + /// from both `config` and app config. When `config` fully specifies + /// `embedding_model`, `chunk_size`, and `chunk_overlap`, the build runs + /// with no prompts. + pub async fn init_with_config( + app: &AppConfig, + name: &str, + save_path: &Path, + doc_paths: &[String], + config: &RagInitConfig, + abort_signal: AbortSignal, + ) -> Result { + if doc_paths.is_empty() { + bail!("Cannot build RAG knowledge base '{name}' with no documents"); + } + println!("⚙ Initializing RAG..."); + let data = Self::resolve_init_data(app, config)?; + let mut rag = Self::create(app, name, save_path, data)?; + let loaders = app.document_loaders.clone(); + let (spinner, spinner_rx) = Spinner::create(""); + abortable_run_with_spinner_rx( + rag.sync_documents(doc_paths, true, loaders, Some(spinner)), + spinner_rx, + abort_signal, + ) + .await?; + if rag.save()? { + println!("✓ Saved RAG to '{}'.", save_path.display()); + } + Ok(rag) + } + + fn resolve_init_data(app: &AppConfig, config: &RagInitConfig) -> Result { + let embedding_model_id = config + .embedding_model + .clone() + .or_else(|| app.rag_embedding_model.clone()); + let embedding_model_id = match embedding_model_id { + Some(value) => { + println!("Embedding model: {value}"); + value + } + None => { + if !*IS_STDOUT_TERMINAL { + bail!( + "RAG knowledge base needs an embedding model. Set `embedding_model` \ + on the rag node, or run the agent interactively once." + ); + } + let models = list_models(app, ModelType::Embedding); + if models.is_empty() { + bail!("No available embedding model"); + } + select_embedding_model(&models)? + } + }; + let embedding_model = + Model::retrieve_model(app, &embedding_model_id, ModelType::Embedding)?; + + let chunk_size = match config.chunk_size.or(app.rag_chunk_size) { + Some(value) => { + println!("Chunk size: {value}"); + value + } + None => { + if !*IS_STDOUT_TERMINAL { + bail!( + "RAG knowledge base needs a chunk_size. Set `chunk_size` on the \ + rag node, or run the agent interactively once." + ); + } + set_chunk_size(&embedding_model)? + } + }; + let chunk_overlap = match config.chunk_overlap.or(app.rag_chunk_overlap) { + Some(value) => { + println!("Chunk overlap: {value}"); + value + } + None => { + if !*IS_STDOUT_TERMINAL { + bail!( + "RAG knowledge base needs a chunk_overlap. Set `chunk_overlap` on \ + the rag node, or run the agent interactively once." + ); + } + set_chunk_overlay(chunk_size / 20)? + } + }; + + let reranker_model = config + .reranker_model + .clone() + .or_else(|| app.rag_reranker_model.clone()); + let top_k = config.top_k.unwrap_or(app.rag_top_k); + let batch_size = config + .batch_size + .or_else(|| embedding_model.max_batch_size()); + + Ok(RagData::new( + embedding_model.id(), + chunk_size, + chunk_overlap, + reranker_model, + top_k, + batch_size, + )) + } + pub async fn init( app: &AppConfig, name: &str,