feat: 99% complete migration to new state structs to get away from God-Config struct; i.e. AppConfig, AppState, and RequestContext
This commit is contained in:
+40
-33
@@ -15,11 +15,13 @@ use inquire::{Confirm, Select, Text, required, validator::Validation};
|
||||
use parking_lot::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::{collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, time::Duration};
|
||||
use std::{
|
||||
collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, sync::Arc, time::Duration,
|
||||
};
|
||||
use tokio::time::sleep;
|
||||
|
||||
pub struct Rag {
|
||||
config: GlobalConfig,
|
||||
app_config: Arc<AppConfig>,
|
||||
name: String,
|
||||
path: String,
|
||||
embedding_model: Model,
|
||||
@@ -43,7 +45,7 @@ impl Debug for Rag {
|
||||
impl Clone for Rag {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
config: self.config.clone(),
|
||||
app_config: self.app_config.clone(),
|
||||
name: self.name.clone(),
|
||||
path: self.path.clone(),
|
||||
embedding_model: self.embedding_model.clone(),
|
||||
@@ -56,8 +58,12 @@ impl Clone for Rag {
|
||||
}
|
||||
|
||||
impl Rag {
|
||||
fn create_embeddings_client(&self, model: Model) -> Result<Box<dyn Client>> {
|
||||
init_client(&self.app_config, model)
|
||||
}
|
||||
|
||||
pub async fn init(
|
||||
config: &GlobalConfig,
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
save_path: &Path,
|
||||
doc_paths: &[String],
|
||||
@@ -67,11 +73,9 @@ impl Rag {
|
||||
bail!("Failed to init rag in non-interactive mode");
|
||||
}
|
||||
println!("⚙ Initializing RAG...");
|
||||
let (embedding_model, chunk_size, chunk_overlap) = Self::create_config(config)?;
|
||||
let (reranker_model, top_k) = {
|
||||
let config = config.read();
|
||||
(config.rag_reranker_model.clone(), config.rag_top_k)
|
||||
};
|
||||
let (embedding_model, chunk_size, chunk_overlap) = Self::create_config(app)?;
|
||||
let reranker_model = app.rag_reranker_model.clone();
|
||||
let top_k = app.rag_top_k;
|
||||
let data = RagData::new(
|
||||
embedding_model.id(),
|
||||
chunk_size,
|
||||
@@ -80,12 +84,12 @@ impl Rag {
|
||||
top_k,
|
||||
embedding_model.max_batch_size(),
|
||||
);
|
||||
let mut rag = Self::create(config, name, save_path, data)?;
|
||||
let mut rag = Self::create(app, name, save_path, data)?;
|
||||
let mut paths = doc_paths.to_vec();
|
||||
if paths.is_empty() {
|
||||
paths = add_documents()?;
|
||||
};
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let loaders = app.document_loaders.clone();
|
||||
let (spinner, spinner_rx) = Spinner::create("");
|
||||
abortable_run_with_spinner_rx(
|
||||
rag.sync_documents(&paths, true, loaders, Some(spinner)),
|
||||
@@ -99,20 +103,29 @@ impl Rag {
|
||||
Ok(rag)
|
||||
}
|
||||
|
||||
pub fn load(config: &GlobalConfig, name: &str, path: &Path) -> Result<Self> {
|
||||
pub fn load(
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
path: &Path,
|
||||
) -> Result<Self> {
|
||||
let err = || format!("Failed to load rag '{name}' at '{}'", path.display());
|
||||
let content = fs::read_to_string(path).with_context(err)?;
|
||||
let data: RagData = serde_yaml::from_str(&content).with_context(err)?;
|
||||
Self::create(config, name, path, data)
|
||||
Self::create(app, name, path, data)
|
||||
}
|
||||
|
||||
pub fn create(config: &GlobalConfig, name: &str, path: &Path, data: RagData) -> Result<Self> {
|
||||
pub fn create(
|
||||
app: &AppConfig,
|
||||
name: &str,
|
||||
path: &Path,
|
||||
data: RagData,
|
||||
) -> Result<Self> {
|
||||
let hnsw = data.build_hnsw();
|
||||
let bm25 = data.build_bm25();
|
||||
let embedding_model =
|
||||
Model::retrieve_model(&config.read(), &data.embedding_model, ModelType::Embedding)?;
|
||||
Model::retrieve_model(app, &data.embedding_model, ModelType::Embedding)?;
|
||||
let rag = Rag {
|
||||
config: config.clone(),
|
||||
app_config: Arc::new(app.clone()),
|
||||
name: name.to_string(),
|
||||
path: path.display().to_string(),
|
||||
data,
|
||||
@@ -132,10 +145,10 @@ impl Rag {
|
||||
&mut self,
|
||||
document_paths: &[String],
|
||||
refresh: bool,
|
||||
config: &GlobalConfig,
|
||||
app: &AppConfig,
|
||||
abort_signal: AbortSignal,
|
||||
) -> Result<()> {
|
||||
let loaders = config.read().document_loaders.clone();
|
||||
let loaders = app.document_loaders.clone();
|
||||
let (spinner, spinner_rx) = Spinner::create("");
|
||||
abortable_run_with_spinner_rx(
|
||||
self.sync_documents(document_paths, refresh, loaders, Some(spinner)),
|
||||
@@ -149,22 +162,17 @@ impl Rag {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn create_config(config: &GlobalConfig) -> Result<(Model, usize, usize)> {
|
||||
let (embedding_model_id, chunk_size, chunk_overlap) = {
|
||||
let config = config.read();
|
||||
(
|
||||
config.rag_embedding_model.clone(),
|
||||
config.rag_chunk_size,
|
||||
config.rag_chunk_overlap,
|
||||
)
|
||||
};
|
||||
pub fn create_config(app: &AppConfig) -> Result<(Model, usize, usize)> {
|
||||
let embedding_model_id = app.rag_embedding_model.clone();
|
||||
let chunk_size = app.rag_chunk_size;
|
||||
let chunk_overlap = app.rag_chunk_overlap;
|
||||
let embedding_model_id = match embedding_model_id {
|
||||
Some(value) => {
|
||||
println!("Select embedding model: {value}");
|
||||
value
|
||||
}
|
||||
None => {
|
||||
let models = list_models(&config.read(), ModelType::Embedding);
|
||||
let models = list_models(app, ModelType::Embedding);
|
||||
if models.is_empty() {
|
||||
bail!("No available embedding model");
|
||||
}
|
||||
@@ -172,7 +180,7 @@ impl Rag {
|
||||
}
|
||||
};
|
||||
let embedding_model =
|
||||
Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?;
|
||||
Model::retrieve_model(app, &embedding_model_id, ModelType::Embedding)?;
|
||||
|
||||
let chunk_size = match chunk_size {
|
||||
Some(value) => {
|
||||
@@ -560,9 +568,8 @@ impl Rag {
|
||||
|
||||
let ids = match rerank_model {
|
||||
Some(model_id) => {
|
||||
let model =
|
||||
Model::retrieve_model(&self.config.read(), model_id, ModelType::Reranker)?;
|
||||
let client = init_client(&self.config, Some(model))?;
|
||||
let model = Model::retrieve_model(&self.app_config, model_id, ModelType::Reranker)?;
|
||||
let client = self.create_embeddings_client(model)?;
|
||||
let ids: IndexSet<DocumentId> = [vector_search_ids, keyword_search_ids]
|
||||
.concat()
|
||||
.into_iter()
|
||||
@@ -665,7 +672,7 @@ impl Rag {
|
||||
data: EmbeddingsData,
|
||||
spinner: Option<Spinner>,
|
||||
) -> Result<EmbeddingsOutput> {
|
||||
let embedding_client = init_client(&self.config, Some(self.embedding_model.clone()))?;
|
||||
let embedding_client = self.create_embeddings_client(self.embedding_model.clone())?;
|
||||
let EmbeddingsData { texts, query } = data;
|
||||
let batch_size = self
|
||||
.data
|
||||
|
||||
Reference in New Issue
Block a user