use self::splitter::*; use crate::client::*; use crate::config::*; use crate::utils::*; mod serde_vectors; mod splitter; use anyhow::{Context, Result, anyhow, bail}; use bm25::{Language, SearchEngine, SearchEngineBuilder}; use hnsw_rs::prelude::*; use indexmap::{IndexMap, IndexSet}; use inquire::{Confirm, Select, Text, required, validator::Validation}; use parking_lot::RwLock; use serde::{Deserialize, Serialize}; use serde_json::json; use std::{collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, time::Duration}; use tokio::time::sleep; pub struct Rag { config: GlobalConfig, name: String, path: String, embedding_model: Model, hnsw: Hnsw<'static, f32, DistCosine>, bm25: SearchEngine, data: RagData, last_sources: RwLock>, } impl Debug for Rag { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Rag") .field("name", &self.name) .field("path", &self.path) .field("embedding_model", &self.embedding_model) .field("data", &self.data) .finish() } } impl Clone for Rag { fn clone(&self) -> Self { Self { config: self.config.clone(), name: self.name.clone(), path: self.path.clone(), embedding_model: self.embedding_model.clone(), hnsw: self.data.build_hnsw(), bm25: self.data.build_bm25(), data: self.data.clone(), last_sources: RwLock::new(None), } } } impl Rag { pub async fn init( config: &GlobalConfig, name: &str, save_path: &Path, doc_paths: &[String], abort_signal: AbortSignal, ) -> Result { if !*IS_STDOUT_TERMINAL { bail!("Failed to init rag in non-interactive mode"); } println!("⚙ Initializing RAG..."); let (embedding_model, chunk_size, chunk_overlap) = Self::create_config(config)?; let (reranker_model, top_k) = { let config = config.read(); (config.rag_reranker_model.clone(), config.rag_top_k) }; let data = RagData::new( embedding_model.id(), chunk_size, chunk_overlap, reranker_model, top_k, embedding_model.max_batch_size(), ); let mut rag = Self::create(config, name, save_path, data)?; let mut paths = doc_paths.to_vec(); if paths.is_empty() { paths = add_documents()?; }; let loaders = config.read().document_loaders.clone(); let (spinner, spinner_rx) = Spinner::create(""); abortable_run_with_spinner_rx( rag.sync_documents(&paths, true, loaders, Some(spinner)), spinner_rx, abort_signal, ) .await?; if rag.save()? { println!("✓ Saved RAG to '{}'.", save_path.display()); } Ok(rag) } pub fn load(config: &GlobalConfig, name: &str, path: &Path) -> Result { let err = || format!("Failed to load rag '{name}' at '{}'", path.display()); let content = fs::read_to_string(path).with_context(err)?; let data: RagData = serde_yaml::from_str(&content).with_context(err)?; Self::create(config, name, path, data) } pub fn create(config: &GlobalConfig, name: &str, path: &Path, data: RagData) -> Result { let hnsw = data.build_hnsw(); let bm25 = data.build_bm25(); let embedding_model = Model::retrieve_model(&config.read(), &data.embedding_model, ModelType::Embedding)?; let rag = Rag { config: config.clone(), name: name.to_string(), path: path.display().to_string(), data, embedding_model, hnsw, bm25, last_sources: RwLock::new(None), }; Ok(rag) } pub fn document_paths(&self) -> &[String] { &self.data.document_paths } pub async fn refresh_document_paths( &mut self, document_paths: &[String], refresh: bool, config: &GlobalConfig, abort_signal: AbortSignal, ) -> Result<()> { let loaders = config.read().document_loaders.clone(); let (spinner, spinner_rx) = Spinner::create(""); abortable_run_with_spinner_rx( self.sync_documents(document_paths, refresh, loaders, Some(spinner)), spinner_rx, abort_signal, ) .await?; if self.save()? { println!("✓ Saved rag to '{}'.", self.path); } Ok(()) } pub fn create_config(config: &GlobalConfig) -> Result<(Model, usize, usize)> { let (embedding_model_id, chunk_size, chunk_overlap) = { let config = config.read(); ( config.rag_embedding_model.clone(), config.rag_chunk_size, config.rag_chunk_overlap, ) }; let embedding_model_id = match embedding_model_id { Some(value) => { println!("Select embedding model: {value}"); value } None => { let models = list_models(&config.read(), ModelType::Embedding); if models.is_empty() { bail!("No available embedding model"); } select_embedding_model(&models)? } }; let embedding_model = Model::retrieve_model(&config.read(), &embedding_model_id, ModelType::Embedding)?; let chunk_size = match chunk_size { Some(value) => { println!("Set chunk size: {value}"); value } None => set_chunk_size(&embedding_model)?, }; let chunk_overlap = match chunk_overlap { Some(value) => { println!("Set chunk overlay: {value}"); value } None => { let value = chunk_size / 20; set_chunk_overlay(value)? } }; Ok((embedding_model, chunk_size, chunk_overlap)) } pub fn get_config(&self) -> (Option, usize) { (self.data.reranker_model.clone(), self.data.top_k) } pub fn get_last_sources(&self) -> Option { self.last_sources.read().clone() } pub fn set_last_sources(&self, ids: &[DocumentId]) { let mut sources: IndexMap> = IndexMap::new(); for id in ids { let (file_index, _) = id.split(); if let Some(file) = self.data.files.get(&file_index) { sources .entry(file.path.clone()) .or_default() .push(format!("{id:?}")); } } let sources = if sources.is_empty() { None } else { Some( sources .into_iter() .map(|(path, ids)| format!("{path} ({})", ids.join(","))) .collect::>() .join("\n"), ) }; *self.last_sources.write() = sources; } pub fn set_reranker_model(&mut self, reranker_model: Option) -> Result<()> { self.data.reranker_model = reranker_model; self.save()?; Ok(()) } pub fn set_top_k(&mut self, top_k: usize) -> Result<()> { self.data.top_k = top_k; self.save()?; Ok(()) } pub fn save(&self) -> Result { if self.is_temp() { return Ok(false); } let path = Path::new(&self.path); ensure_parent_exists(path)?; let content = serde_yaml::to_string(&self.data) .with_context(|| format!("Failed to serde rag '{}'", self.name))?; fs::write(path, content).with_context(|| { format!("Failed to save rag '{}' to '{}'", self.name, path.display()) })?; Ok(true) } pub fn export(&self) -> Result { let files: Vec<_> = self .data .files .iter() .map(|(_, v)| { json!({ "path": v.path, "num_chunks": v.documents.len(), }) }) .collect(); let data = json!({ "path": self.path, "embedding_model": self.embedding_model.id(), "chunk_size": self.data.chunk_size, "chunk_overlap": self.data.chunk_overlap, "reranker_model": self.data.reranker_model, "top_k": self.data.top_k, "batch_size": self.data.batch_size, "document_paths": self.data.document_paths, "files": files, }); let output = serde_yaml::to_string(&data) .with_context(|| format!("Unable to show info about rag '{}'", self.name))?; Ok(output) } pub fn name(&self) -> &str { &self.name } pub fn is_temp(&self) -> bool { self.name == TEMP_RAG_NAME } pub async fn search( &self, text: &str, top_k: usize, rerank_model: Option<&str>, abort_signal: AbortSignal, ) -> Result<(String, String, Vec)> { let ret = abortable_run_with_spinner( self.hybird_search(text, top_k, rerank_model), "Searching", abort_signal, ) .await; let results = ret?; let ids: Vec<_> = results.iter().map(|(id, _)| *id).collect(); let embeddings = results .iter() .map(|(id, content)| { let source = self.resolve_source(id); format!("[Source: {source}]\n{content}") }) .collect::>() .join("\n\n"); let sources = self.format_sources(&ids); Ok((embeddings, sources, ids)) } fn resolve_source(&self, id: &DocumentId) -> String { let (file_index, _) = id.split(); self.data .files .get(&file_index) .map(|f| f.path.clone()) .unwrap_or_else(|| "unknown".to_string()) } fn format_sources(&self, ids: &[DocumentId]) -> String { let mut seen = IndexSet::new(); for id in ids { let (file_index, _) = id.split(); if let Some(file) = self.data.files.get(&file_index) { seen.insert(file.path.clone()); } } seen.into_iter() .map(|path| format!("- {path}")) .collect::>() .join("\n") } pub async fn sync_documents( &mut self, paths: &[String], refresh: bool, loaders: HashMap, spinner: Option, ) -> Result<()> { if let Some(spinner) = &spinner { let _ = spinner.set_message(String::new()); } let (document_paths, mut recursive_urls, mut urls, mut protocol_paths, mut local_paths) = resolve_paths(&loaders, paths).await?; let mut to_deleted: IndexMap> = Default::default(); if refresh { for (file_id, file) in &self.data.files { to_deleted .entry(file.hash.clone()) .or_default() .push(*file_id); } } else { let recursive_urls_cloned = recursive_urls.clone(); let match_recursive_url = |v: &str| { recursive_urls_cloned .iter() .any(|start_url| v.starts_with(start_url)) }; recursive_urls = recursive_urls .into_iter() .filter(|v| !self.data.document_paths.contains(&format!("{v}**"))) .collect(); let protocol_paths_cloned = protocol_paths.clone(); let match_protocol_path = |v: &str| protocol_paths_cloned.iter().any(|root| v.starts_with(root)); protocol_paths = protocol_paths .into_iter() .filter(|v| !self.data.document_paths.contains(v)) .collect(); for (file_id, file) in &self.data.files { if is_url(&file.path) { if !urls.swap_remove(&file.path) && !match_recursive_url(&file.path) { to_deleted .entry(file.hash.clone()) .or_default() .push(*file_id); } } else if is_loader_protocol(&loaders, &file.path) { if !match_protocol_path(&file.path) { to_deleted .entry(file.hash.clone()) .or_default() .push(*file_id); } } else if !local_paths.swap_remove(&file.path) { to_deleted .entry(file.hash.clone()) .or_default() .push(*file_id); } } } let mut loaded_documents = vec![]; let mut has_error = false; let mut index = 0; let total = recursive_urls.len() + urls.len() + protocol_paths.len() + local_paths.len(); let handle_error = |error: anyhow::Error, has_error: &mut bool| { println!("{}", warning_text(&format!("⚠️ {error}"))); *has_error = true; }; for start_url in recursive_urls { index += 1; println!("Load {start_url}** [{index}/{total}]"); match load_recursive_url(&loaders, &start_url).await { Ok(v) => loaded_documents.extend(v), Err(err) => handle_error(err, &mut has_error), } } for url in urls { index += 1; println!("Load {url} [{index}/{total}]"); match load_url(&loaders, &url).await { Ok(v) => loaded_documents.push(v), Err(err) => handle_error(err, &mut has_error), } } for protocol_path in protocol_paths { index += 1; println!("Load {protocol_path} [{index}/{total}]"); match load_protocol_path(&loaders, &protocol_path) { Ok(v) => loaded_documents.extend(v), Err(err) => handle_error(err, &mut has_error), } } for local_path in local_paths { index += 1; println!("Load {local_path} [{index}/{total}]"); match load_file(&loaders, &local_path).await { Ok(v) => loaded_documents.push(v), Err(err) => handle_error(err, &mut has_error), } } if has_error { let mut aborted = true; if *IS_STDOUT_TERMINAL && total > 0 { let ans = Confirm::new("Some documents failed to load. Continue?") .with_default(false) .prompt()?; aborted = !ans; } if aborted { bail!("Aborted"); } } let mut rag_files = vec![]; for LoadedDocument { path, contents, mut metadata, } in loaded_documents { let hash = sha256(&contents); if let Some(file_ids) = to_deleted.get_mut(&hash) && let Some((i, _)) = file_ids .iter() .enumerate() .find(|(_, v)| self.data.files[*v].path == path) { if file_ids.len() == 1 { to_deleted.swap_remove(&hash); } else { file_ids.remove(i); } continue; } let extension = metadata .swap_remove(EXTENSION_METADATA) .unwrap_or_else(|| DEFAULT_EXTENSION.into()); let separator = get_separators(&extension); let splitter = RecursiveCharacterTextSplitter::new( self.data.chunk_size, self.data.chunk_overlap, &separator, ); let split_options = SplitterChunkHeaderOptions::default(); let document = RagDocument::new(contents); let split_documents = splitter.split_documents(&[document], &split_options); rag_files.push(RagFile { hash: hash.clone(), path, documents: split_documents, }); } let mut next_file_id = self.data.next_file_id; let mut files = vec![]; let mut document_ids = vec![]; let mut embeddings = vec![]; if !rag_files.is_empty() { let mut texts = vec![]; for file in rag_files.into_iter() { for (document_index, document) in file.documents.iter().enumerate() { document_ids.push(DocumentId::new(next_file_id, document_index)); texts.push(document.page_content.clone()) } files.push((next_file_id, file)); next_file_id += 1; } let embeddings_data = EmbeddingsData::new(texts, false); embeddings = self .create_embeddings(embeddings_data, spinner.clone()) .await?; } let to_delete_file_ids: Vec<_> = to_deleted.values().flatten().copied().collect(); self.data.del(to_delete_file_ids); self.data.add(next_file_id, files, document_ids, embeddings); self.data.document_paths = document_paths.into_iter().collect(); if self.data.files.is_empty() { bail!("No RAG files"); } progress(&spinner, "Building store".into()); self.hnsw = self.data.build_hnsw(); self.bm25 = self.data.build_bm25(); Ok(()) } async fn hybird_search( &self, query: &str, top_k: usize, rerank_model: Option<&str>, ) -> Result> { let (vector_search_results, keyword_search_results) = tokio::join!( self.vector_search(query, top_k, 0.0), self.keyword_search(query, top_k, 0.0), ); let vector_search_results = vector_search_results?; debug!("vector_search_results: {vector_search_results:?}",); let vector_search_ids: Vec = vector_search_results.into_iter().map(|(v, _)| v).collect(); let keyword_search_results = keyword_search_results?; debug!("keyword_search_results: {keyword_search_results:?}",); let keyword_search_ids: Vec = keyword_search_results.into_iter().map(|(v, _)| v).collect(); let ids = match rerank_model { Some(model_id) => { let model = Model::retrieve_model(&self.config.read(), model_id, ModelType::Reranker)?; let client = init_client(&self.config, Some(model))?; let ids: IndexSet = [vector_search_ids, keyword_search_ids] .concat() .into_iter() .collect(); let mut documents = vec![]; let mut documents_ids = vec![]; for id in ids { if let Some(document) = self.data.get(id) { documents_ids.push(id); documents.push(document.page_content.to_string()); } } let data = RerankData::new(query.to_string(), documents, top_k); let list = client.rerank(&data).await.context("Failed to rerank")?; let ids: Vec<_> = list .into_iter() .take(top_k) .filter_map(|item| documents_ids.get(item.index).cloned()) .collect(); debug!("rerank_ids: {ids:?}"); ids } None => { let ids = reciprocal_rank_fusion( vec![vector_search_ids, keyword_search_ids], vec![1.125, 1.0], top_k, ); debug!("rrf_ids: {ids:?}"); ids } }; let output = ids .into_iter() .filter_map(|id| { let document = self.data.get(id)?; Some((id, document.page_content.clone())) }) .collect(); Ok(output) } async fn vector_search( &self, query: &str, top_k: usize, min_score: f32, ) -> Result> { let splitter = RecursiveCharacterTextSplitter::new( self.data.chunk_size, self.data.chunk_overlap, &DEFAULT_SEPARATORS, ); let texts = splitter.split_text(query); let embeddings_data = EmbeddingsData::new(texts, true); let embeddings = self.create_embeddings(embeddings_data, None).await?; let output = self .hnsw .parallel_search(&embeddings, top_k, 30) .into_iter() .flat_map(|list| { list.into_iter() .filter_map(|v| { let score = 1.0 - v.distance; if score > min_score { Some((DocumentId(v.d_id), score)) } else { None } }) .collect::>() }) .collect(); Ok(output) } async fn keyword_search( &self, query: &str, top_k: usize, min_score: f32, ) -> Result> { let results = self.bm25.search(query, top_k); let output: Vec<(DocumentId, f32)> = results .into_iter() .filter_map(|v| { let score = v.score; if score > min_score { Some((v.document.id, score)) } else { None } }) .collect(); Ok(output) } async fn create_embeddings( &self, data: EmbeddingsData, spinner: Option, ) -> Result { let embedding_client = init_client(&self.config, Some(self.embedding_model.clone()))?; let EmbeddingsData { texts, query } = data; let batch_size = self .data .batch_size .or_else(|| self.embedding_model.max_batch_size()); let batch_size = match self.embedding_model.max_input_tokens() { Some(max_input_tokens) => { let x = max_input_tokens / self.data.chunk_size; match batch_size { Some(y) => x.min(y), None => x, } } None => batch_size.unwrap_or(1), }; let mut output = vec![]; let batch_chunks = texts.chunks(batch_size.max(1)); let batch_chunks_len = batch_chunks.len(); let retry_limit = env::var(get_env_name("embeddings_retry_limit")) .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(2); for (index, texts) in batch_chunks.enumerate() { progress( &spinner, format!("Creating embeddings [{}/{batch_chunks_len}]", index + 1), ); let chunk_data = EmbeddingsData { texts: texts.to_vec(), query, }; let mut retry = 0; let chunk_output = loop { retry += 1; match embedding_client.embeddings(&chunk_data).await { Ok(v) => break v, Err(e) if retry < retry_limit => { debug!("retry {retry} failed: {e}"); sleep(Duration::from_secs(2u64.pow(retry - 1))).await; continue; } Err(e) => { return Err(e).with_context(|| { format!("Failed to create embedding after {retry_limit} attempts") })?; } } }; output.extend(chunk_output); } Ok(output) } } #[derive(Clone, Serialize, Deserialize)] pub struct RagData { pub embedding_model: String, pub chunk_size: usize, pub chunk_overlap: usize, pub reranker_model: Option, pub top_k: usize, pub batch_size: Option, pub next_file_id: FileId, pub document_paths: Vec, pub files: IndexMap, #[serde(with = "serde_vectors")] pub vectors: IndexMap>, } impl Debug for RagData { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RagData") .field("embedding_model", &self.embedding_model) .field("chunk_size", &self.chunk_size) .field("chunk_overlap", &self.chunk_overlap) .field("reranker_model", &self.reranker_model) .field("top_k", &self.top_k) .field("batch_size", &self.batch_size) .field("next_file_id", &self.next_file_id) .field("document_paths", &self.document_paths) .field("files", &self.files) .finish() } } impl RagData { pub fn new( embedding_model: String, chunk_size: usize, chunk_overlap: usize, reranker_model: Option, top_k: usize, batch_size: Option, ) -> Self { Self { embedding_model, chunk_size, chunk_overlap, reranker_model, top_k, batch_size, next_file_id: 0, document_paths: Default::default(), files: Default::default(), vectors: Default::default(), } } pub fn get(&self, id: DocumentId) -> Option<&RagDocument> { let (file_index, document_index) = id.split(); let file = self.files.get(&file_index)?; let document = file.documents.get(document_index)?; Some(document) } pub fn del(&mut self, file_ids: Vec) { for file_id in file_ids { if let Some(file) = self.files.swap_remove(&file_id) { for (document_index, _) in file.documents.iter().enumerate() { let document_id = DocumentId::new(file_id, document_index); self.vectors.swap_remove(&document_id); } } } } pub fn add( &mut self, next_file_id: FileId, files: Vec<(FileId, RagFile)>, document_ids: Vec, embeddings: EmbeddingsOutput, ) { self.next_file_id = next_file_id; self.files.extend(files); self.vectors .extend(document_ids.into_iter().zip(embeddings)); } pub fn build_hnsw(&self) -> Hnsw<'static, f32, DistCosine> { let hnsw = Hnsw::new(32, self.vectors.len(), 16, 200, DistCosine {}); let list: Vec<_> = self.vectors.iter().map(|(k, v)| (v, k.0)).collect(); hnsw.parallel_insert(&list); hnsw } pub fn build_bm25(&self) -> SearchEngine { let mut documents = vec![]; for (file_index, file) in self.files.iter() { for (document_index, document) in file.documents.iter().enumerate() { let id = DocumentId::new(*file_index, document_index); documents.push(bm25::Document::new(id, &document.page_content)) } } SearchEngineBuilder::::with_documents(Language::English, documents) .k1(1.5) .b(0.75) .build() } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RagFile { hash: String, path: String, documents: Vec, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RagDocument { pub page_content: String, pub metadata: DocumentMetadata, } impl RagDocument { pub fn new>(page_content: S) -> Self { RagDocument { page_content: page_content.into(), metadata: IndexMap::new(), } } } impl Default for RagDocument { fn default() -> Self { RagDocument { page_content: "".to_string(), metadata: IndexMap::new(), } } } pub type FileId = usize; #[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)] pub struct DocumentId(usize); impl Debug for DocumentId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let (file_index, document_index) = self.split(); f.write_fmt(format_args!("{file_index}-{document_index}")) } } impl DocumentId { pub fn new(file_index: usize, document_index: usize) -> Self { let value = (file_index << (usize::BITS / 2)) | document_index; Self(value) } pub fn split(self) -> (usize, usize) { let value = self.0; let low_mask = (1 << (usize::BITS / 2)) - 1; let low = value & low_mask; let high = value >> (usize::BITS / 2); (high, low) } } fn select_embedding_model(models: &[&Model]) -> Result { let models: Vec<_> = models .iter() .map(|v| SelectOption::new(v.id(), v.description())) .collect(); let result = Select::new("Select embedding model:", models).prompt()?; Ok(result.value) } #[derive(Debug)] struct SelectOption { pub value: String, pub description: String, } impl SelectOption { pub fn new(value: String, description: String) -> Self { Self { value, description } } } impl std::fmt::Display for SelectOption { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{} ({})", self.value, self.description) } } fn set_chunk_size(model: &Model) -> Result { let default_value = model.default_chunk_size().to_string(); let help_message = model .max_tokens_per_chunk() .map(|v| format!("The model's max_tokens is {v}")); let mut text = Text::new("Set chunk size:") .with_default(&default_value) .with_validator(move |text: &str| { let out = match text.parse::() { Ok(_) => Validation::Valid, Err(_) => Validation::Invalid("Must be a integer".into()), }; Ok(out) }); if let Some(help_message) = &help_message { text = text.with_help_message(help_message); } let value = text.prompt()?; value.parse().map_err(|_| anyhow!("Invalid chunk_size")) } fn set_chunk_overlay(default_value: usize) -> Result { let value = Text::new("Set chunk overlay:") .with_default(&default_value.to_string()) .with_validator(move |text: &str| { let out = match text.parse::() { Ok(_) => Validation::Valid, Err(_) => Validation::Invalid("Must be a integer".into()), }; Ok(out) }) .prompt()?; value.parse().map_err(|_| anyhow!("Invalid chunk_overlay")) } fn add_documents() -> Result> { let text = Text::new("Add documents:") .with_validator(required!("This field is required")) .with_help_message("e.g. file;dir/;dir/**/*.{md,mdx};loader:resource;url;website/**") .prompt()?; let paths = text .split(';') .filter_map(|v| { let v = v.trim().to_string(); if v.is_empty() { None } else { Some(v) } }) .collect(); Ok(paths) } async fn resolve_paths>( loaders: &HashMap, paths: &[T], ) -> Result<( IndexSet, IndexSet, IndexSet, IndexSet, IndexSet, )> { let mut document_paths = IndexSet::new(); let mut recursive_urls = IndexSet::new(); let mut urls = IndexSet::new(); let mut protocol_paths = IndexSet::new(); let mut absolute_paths = vec![]; for path in paths { let path = path.as_ref().trim(); if is_url(path) { if let Some(start_url) = path.strip_suffix("**") { recursive_urls.insert(start_url.to_string()); } else { urls.insert(path.to_string()); } document_paths.insert(path.to_string()); } else if is_loader_protocol(loaders, path) { protocol_paths.insert(path.to_string()); document_paths.insert(path.to_string()); } else { let resolved_path = resolve_home_dir(path); let absolute_path = to_absolute_path(&resolved_path) .with_context(|| format!("Invalid path '{path}'"))?; absolute_paths.push(resolved_path); document_paths.insert(absolute_path); } } let local_paths = expand_glob_paths(&absolute_paths, false).await?; Ok(( document_paths, recursive_urls, urls, protocol_paths, local_paths, )) } fn progress(spinner: &Option, message: String) { if let Some(spinner) = spinner { let _ = spinner.set_message(message); } } fn reciprocal_rank_fusion( list_of_document_ids: Vec>, list_of_weights: Vec, top_k: usize, ) -> Vec { let rrf_k = top_k * 2; let mut map: IndexMap = IndexMap::new(); for (document_ids, weight) in list_of_document_ids .into_iter() .zip(list_of_weights.into_iter()) { for (index, &item) in document_ids.iter().enumerate() { *map.entry(item).or_default() += (1.0 / ((rrf_k + index + 1) as f32)) * weight; } } let mut sorted_items: Vec<(DocumentId, f32)> = map.into_iter().collect(); sorted_items.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); sorted_items .into_iter() .take(top_k) .map(|(v, _)| v) .collect() }