feat: added additional support for all RAG-configuration fields in RAG nodes
This commit is contained in:
+124
@@ -82,11 +82,135 @@ impl Clone for Rag {
|
||||
}
|
||||
}
|
||||
|
||||
/// Caller-supplied overrides for building a RAG knowledge base. Each field
|
||||
/// takes precedence over the app-level `rag_*` config; a field left `None`
|
||||
/// falls back to app config and then, if still unset, an interactive prompt.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct RagInitConfig {
|
||||
pub embedding_model: Option<String>,
|
||||
pub chunk_size: Option<usize>,
|
||||
pub chunk_overlap: Option<usize>,
|
||||
pub reranker_model: Option<String>,
|
||||
pub top_k: Option<usize>,
|
||||
pub batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl Rag {
|
||||
fn create_embeddings_client(&self, model: Model) -> Result<Box<dyn Client>> {
|
||||
init_client(&self.app_config, model)
|
||||
}
|
||||
|
||||
/// Build a RAG knowledge base using caller-supplied config overrides.
|
||||
/// Unlike [`Rag::init`], this does not bail outright in non-interactive
|
||||
/// mode: it only requires a terminal when a needed value is missing
|
||||
/// from both `config` and app config. When `config` fully specifies
|
||||
/// `embedding_model`, `chunk_size`, and `chunk_overlap`, the build runs
|
||||
/// with no prompts.
|
||||
pub async fn init_with_config(
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
save_path: &Path,
|
||||
doc_paths: &[String],
|
||||
config: &RagInitConfig,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<Self> {
|
||||
if doc_paths.is_empty() {
|
||||
bail!("Cannot build RAG knowledge base '{name}' with no documents");
|
||||
}
|
||||
println!("⚙ Initializing RAG...");
|
||||
let data = Self::resolve_init_data(app, config)?;
|
||||
let mut rag = Self::create(app, name, save_path, data)?;
|
||||
let loaders = app.document_loaders.clone();
|
||||
let (spinner, spinner_rx) = Spinner::create("");
|
||||
abortable_run_with_spinner_rx(
|
||||
rag.sync_documents(doc_paths, true, loaders, Some(spinner)),
|
||||
spinner_rx,
|
||||
abort_signal,
|
||||
)
|
||||
.await?;
|
||||
if rag.save()? {
|
||||
println!("✓ Saved RAG to '{}'.", save_path.display());
|
||||
}
|
||||
Ok(rag)
|
||||
}
|
||||
|
||||
fn resolve_init_data(app: &AppConfig, config: &RagInitConfig) -> Result<RagData> {
|
||||
let embedding_model_id = config
|
||||
.embedding_model
|
||||
.clone()
|
||||
.or_else(|| app.rag_embedding_model.clone());
|
||||
let embedding_model_id = match embedding_model_id {
|
||||
Some(value) => {
|
||||
println!("Embedding model: {value}");
|
||||
value
|
||||
}
|
||||
None => {
|
||||
if !*IS_STDOUT_TERMINAL {
|
||||
bail!(
|
||||
"RAG knowledge base needs an embedding model. Set `embedding_model` \
|
||||
on the rag node, or run the agent interactively once."
|
||||
);
|
||||
}
|
||||
let models = list_models(app, ModelType::Embedding);
|
||||
if models.is_empty() {
|
||||
bail!("No available embedding model");
|
||||
}
|
||||
select_embedding_model(&models)?
|
||||
}
|
||||
};
|
||||
let embedding_model =
|
||||
Model::retrieve_model(app, &embedding_model_id, ModelType::Embedding)?;
|
||||
|
||||
let chunk_size = match config.chunk_size.or(app.rag_chunk_size) {
|
||||
Some(value) => {
|
||||
println!("Chunk size: {value}");
|
||||
value
|
||||
}
|
||||
None => {
|
||||
if !*IS_STDOUT_TERMINAL {
|
||||
bail!(
|
||||
"RAG knowledge base needs a chunk_size. Set `chunk_size` on the \
|
||||
rag node, or run the agent interactively once."
|
||||
);
|
||||
}
|
||||
set_chunk_size(&embedding_model)?
|
||||
}
|
||||
};
|
||||
let chunk_overlap = match config.chunk_overlap.or(app.rag_chunk_overlap) {
|
||||
Some(value) => {
|
||||
println!("Chunk overlap: {value}");
|
||||
value
|
||||
}
|
||||
None => {
|
||||
if !*IS_STDOUT_TERMINAL {
|
||||
bail!(
|
||||
"RAG knowledge base needs a chunk_overlap. Set `chunk_overlap` on \
|
||||
the rag node, or run the agent interactively once."
|
||||
);
|
||||
}
|
||||
set_chunk_overlay(chunk_size / 20)?
|
||||
}
|
||||
};
|
||||
|
||||
let reranker_model = config
|
||||
.reranker_model
|
||||
.clone()
|
||||
.or_else(|| app.rag_reranker_model.clone());
|
||||
let top_k = config.top_k.unwrap_or(app.rag_top_k);
|
||||
let batch_size = config
|
||||
.batch_size
|
||||
.or_else(|| embedding_model.max_batch_size());
|
||||
|
||||
Ok(RagData::new(
|
||||
embedding_model.id(),
|
||||
chunk_size,
|
||||
chunk_overlap,
|
||||
reranker_model,
|
||||
top_k,
|
||||
batch_size,
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn init(
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
|
||||
Reference in New Issue
Block a user