feat: added additional support for all RAG-configuration fields in RAG nodes
This commit is contained in:
+44
-18
@@ -12,6 +12,7 @@ use crate::config::prompts::{
|
|||||||
DEFAULT_USER_INTERACTION_INSTRUCTIONS,
|
DEFAULT_USER_INTERACTION_INSTRUCTIONS,
|
||||||
};
|
};
|
||||||
use crate::graph::{Graph, GraphParser, NodeType};
|
use crate::graph::{Graph, GraphParser, NodeType};
|
||||||
|
use crate::rag::RagInitConfig;
|
||||||
use crate::vault::SECRET_RE;
|
use crate::vault::SECRET_RE;
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use fancy_regex::Captures;
|
use fancy_regex::Captures;
|
||||||
@@ -673,7 +674,6 @@ impl AgentConfig {
|
|||||||
model_id: graph.model.clone(),
|
model_id: graph.model.clone(),
|
||||||
temperature: graph.temperature,
|
temperature: graph.temperature,
|
||||||
top_p: graph.top_p,
|
top_p: graph.top_p,
|
||||||
agent_session: graph.agent_session.clone(),
|
|
||||||
description: graph.description.clone(),
|
description: graph.description.clone(),
|
||||||
global_tools: graph.global_tools.clone(),
|
global_tools: graph.global_tools.clone(),
|
||||||
mcp_servers: graph.mcp_servers.clone(),
|
mcp_servers: graph.mcp_servers.clone(),
|
||||||
@@ -833,6 +833,9 @@ async fn init_graph_rags(
|
|||||||
abort_signal: AbortSignal,
|
abort_signal: AbortSignal,
|
||||||
) -> Result<HashMap<String, Arc<Rag>>> {
|
) -> Result<HashMap<String, Arc<Rag>>> {
|
||||||
let mut rags = HashMap::new();
|
let mut rags = HashMap::new();
|
||||||
|
if info_flag {
|
||||||
|
return Ok(rags);
|
||||||
|
}
|
||||||
for (node_id, node) in &graph.nodes {
|
for (node_id, node) in &graph.nodes {
|
||||||
let NodeType::Rag(rag_node) = &node.node_type else {
|
let NodeType::Rag(rag_node) = &node.node_type else {
|
||||||
continue;
|
continue;
|
||||||
@@ -852,23 +855,38 @@ async fn init_graph_rags(
|
|||||||
Rag::load(&app_clone, &name_clone, &path_clone)
|
Rag::load(&app_clone, &name_clone, &path_clone)
|
||||||
})
|
})
|
||||||
.await?
|
.await?
|
||||||
} else if info_flag || !*IS_STDOUT_TERMINAL {
|
|
||||||
bail!(
|
|
||||||
"Agent '{agent_name}' requires RAG for rag node '{node_id}', but its \
|
|
||||||
knowledge base has not been built. Run the agent once interactively \
|
|
||||||
to initialize it."
|
|
||||||
);
|
|
||||||
} else {
|
} else {
|
||||||
let ans = Confirm::new(&format!(
|
let config = RagInitConfig {
|
||||||
"Initialize RAG knowledge base for rag node '{node_id}'?"
|
embedding_model: rag_node.embedding_model.clone(),
|
||||||
))
|
chunk_size: rag_node.chunk_size,
|
||||||
.with_default(true)
|
chunk_overlap: rag_node.chunk_overlap,
|
||||||
.prompt()?;
|
reranker_model: rag_node.reranker_model.clone(),
|
||||||
if !ans {
|
top_k: rag_node.top_k,
|
||||||
bail!(
|
batch_size: rag_node.batch_size,
|
||||||
"Agent '{agent_name}' has rag node '{node_id}' but its RAG was not \
|
};
|
||||||
initialized. RAG initialization is required for this agent."
|
let fully_specified = config.embedding_model.is_some()
|
||||||
);
|
&& config.chunk_size.is_some()
|
||||||
|
&& config.chunk_overlap.is_some();
|
||||||
|
if !fully_specified {
|
||||||
|
if !*IS_STDOUT_TERMINAL {
|
||||||
|
bail!(
|
||||||
|
"Agent '{agent_name}' requires RAG for rag node '{node_id}', but its \
|
||||||
|
knowledge base is not built and the node does not fully specify how \
|
||||||
|
to build it. Set `embedding_model`, `chunk_size`, and `chunk_overlap` \
|
||||||
|
on the node, or run the agent once interactively."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let ans = Confirm::new(&format!(
|
||||||
|
"Initialize RAG knowledge base for rag node '{node_id}'?"
|
||||||
|
))
|
||||||
|
.with_default(true)
|
||||||
|
.prompt()?;
|
||||||
|
if !ans {
|
||||||
|
bail!(
|
||||||
|
"Agent '{agent_name}' has rag node '{node_id}' but its RAG was not \
|
||||||
|
initialized. RAG initialization is required for this agent."
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let document_paths =
|
let document_paths =
|
||||||
resolve_document_paths(&rag_node.documents, loaders, agent_data_dir)?;
|
resolve_document_paths(&rag_node.documents, loaders, agent_data_dir)?;
|
||||||
@@ -879,7 +897,15 @@ async fn init_graph_rags(
|
|||||||
app_state
|
app_state
|
||||||
.rag_cache
|
.rag_cache
|
||||||
.load_with(key, || async move {
|
.load_with(key, || async move {
|
||||||
Rag::init(&app_clone, &name_clone, &path_clone, &document_paths, abort).await
|
Rag::init_with_config(
|
||||||
|
&app_clone,
|
||||||
|
&name_clone,
|
||||||
|
&path_clone,
|
||||||
|
&document_paths,
|
||||||
|
&config,
|
||||||
|
abort,
|
||||||
|
)
|
||||||
|
.await
|
||||||
})
|
})
|
||||||
.await?
|
.await?
|
||||||
};
|
};
|
||||||
|
|||||||
+15
-18
@@ -198,21 +198,19 @@ fn build_inline_role(
|
|||||||
role.set_top_p(Some(p));
|
role.set_top_p(Some(p));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(tool_entries) = &node.tools {
|
if node.tools.as_deref().unwrap_or_default().is_empty() {
|
||||||
if tool_entries.is_empty() {
|
role.set_enabled_tools(Some(String::new()));
|
||||||
role.set_enabled_tools(Some(String::new()));
|
role.set_enabled_mcp_servers(Some(String::new()));
|
||||||
role.set_enabled_mcp_servers(Some(String::new()));
|
} else {
|
||||||
|
if !regular_tools.is_empty() {
|
||||||
|
role.set_enabled_tools(Some(regular_tools.join(",")));
|
||||||
} else {
|
} else {
|
||||||
if !regular_tools.is_empty() {
|
role.set_enabled_tools(Some(String::new()));
|
||||||
role.set_enabled_tools(Some(regular_tools.join(",")));
|
}
|
||||||
} else {
|
if !mcp_servers.is_empty() {
|
||||||
role.set_enabled_tools(Some(String::new()));
|
role.set_enabled_mcp_servers(Some(mcp_servers.join(",")));
|
||||||
}
|
} else {
|
||||||
if !mcp_servers.is_empty() {
|
role.set_enabled_mcp_servers(Some(String::new()));
|
||||||
role.set_enabled_mcp_servers(Some(mcp_servers.join(",")));
|
|
||||||
} else {
|
|
||||||
role.set_enabled_mcp_servers(Some(String::new()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,9 +368,8 @@ fn format_schema_hint(schema: &Value) -> String {
|
|||||||
|
|
||||||
fn describe_tools_filter(tools: Option<&[String]>) -> String {
|
fn describe_tools_filter(tools: Option<&[String]>) -> String {
|
||||||
match tools {
|
match tools {
|
||||||
None => "<inherit>".into(),
|
Some(t) if !t.is_empty() => t.join(","),
|
||||||
Some(t) if t.is_empty() => "<none>".into(),
|
_ => "<none>".into(),
|
||||||
Some(t) => t.join(","),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -531,7 +528,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn describe_tools_filter_renders_each_case() {
|
fn describe_tools_filter_renders_each_case() {
|
||||||
assert_eq!(describe_tools_filter(None), "<inherit>");
|
assert_eq!(describe_tools_filter(None), "<none>");
|
||||||
assert_eq!(describe_tools_filter(Some(&[])), "<none>");
|
assert_eq!(describe_tools_filter(Some(&[])), "<none>");
|
||||||
let tools = vec!["a".to_string(), "b".to_string()];
|
let tools = vec!["a".to_string(), "b".to_string()];
|
||||||
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
|
assert_eq!(describe_tools_filter(Some(&tools)), "a,b");
|
||||||
|
|||||||
+24
-9
@@ -31,10 +31,6 @@ pub struct Graph {
|
|||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub top_p: Option<f64>,
|
pub top_p: Option<f64>,
|
||||||
|
|
||||||
/// Session to start the agent in (e.g. `temp`). Single-file mode only.
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub agent_session: Option<String>,
|
|
||||||
|
|
||||||
/// Global tools available to the agent's nodes. Single-file mode only.
|
/// Global tools available to the agent's nodes. Single-file mode only.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub global_tools: Vec<String>,
|
pub global_tools: Vec<String>,
|
||||||
@@ -282,8 +278,8 @@ pub struct LlmNode {
|
|||||||
/// Each entry is either an exact function name (`global_tools`
|
/// Each entry is either an exact function name (`global_tools`
|
||||||
/// entry or `tools.{sh,py,ts}` subcommand) or the shorthand
|
/// entry or `tools.{sh,py,ts}` subcommand) or the shorthand
|
||||||
/// `mcp:<server>` (where `<server>` must be in the agent's
|
/// `mcp:<server>` (where `<server>` must be in the agent's
|
||||||
/// `mcp_servers`). Unset = inherit agent's full set; `[]` = no
|
/// `mcp_servers`). Unset or `[]` = no tools — tools are strictly
|
||||||
/// tools.
|
/// opt-in.
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub tools: Option<Vec<String>>,
|
pub tools: Option<Vec<String>>,
|
||||||
|
|
||||||
@@ -351,6 +347,28 @@ pub struct RagNode {
|
|||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub top_k: Option<usize>,
|
pub top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// Embedding model for building the knowledge base. When this plus
|
||||||
|
/// `chunk_size` and `chunk_overlap` are all set, knowledge-base
|
||||||
|
/// construction runs non-interactively (no prompts).
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub embedding_model: Option<String>,
|
||||||
|
|
||||||
|
/// Chunk size for splitting documents at build time.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub chunk_size: Option<usize>,
|
||||||
|
|
||||||
|
/// Chunk overlap for splitting documents at build time.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub chunk_overlap: Option<usize>,
|
||||||
|
|
||||||
|
/// Reranker model applied to hybrid-search results.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reranker_model: Option<String>,
|
||||||
|
|
||||||
|
/// Embedding-request batch size at build time.
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub batch_size: Option<usize>,
|
||||||
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
pub state_updates: Option<HashMap<String, String>>,
|
pub state_updates: Option<HashMap<String, String>>,
|
||||||
|
|
||||||
@@ -812,7 +830,6 @@ start: e
|
|||||||
model: anthropic:claude-sonnet-4-6
|
model: anthropic:claude-sonnet-4-6
|
||||||
temperature: 0.2
|
temperature: 0.2
|
||||||
top_p: 0.9
|
top_p: 0.9
|
||||||
agent_session: temp
|
|
||||||
global_tools:
|
global_tools:
|
||||||
- web_search_loki.sh
|
- web_search_loki.sh
|
||||||
mcp_servers:
|
mcp_servers:
|
||||||
@@ -829,7 +846,6 @@ nodes:
|
|||||||
assert_eq!(graph.model.as_deref(), Some("anthropic:claude-sonnet-4-6"));
|
assert_eq!(graph.model.as_deref(), Some("anthropic:claude-sonnet-4-6"));
|
||||||
assert_eq!(graph.temperature, Some(0.2));
|
assert_eq!(graph.temperature, Some(0.2));
|
||||||
assert_eq!(graph.top_p, Some(0.9));
|
assert_eq!(graph.top_p, Some(0.9));
|
||||||
assert_eq!(graph.agent_session.as_deref(), Some("temp"));
|
|
||||||
assert_eq!(graph.global_tools, vec!["web_search_loki.sh"]);
|
assert_eq!(graph.global_tools, vec!["web_search_loki.sh"]);
|
||||||
assert_eq!(graph.mcp_servers, vec!["pubmed-search"]);
|
assert_eq!(graph.mcp_servers, vec!["pubmed-search"]);
|
||||||
assert_eq!(graph.conversation_starters, vec!["Look up 2160-0"]);
|
assert_eq!(graph.conversation_starters, vec!["Look up 2160-0"]);
|
||||||
@@ -842,7 +858,6 @@ nodes:
|
|||||||
assert!(graph.model.is_none());
|
assert!(graph.model.is_none());
|
||||||
assert!(graph.temperature.is_none());
|
assert!(graph.temperature.is_none());
|
||||||
assert!(graph.top_p.is_none());
|
assert!(graph.top_p.is_none());
|
||||||
assert!(graph.agent_session.is_none());
|
|
||||||
assert!(graph.global_tools.is_empty());
|
assert!(graph.global_tools.is_empty());
|
||||||
assert!(graph.mcp_servers.is_empty());
|
assert!(graph.mcp_servers.is_empty());
|
||||||
assert!(graph.conversation_starters.is_empty());
|
assert!(graph.conversation_starters.is_empty());
|
||||||
|
|||||||
@@ -382,7 +382,6 @@ mod tests {
|
|||||||
model: None,
|
model: None,
|
||||||
temperature: None,
|
temperature: None,
|
||||||
top_p: None,
|
top_p: None,
|
||||||
agent_session: None,
|
|
||||||
global_tools: Vec::new(),
|
global_tools: Vec::new(),
|
||||||
mcp_servers: Vec::new(),
|
mcp_servers: Vec::new(),
|
||||||
conversation_starters: Vec::new(),
|
conversation_starters: Vec::new(),
|
||||||
@@ -453,6 +452,11 @@ mod tests {
|
|||||||
documents: documents.iter().map(|s| (*s).into()).collect(),
|
documents: documents.iter().map(|s| (*s).into()).collect(),
|
||||||
query: None,
|
query: None,
|
||||||
top_k: None,
|
top_k: None,
|
||||||
|
embedding_model: None,
|
||||||
|
chunk_size: None,
|
||||||
|
chunk_overlap: None,
|
||||||
|
reranker_model: None,
|
||||||
|
batch_size: None,
|
||||||
state_updates,
|
state_updates,
|
||||||
timeout: None,
|
timeout: None,
|
||||||
}),
|
}),
|
||||||
|
|||||||
+124
@@ -82,11 +82,135 @@ impl Clone for Rag {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Caller-supplied overrides for building a RAG knowledge base. Each field
|
||||||
|
/// takes precedence over the app-level `rag_*` config; a field left `None`
|
||||||
|
/// falls back to app config and then, if still unset, an interactive prompt.
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct RagInitConfig {
|
||||||
|
pub embedding_model: Option<String>,
|
||||||
|
pub chunk_size: Option<usize>,
|
||||||
|
pub chunk_overlap: Option<usize>,
|
||||||
|
pub reranker_model: Option<String>,
|
||||||
|
pub top_k: Option<usize>,
|
||||||
|
pub batch_size: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
impl Rag {
|
impl Rag {
|
||||||
fn create_embeddings_client(&self, model: Model) -> Result<Box<dyn Client>> {
|
fn create_embeddings_client(&self, model: Model) -> Result<Box<dyn Client>> {
|
||||||
init_client(&self.app_config, model)
|
init_client(&self.app_config, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build a RAG knowledge base using caller-supplied config overrides.
|
||||||
|
/// Unlike [`Rag::init`], this does not bail outright in non-interactive
|
||||||
|
/// mode: it only requires a terminal when a needed value is missing
|
||||||
|
/// from both `config` and app config. When `config` fully specifies
|
||||||
|
/// `embedding_model`, `chunk_size`, and `chunk_overlap`, the build runs
|
||||||
|
/// with no prompts.
|
||||||
|
pub async fn init_with_config(
|
||||||
|
app: &AppConfig,
|
||||||
|
name: &str,
|
||||||
|
save_path: &Path,
|
||||||
|
doc_paths: &[String],
|
||||||
|
config: &RagInitConfig,
|
||||||
|
abort_signal: AbortSignal,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if doc_paths.is_empty() {
|
||||||
|
bail!("Cannot build RAG knowledge base '{name}' with no documents");
|
||||||
|
}
|
||||||
|
println!("⚙ Initializing RAG...");
|
||||||
|
let data = Self::resolve_init_data(app, config)?;
|
||||||
|
let mut rag = Self::create(app, name, save_path, data)?;
|
||||||
|
let loaders = app.document_loaders.clone();
|
||||||
|
let (spinner, spinner_rx) = Spinner::create("");
|
||||||
|
abortable_run_with_spinner_rx(
|
||||||
|
rag.sync_documents(doc_paths, true, loaders, Some(spinner)),
|
||||||
|
spinner_rx,
|
||||||
|
abort_signal,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
if rag.save()? {
|
||||||
|
println!("✓ Saved RAG to '{}'.", save_path.display());
|
||||||
|
}
|
||||||
|
Ok(rag)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_init_data(app: &AppConfig, config: &RagInitConfig) -> Result<RagData> {
|
||||||
|
let embedding_model_id = config
|
||||||
|
.embedding_model
|
||||||
|
.clone()
|
||||||
|
.or_else(|| app.rag_embedding_model.clone());
|
||||||
|
let embedding_model_id = match embedding_model_id {
|
||||||
|
Some(value) => {
|
||||||
|
println!("Embedding model: {value}");
|
||||||
|
value
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
if !*IS_STDOUT_TERMINAL {
|
||||||
|
bail!(
|
||||||
|
"RAG knowledge base needs an embedding model. Set `embedding_model` \
|
||||||
|
on the rag node, or run the agent interactively once."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let models = list_models(app, ModelType::Embedding);
|
||||||
|
if models.is_empty() {
|
||||||
|
bail!("No available embedding model");
|
||||||
|
}
|
||||||
|
select_embedding_model(&models)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let embedding_model =
|
||||||
|
Model::retrieve_model(app, &embedding_model_id, ModelType::Embedding)?;
|
||||||
|
|
||||||
|
let chunk_size = match config.chunk_size.or(app.rag_chunk_size) {
|
||||||
|
Some(value) => {
|
||||||
|
println!("Chunk size: {value}");
|
||||||
|
value
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
if !*IS_STDOUT_TERMINAL {
|
||||||
|
bail!(
|
||||||
|
"RAG knowledge base needs a chunk_size. Set `chunk_size` on the \
|
||||||
|
rag node, or run the agent interactively once."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
set_chunk_size(&embedding_model)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let chunk_overlap = match config.chunk_overlap.or(app.rag_chunk_overlap) {
|
||||||
|
Some(value) => {
|
||||||
|
println!("Chunk overlap: {value}");
|
||||||
|
value
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
if !*IS_STDOUT_TERMINAL {
|
||||||
|
bail!(
|
||||||
|
"RAG knowledge base needs a chunk_overlap. Set `chunk_overlap` on \
|
||||||
|
the rag node, or run the agent interactively once."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
set_chunk_overlay(chunk_size / 20)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let reranker_model = config
|
||||||
|
.reranker_model
|
||||||
|
.clone()
|
||||||
|
.or_else(|| app.rag_reranker_model.clone());
|
||||||
|
let top_k = config.top_k.unwrap_or(app.rag_top_k);
|
||||||
|
let batch_size = config
|
||||||
|
.batch_size
|
||||||
|
.or_else(|| embedding_model.max_batch_size());
|
||||||
|
|
||||||
|
Ok(RagData::new(
|
||||||
|
embedding_model.id(),
|
||||||
|
chunk_size,
|
||||||
|
chunk_overlap,
|
||||||
|
reranker_model,
|
||||||
|
top_k,
|
||||||
|
batch_size,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn init(
|
pub async fn init(
|
||||||
app: &AppConfig,
|
app: &AppConfig,
|
||||||
name: &str,
|
name: &str,
|
||||||
|
|||||||
Reference in New Issue
Block a user