Files
loki/src/rag/mod.rs

1086 lines
35 KiB
Rust

use self::splitter::*;
use crate::client::*;
use crate::config::*;
use crate::utils::*;
mod serde_vectors;
mod splitter;
use anyhow::{Context, Result, anyhow, bail};
use bm25::{Language, SearchEngine, SearchEngineBuilder};
use hnsw_rs::prelude::*;
use indexmap::{IndexMap, IndexSet};
use inquire::{Confirm, Select, Text, required, validator::Validation};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::{
collections::HashMap, env, fmt::Debug, fs, hash::Hash, path::Path, sync::Arc, time::Duration,
};
use tokio::time::sleep;
const RAG_TEMPLATE: &str = r#"Answer the query based on the context while respecting the rules. (user query, some textual context and rules, all inside xml tags)
<context>
__CONTEXT__
</context>
<sources>
__SOURCES__
</sources>
<rules>
- If you don't know, just say so.
- If you are not sure, ask for clarification.
- Answer in the same language as the user query.
- If the context appears unreadable or of poor quality, tell the user then answer as best as you can.
- If the answer is not in the context but you think you know the answer, explain that to the user then answer with your own knowledge.
- Answer directly and without using xml tags.
- When using information from the context, cite the relevant source from the <sources> section.
</rules>
<user_query>
__INPUT__
</user_query>"#;
pub struct Rag {
app_config: Arc<AppConfig>,
name: String,
path: String,
embedding_model: Model,
hnsw: Hnsw<'static, f32, DistCosine>,
bm25: SearchEngine<DocumentId>,
data: RagData,
last_sources: RwLock<Option<String>>,
}
impl Debug for Rag {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Rag")
.field("name", &self.name)
.field("path", &self.path)
.field("embedding_model", &self.embedding_model)
.field("data", &self.data)
.finish()
}
}
impl Clone for Rag {
fn clone(&self) -> Self {
Self {
app_config: self.app_config.clone(),
name: self.name.clone(),
path: self.path.clone(),
embedding_model: self.embedding_model.clone(),
hnsw: self.data.build_hnsw(),
bm25: self.data.build_bm25(),
data: self.data.clone(),
last_sources: RwLock::new(None),
}
}
}
impl Rag {
fn create_embeddings_client(&self, model: Model) -> Result<Box<dyn Client>> {
init_client(&self.app_config, model)
}
pub async fn init(
app: &AppConfig,
name: &str,
save_path: &Path,
doc_paths: &[String],
abort_signal: AbortSignal,
) -> Result<Self> {
if !*IS_STDOUT_TERMINAL {
bail!("Failed to init rag in non-interactive mode");
}
println!("⚙ Initializing RAG...");
let (embedding_model, chunk_size, chunk_overlap) = Self::create_config(app)?;
let reranker_model = app.rag_reranker_model.clone();
let top_k = app.rag_top_k;
let data = RagData::new(
embedding_model.id(),
chunk_size,
chunk_overlap,
reranker_model,
top_k,
embedding_model.max_batch_size(),
);
let mut rag = Self::create(app, name, save_path, data)?;
let mut paths = doc_paths.to_vec();
if paths.is_empty() {
paths = add_documents()?;
};
let loaders = app.document_loaders.clone();
let (spinner, spinner_rx) = Spinner::create("");
abortable_run_with_spinner_rx(
rag.sync_documents(&paths, true, loaders, Some(spinner)),
spinner_rx,
abort_signal,
)
.await?;
if rag.save()? {
println!("✓ Saved RAG to '{}'.", save_path.display());
}
Ok(rag)
}
pub fn load(app: &AppConfig, name: &str, path: &Path) -> Result<Self> {
let err = || format!("Failed to load rag '{name}' at '{}'", path.display());
let content = fs::read_to_string(path).with_context(err)?;
let data: RagData = serde_yaml::from_str(&content).with_context(err)?;
Self::create(app, name, path, data)
}
pub fn create(app: &AppConfig, name: &str, path: &Path, data: RagData) -> Result<Self> {
let hnsw = data.build_hnsw();
let bm25 = data.build_bm25();
let embedding_model =
Model::retrieve_model(app, &data.embedding_model, ModelType::Embedding)?;
let rag = Rag {
app_config: Arc::new(app.clone()),
name: name.to_string(),
path: path.display().to_string(),
data,
embedding_model,
hnsw,
bm25,
last_sources: RwLock::new(None),
};
Ok(rag)
}
pub fn document_paths(&self) -> &[String] {
&self.data.document_paths
}
pub async fn refresh_document_paths(
&mut self,
document_paths: &[String],
refresh: bool,
app: &AppConfig,
abort_signal: AbortSignal,
) -> Result<()> {
let loaders = app.document_loaders.clone();
let (spinner, spinner_rx) = Spinner::create("");
abortable_run_with_spinner_rx(
self.sync_documents(document_paths, refresh, loaders, Some(spinner)),
spinner_rx,
abort_signal,
)
.await?;
if self.save()? {
println!("✓ Saved rag to '{}'.", self.path);
}
Ok(())
}
pub fn create_config(app: &AppConfig) -> Result<(Model, usize, usize)> {
let embedding_model_id = app.rag_embedding_model.clone();
let chunk_size = app.rag_chunk_size;
let chunk_overlap = app.rag_chunk_overlap;
let embedding_model_id = match embedding_model_id {
Some(value) => {
println!("Select embedding model: {value}");
value
}
None => {
let models = list_models(app, ModelType::Embedding);
if models.is_empty() {
bail!("No available embedding model");
}
select_embedding_model(&models)?
}
};
let embedding_model =
Model::retrieve_model(app, &embedding_model_id, ModelType::Embedding)?;
let chunk_size = match chunk_size {
Some(value) => {
println!("Set chunk size: {value}");
value
}
None => set_chunk_size(&embedding_model)?,
};
let chunk_overlap = match chunk_overlap {
Some(value) => {
println!("Set chunk overlay: {value}");
value
}
None => {
let value = chunk_size / 20;
set_chunk_overlay(value)?
}
};
Ok((embedding_model, chunk_size, chunk_overlap))
}
pub fn get_config(&self) -> (Option<String>, usize) {
(self.data.reranker_model.clone(), self.data.top_k)
}
pub fn get_last_sources(&self) -> Option<String> {
self.last_sources.read().clone()
}
pub fn set_last_sources(&self, ids: &[DocumentId]) {
let mut sources: IndexMap<String, Vec<String>> = IndexMap::new();
for id in ids {
let (file_index, _) = id.split();
if let Some(file) = self.data.files.get(&file_index) {
sources
.entry(file.path.clone())
.or_default()
.push(format!("{id:?}"));
}
}
let sources = if sources.is_empty() {
None
} else {
Some(
sources
.into_iter()
.map(|(path, ids)| format!("{path} ({})", ids.join(",")))
.collect::<Vec<_>>()
.join("\n"),
)
};
*self.last_sources.write() = sources;
}
pub fn set_reranker_model(&mut self, reranker_model: Option<String>) -> Result<()> {
self.data.reranker_model = reranker_model;
self.save()?;
Ok(())
}
pub fn set_top_k(&mut self, top_k: usize) -> Result<()> {
self.data.top_k = top_k;
self.save()?;
Ok(())
}
pub fn save(&self) -> Result<bool> {
if self.is_temp() {
return Ok(false);
}
let path = Path::new(&self.path);
ensure_parent_exists(path)?;
let content = serde_yaml::to_string(&self.data)
.with_context(|| format!("Failed to serde rag '{}'", self.name))?;
fs::write(path, content).with_context(|| {
format!("Failed to save rag '{}' to '{}'", self.name, path.display())
})?;
Ok(true)
}
pub fn export(&self) -> Result<String> {
let files: Vec<_> = self
.data
.files
.iter()
.map(|(_, v)| {
json!({
"path": v.path,
"num_chunks": v.documents.len(),
})
})
.collect();
let data = json!({
"path": self.path,
"embedding_model": self.embedding_model.id(),
"chunk_size": self.data.chunk_size,
"chunk_overlap": self.data.chunk_overlap,
"reranker_model": self.data.reranker_model,
"top_k": self.data.top_k,
"batch_size": self.data.batch_size,
"document_paths": self.data.document_paths,
"files": files,
});
let output = serde_yaml::to_string(&data)
.with_context(|| format!("Unable to show info about rag '{}'", self.name))?;
Ok(output)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn is_temp(&self) -> bool {
self.name == TEMP_RAG_NAME
}
pub async fn search(
&self,
text: &str,
top_k: usize,
rerank_model: Option<&str>,
abort_signal: AbortSignal,
) -> Result<(String, String, Vec<DocumentId>)> {
let ret = abortable_run_with_spinner(
self.hybird_search(text, top_k, rerank_model),
"Searching",
abort_signal,
)
.await;
let results = ret?;
let ids: Vec<_> = results.iter().map(|(id, _)| *id).collect();
let embeddings = results
.iter()
.map(|(id, content)| {
let source = self.resolve_source(id);
format!("[Source: {source}]\n{content}")
})
.collect::<Vec<_>>()
.join("\n\n");
let sources = self.format_sources(&ids);
Ok((embeddings, sources, ids))
}
pub async fn search_with_template(
&self,
app: &AppConfig,
text: &str,
abort_signal: AbortSignal,
) -> Result<String> {
let (reranker_model, top_k) = self.get_config();
let (embeddings, sources, ids) = self
.search(text, top_k, reranker_model.as_deref(), abort_signal)
.await?;
let rag_template = app.rag_template.as_deref().unwrap_or(RAG_TEMPLATE);
let text = if embeddings.is_empty() {
text.to_string()
} else {
rag_template
.replace("__CONTEXT__", &embeddings)
.replace("__SOURCES__", &sources)
.replace("__INPUT__", text)
};
self.set_last_sources(&ids);
Ok(text)
}
fn resolve_source(&self, id: &DocumentId) -> String {
let (file_index, _) = id.split();
self.data
.files
.get(&file_index)
.map(|f| f.path.clone())
.unwrap_or_else(|| "unknown".to_string())
}
fn format_sources(&self, ids: &[DocumentId]) -> String {
let mut seen = IndexSet::new();
for id in ids {
let (file_index, _) = id.split();
if let Some(file) = self.data.files.get(&file_index) {
seen.insert(file.path.clone());
}
}
seen.into_iter()
.map(|path| format!("- {path}"))
.collect::<Vec<_>>()
.join("\n")
}
pub async fn sync_documents(
&mut self,
paths: &[String],
refresh: bool,
loaders: HashMap<String, String>,
spinner: Option<Spinner>,
) -> Result<()> {
if let Some(spinner) = &spinner {
let _ = spinner.set_message(String::new());
}
let (document_paths, mut recursive_urls, mut urls, mut protocol_paths, mut local_paths) =
resolve_paths(&loaders, paths).await?;
let mut to_deleted: IndexMap<String, Vec<FileId>> = Default::default();
if refresh {
for (file_id, file) in &self.data.files {
to_deleted
.entry(file.hash.clone())
.or_default()
.push(*file_id);
}
} else {
let recursive_urls_cloned = recursive_urls.clone();
let match_recursive_url = |v: &str| {
recursive_urls_cloned
.iter()
.any(|start_url| v.starts_with(start_url))
};
recursive_urls = recursive_urls
.into_iter()
.filter(|v| !self.data.document_paths.contains(&format!("{v}**")))
.collect();
let protocol_paths_cloned = protocol_paths.clone();
let match_protocol_path =
|v: &str| protocol_paths_cloned.iter().any(|root| v.starts_with(root));
protocol_paths = protocol_paths
.into_iter()
.filter(|v| !self.data.document_paths.contains(v))
.collect();
for (file_id, file) in &self.data.files {
if is_url(&file.path) {
if !urls.swap_remove(&file.path) && !match_recursive_url(&file.path) {
to_deleted
.entry(file.hash.clone())
.or_default()
.push(*file_id);
}
} else if is_loader_protocol(&loaders, &file.path) {
if !match_protocol_path(&file.path) {
to_deleted
.entry(file.hash.clone())
.or_default()
.push(*file_id);
}
} else if !local_paths.swap_remove(&file.path) {
to_deleted
.entry(file.hash.clone())
.or_default()
.push(*file_id);
}
}
}
let mut loaded_documents = vec![];
let mut has_error = false;
let mut index = 0;
let total = recursive_urls.len() + urls.len() + protocol_paths.len() + local_paths.len();
let handle_error = |error: anyhow::Error, has_error: &mut bool| {
println!("{}", warning_text(&format!("⚠️ {error}")));
*has_error = true;
};
for start_url in recursive_urls {
index += 1;
println!("Load {start_url}** [{index}/{total}]");
match load_recursive_url(&loaders, &start_url).await {
Ok(v) => loaded_documents.extend(v),
Err(err) => handle_error(err, &mut has_error),
}
}
for url in urls {
index += 1;
println!("Load {url} [{index}/{total}]");
match load_url(&loaders, &url).await {
Ok(v) => loaded_documents.push(v),
Err(err) => handle_error(err, &mut has_error),
}
}
for protocol_path in protocol_paths {
index += 1;
println!("Load {protocol_path} [{index}/{total}]");
match load_protocol_path(&loaders, &protocol_path) {
Ok(v) => loaded_documents.extend(v),
Err(err) => handle_error(err, &mut has_error),
}
}
for local_path in local_paths {
index += 1;
println!("Load {local_path} [{index}/{total}]");
match load_file(&loaders, &local_path).await {
Ok(v) => loaded_documents.push(v),
Err(err) => handle_error(err, &mut has_error),
}
}
if has_error {
let mut aborted = true;
if *IS_STDOUT_TERMINAL && total > 0 {
let ans = Confirm::new("Some documents failed to load. Continue?")
.with_default(false)
.prompt()?;
aborted = !ans;
}
if aborted {
bail!("Aborted");
}
}
let mut rag_files = vec![];
for LoadedDocument {
path,
contents,
mut metadata,
} in loaded_documents
{
let hash = sha256(&contents);
if let Some(file_ids) = to_deleted.get_mut(&hash)
&& let Some((i, _)) = file_ids
.iter()
.enumerate()
.find(|(_, v)| self.data.files[*v].path == path)
{
if file_ids.len() == 1 {
to_deleted.swap_remove(&hash);
} else {
file_ids.remove(i);
}
continue;
}
let extension = metadata
.swap_remove(EXTENSION_METADATA)
.unwrap_or_else(|| DEFAULT_EXTENSION.into());
let separator = get_separators(&extension);
let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&separator,
);
let split_options = SplitterChunkHeaderOptions::default();
let document = RagDocument::new(contents);
let split_documents = splitter.split_documents(&[document], &split_options);
rag_files.push(RagFile {
hash: hash.clone(),
path,
documents: split_documents,
});
}
let mut next_file_id = self.data.next_file_id;
let mut files = vec![];
let mut document_ids = vec![];
let mut embeddings = vec![];
if !rag_files.is_empty() {
let mut texts = vec![];
for file in rag_files.into_iter() {
for (document_index, document) in file.documents.iter().enumerate() {
document_ids.push(DocumentId::new(next_file_id, document_index));
texts.push(document.page_content.clone())
}
files.push((next_file_id, file));
next_file_id += 1;
}
let embeddings_data = EmbeddingsData::new(texts, false);
embeddings = self
.create_embeddings(embeddings_data, spinner.clone())
.await?;
}
let to_delete_file_ids: Vec<_> = to_deleted.values().flatten().copied().collect();
self.data.del(to_delete_file_ids);
self.data.add(next_file_id, files, document_ids, embeddings);
self.data.document_paths = document_paths.into_iter().collect();
if self.data.files.is_empty() {
bail!("No RAG files");
}
progress(&spinner, "Building store".into());
self.hnsw = self.data.build_hnsw();
self.bm25 = self.data.build_bm25();
Ok(())
}
async fn hybird_search(
&self,
query: &str,
top_k: usize,
rerank_model: Option<&str>,
) -> Result<Vec<(DocumentId, String)>> {
let (vector_search_results, keyword_search_results) = tokio::join!(
self.vector_search(query, top_k, 0.0),
self.keyword_search(query, top_k, 0.0),
);
let vector_search_results = vector_search_results?;
debug!("vector_search_results: {vector_search_results:?}",);
let vector_search_ids: Vec<DocumentId> =
vector_search_results.into_iter().map(|(v, _)| v).collect();
let keyword_search_results = keyword_search_results?;
debug!("keyword_search_results: {keyword_search_results:?}",);
let keyword_search_ids: Vec<DocumentId> =
keyword_search_results.into_iter().map(|(v, _)| v).collect();
let ids = match rerank_model {
Some(model_id) => {
let model = Model::retrieve_model(&self.app_config, model_id, ModelType::Reranker)?;
let client = self.create_embeddings_client(model)?;
let ids: IndexSet<DocumentId> = [vector_search_ids, keyword_search_ids]
.concat()
.into_iter()
.collect();
let mut documents = vec![];
let mut documents_ids = vec![];
for id in ids {
if let Some(document) = self.data.get(id) {
documents_ids.push(id);
documents.push(document.page_content.to_string());
}
}
let data = RerankData::new(query.to_string(), documents, top_k);
let list = client.rerank(&data).await.context("Failed to rerank")?;
let ids: Vec<_> = list
.into_iter()
.take(top_k)
.filter_map(|item| documents_ids.get(item.index).cloned())
.collect();
debug!("rerank_ids: {ids:?}");
ids
}
None => {
let ids = reciprocal_rank_fusion(
vec![vector_search_ids, keyword_search_ids],
vec![1.125, 1.0],
top_k,
);
debug!("rrf_ids: {ids:?}");
ids
}
};
let output = ids
.into_iter()
.filter_map(|id| {
let document = self.data.get(id)?;
Some((id, document.page_content.clone()))
})
.collect();
Ok(output)
}
async fn vector_search(
&self,
query: &str,
top_k: usize,
min_score: f32,
) -> Result<Vec<(DocumentId, f32)>> {
let splitter = RecursiveCharacterTextSplitter::new(
self.data.chunk_size,
self.data.chunk_overlap,
&DEFAULT_SEPARATORS,
);
let texts = splitter.split_text(query);
let embeddings_data = EmbeddingsData::new(texts, true);
let embeddings = self.create_embeddings(embeddings_data, None).await?;
let output = self
.hnsw
.parallel_search(&embeddings, top_k, 30)
.into_iter()
.flat_map(|list| {
list.into_iter()
.filter_map(|v| {
let score = 1.0 - v.distance;
if score > min_score {
Some((DocumentId(v.d_id), score))
} else {
None
}
})
.collect::<Vec<_>>()
})
.collect();
Ok(output)
}
async fn keyword_search(
&self,
query: &str,
top_k: usize,
min_score: f32,
) -> Result<Vec<(DocumentId, f32)>> {
let results = self.bm25.search(query, top_k);
let output: Vec<(DocumentId, f32)> = results
.into_iter()
.filter_map(|v| {
let score = v.score;
if score > min_score {
Some((v.document.id, score))
} else {
None
}
})
.collect();
Ok(output)
}
async fn create_embeddings(
&self,
data: EmbeddingsData,
spinner: Option<Spinner>,
) -> Result<EmbeddingsOutput> {
let embedding_client = self.create_embeddings_client(self.embedding_model.clone())?;
let EmbeddingsData { texts, query } = data;
let batch_size = self
.data
.batch_size
.or_else(|| self.embedding_model.max_batch_size());
let batch_size = match self.embedding_model.max_input_tokens() {
Some(max_input_tokens) => {
let x = max_input_tokens / self.data.chunk_size;
match batch_size {
Some(y) => x.min(y),
None => x,
}
}
None => batch_size.unwrap_or(1),
};
let mut output = vec![];
let batch_chunks = texts.chunks(batch_size.max(1));
let batch_chunks_len = batch_chunks.len();
let retry_limit = env::var(get_env_name("embeddings_retry_limit"))
.ok()
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(2);
for (index, texts) in batch_chunks.enumerate() {
progress(
&spinner,
format!("Creating embeddings [{}/{batch_chunks_len}]", index + 1),
);
let chunk_data = EmbeddingsData {
texts: texts.to_vec(),
query,
};
let mut retry = 0;
let chunk_output = loop {
retry += 1;
match embedding_client.embeddings(&chunk_data).await {
Ok(v) => break v,
Err(e) if retry < retry_limit => {
debug!("retry {retry} failed: {e}");
sleep(Duration::from_secs(2u64.pow(retry - 1))).await;
continue;
}
Err(e) => {
return Err(e).with_context(|| {
format!("Failed to create embedding after {retry_limit} attempts")
})?;
}
}
};
output.extend(chunk_output);
}
Ok(output)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct RagData {
pub embedding_model: String,
pub chunk_size: usize,
pub chunk_overlap: usize,
pub reranker_model: Option<String>,
pub top_k: usize,
pub batch_size: Option<usize>,
pub next_file_id: FileId,
pub document_paths: Vec<String>,
pub files: IndexMap<FileId, RagFile>,
#[serde(with = "serde_vectors")]
pub vectors: IndexMap<DocumentId, Vec<f32>>,
}
impl Debug for RagData {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RagData")
.field("embedding_model", &self.embedding_model)
.field("chunk_size", &self.chunk_size)
.field("chunk_overlap", &self.chunk_overlap)
.field("reranker_model", &self.reranker_model)
.field("top_k", &self.top_k)
.field("batch_size", &self.batch_size)
.field("next_file_id", &self.next_file_id)
.field("document_paths", &self.document_paths)
.field("files", &self.files)
.finish()
}
}
impl RagData {
pub fn new(
embedding_model: String,
chunk_size: usize,
chunk_overlap: usize,
reranker_model: Option<String>,
top_k: usize,
batch_size: Option<usize>,
) -> Self {
Self {
embedding_model,
chunk_size,
chunk_overlap,
reranker_model,
top_k,
batch_size,
next_file_id: 0,
document_paths: Default::default(),
files: Default::default(),
vectors: Default::default(),
}
}
pub fn get(&self, id: DocumentId) -> Option<&RagDocument> {
let (file_index, document_index) = id.split();
let file = self.files.get(&file_index)?;
let document = file.documents.get(document_index)?;
Some(document)
}
pub fn del(&mut self, file_ids: Vec<FileId>) {
for file_id in file_ids {
if let Some(file) = self.files.swap_remove(&file_id) {
for (document_index, _) in file.documents.iter().enumerate() {
let document_id = DocumentId::new(file_id, document_index);
self.vectors.swap_remove(&document_id);
}
}
}
}
pub fn add(
&mut self,
next_file_id: FileId,
files: Vec<(FileId, RagFile)>,
document_ids: Vec<DocumentId>,
embeddings: EmbeddingsOutput,
) {
self.next_file_id = next_file_id;
self.files.extend(files);
self.vectors
.extend(document_ids.into_iter().zip(embeddings));
}
pub fn build_hnsw(&self) -> Hnsw<'static, f32, DistCosine> {
let hnsw = Hnsw::new(32, self.vectors.len(), 16, 200, DistCosine {});
let list: Vec<_> = self.vectors.iter().map(|(k, v)| (v, k.0)).collect();
hnsw.parallel_insert(&list);
hnsw
}
pub fn build_bm25(&self) -> SearchEngine<DocumentId> {
let mut documents = vec![];
for (file_index, file) in self.files.iter() {
for (document_index, document) in file.documents.iter().enumerate() {
let id = DocumentId::new(*file_index, document_index);
documents.push(bm25::Document::new(id, &document.page_content))
}
}
SearchEngineBuilder::<DocumentId>::with_documents(Language::English, documents)
.k1(1.5)
.b(0.75)
.build()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagFile {
hash: String,
path: String,
documents: Vec<RagDocument>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct RagDocument {
pub page_content: String,
pub metadata: DocumentMetadata,
}
impl RagDocument {
pub fn new<S: Into<String>>(page_content: S) -> Self {
RagDocument {
page_content: page_content.into(),
metadata: IndexMap::new(),
}
}
}
impl Default for RagDocument {
fn default() -> Self {
RagDocument {
page_content: "".to_string(),
metadata: IndexMap::new(),
}
}
}
pub type FileId = usize;
#[derive(Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct DocumentId(usize);
impl Debug for DocumentId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let (file_index, document_index) = self.split();
f.write_fmt(format_args!("{file_index}-{document_index}"))
}
}
impl DocumentId {
pub fn new(file_index: usize, document_index: usize) -> Self {
let value = (file_index << (usize::BITS / 2)) | document_index;
Self(value)
}
pub fn split(self) -> (usize, usize) {
let value = self.0;
let low_mask = (1 << (usize::BITS / 2)) - 1;
let low = value & low_mask;
let high = value >> (usize::BITS / 2);
(high, low)
}
}
fn select_embedding_model(models: &[&Model]) -> Result<String> {
let models: Vec<_> = models
.iter()
.map(|v| SelectOption::new(v.id(), v.description()))
.collect();
let result = Select::new("Select embedding model:", models).prompt()?;
Ok(result.value)
}
#[derive(Debug)]
struct SelectOption {
pub value: String,
pub description: String,
}
impl SelectOption {
pub fn new(value: String, description: String) -> Self {
Self { value, description }
}
}
impl std::fmt::Display for SelectOption {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{} ({})", self.value, self.description)
}
}
fn set_chunk_size(model: &Model) -> Result<usize> {
let default_value = model.default_chunk_size().to_string();
let help_message = model
.max_tokens_per_chunk()
.map(|v| format!("The model's max_tokens is {v}"));
let mut text = Text::new("Set chunk size:")
.with_default(&default_value)
.with_validator(move |text: &str| {
let out = match text.parse::<usize>() {
Ok(_) => Validation::Valid,
Err(_) => Validation::Invalid("Must be a integer".into()),
};
Ok(out)
});
if let Some(help_message) = &help_message {
text = text.with_help_message(help_message);
}
let value = text.prompt()?;
value.parse().map_err(|_| anyhow!("Invalid chunk_size"))
}
fn set_chunk_overlay(default_value: usize) -> Result<usize> {
let value = Text::new("Set chunk overlay:")
.with_default(&default_value.to_string())
.with_validator(move |text: &str| {
let out = match text.parse::<usize>() {
Ok(_) => Validation::Valid,
Err(_) => Validation::Invalid("Must be a integer".into()),
};
Ok(out)
})
.prompt()?;
value.parse().map_err(|_| anyhow!("Invalid chunk_overlay"))
}
fn add_documents() -> Result<Vec<String>> {
let text = Text::new("Add documents:")
.with_validator(required!("This field is required"))
.with_help_message("e.g. file;dir/;dir/**/*.{md,mdx};loader:resource;url;website/**")
.prompt()?;
let paths = text
.split(';')
.filter_map(|v| {
let v = v.trim().to_string();
if v.is_empty() { None } else { Some(v) }
})
.collect();
Ok(paths)
}
async fn resolve_paths<T: AsRef<str>>(
loaders: &HashMap<String, String>,
paths: &[T],
) -> Result<(
IndexSet<String>,
IndexSet<String>,
IndexSet<String>,
IndexSet<String>,
IndexSet<String>,
)> {
let mut document_paths = IndexSet::new();
let mut recursive_urls = IndexSet::new();
let mut urls = IndexSet::new();
let mut protocol_paths = IndexSet::new();
let mut absolute_paths = vec![];
for path in paths {
let path = path.as_ref().trim();
if is_url(path) {
if let Some(start_url) = path.strip_suffix("**") {
recursive_urls.insert(start_url.to_string());
} else {
urls.insert(path.to_string());
}
document_paths.insert(path.to_string());
} else if is_loader_protocol(loaders, path) {
protocol_paths.insert(path.to_string());
document_paths.insert(path.to_string());
} else {
let resolved_path = resolve_home_dir(path);
let absolute_path = to_absolute_path(&resolved_path)
.with_context(|| format!("Invalid path '{path}'"))?;
absolute_paths.push(resolved_path);
document_paths.insert(absolute_path);
}
}
let local_paths = expand_glob_paths(&absolute_paths, false).await?;
Ok((
document_paths,
recursive_urls,
urls,
protocol_paths,
local_paths,
))
}
fn progress(spinner: &Option<Spinner>, message: String) {
if let Some(spinner) = spinner {
let _ = spinner.set_message(message);
}
}
fn reciprocal_rank_fusion(
list_of_document_ids: Vec<Vec<DocumentId>>,
list_of_weights: Vec<f32>,
top_k: usize,
) -> Vec<DocumentId> {
let rrf_k = top_k * 2;
let mut map: IndexMap<DocumentId, f32> = IndexMap::new();
for (document_ids, weight) in list_of_document_ids
.into_iter()
.zip(list_of_weights.into_iter())
{
for (index, &item) in document_ids.iter().enumerate() {
*map.entry(item).or_default() += (1.0 / ((rrf_k + index + 1) as f32)) * weight;
}
}
let mut sorted_items: Vec<(DocumentId, f32)> = map.into_iter().collect();
sorted_items.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
sorted_items
.into_iter()
.take(top_k)
.map(|(v, _)| v)
.collect()
}