From 27642c6b873781926644ae61cda3123c46907177 Mon Sep 17 00:00:00 2001 From: travolin Date: Fri, 8 Nov 2024 20:31:45 -0800 Subject: [PATCH] Embedding enable disable (#535) * update to auto load embedding model - When embeddings is turned on download model if needed - Update embedding api to be swapable - When embeddings are enabled add all already indexed documents to the embedding queue (if they have not already be embedded) * remove unused custom default --------- Co-authored-by: travolin --- apps/tauri/src/cmd/settings.rs | 4 + crates/entities/src/models/embedding_queue.rs | 14 +++- .../entities/src/models/indexed_document.rs | 14 ++++ crates/entities/src/models/vec_documents.rs | 2 +- crates/shared/src/config.rs | 3 +- crates/shared/src/config/embeddings.rs | 12 +-- crates/spyglass-model-interface/src/lib.rs | 3 +- crates/spyglass/src/api/handler/search.rs | 2 +- crates/spyglass/src/documents/embeddings.rs | 2 +- crates/spyglass/src/documents/mod.rs | 2 +- crates/spyglass/src/state.rs | 62 +++++++++------ crates/spyglass/src/task.rs | 79 +++++++++++++++++-- 12 files changed, 151 insertions(+), 48 deletions(-) diff --git a/apps/tauri/src/cmd/settings.rs b/apps/tauri/src/cmd/settings.rs index 2c8a978da..efab3d79c 100644 --- a/apps/tauri/src/cmd/settings.rs +++ b/apps/tauri/src/cmd/settings.rs @@ -90,6 +90,10 @@ pub async fn save_user_settings( current_settings.audio_settings.enable_audio_transcription = serde_json::from_str(value).unwrap_or_default() } + "embedding_settings.enable_embeddings" => { + current_settings.embedding_settings.enable_embeddings = + serde_json::from_str(value).unwrap_or_default() + } _ => {} } } diff --git a/crates/entities/src/models/embedding_queue.rs b/crates/entities/src/models/embedding_queue.rs index 6193764a4..4cb24f4ba 100644 --- a/crates/entities/src/models/embedding_queue.rs +++ b/crates/entities/src/models/embedding_queue.rs @@ -84,7 +84,12 @@ where Entity::insert(model) .on_conflict( OnConflict::column(Column::DocumentId) - .update_columns([Column::Status, Column::Content, Column::IndexedDocumentId]) + .update_columns([ + Column::Status, + Column::Content, + Column::IndexedDocumentId, + Column::UpdatedAt, + ]) .to_owned(), ) .exec(db) @@ -98,7 +103,12 @@ where Entity::insert_many(to_add.to_vec()) .on_conflict( OnConflict::column(Column::DocumentId) - .update_columns([Column::Status, Column::Content, Column::IndexedDocumentId]) + .update_columns([ + Column::Status, + Column::Content, + Column::IndexedDocumentId, + Column::UpdatedAt, + ]) .to_owned(), ) .exec_without_returning(db) diff --git a/crates/entities/src/models/indexed_document.rs b/crates/entities/src/models/indexed_document.rs index db1e6de2e..b4c2c6c7e 100644 --- a/crates/entities/src/models/indexed_document.rs +++ b/crates/entities/src/models/indexed_document.rs @@ -509,6 +509,20 @@ pub async fn copy_table( Ok(()) } +pub async fn get_documents_missing_embeddings( + db: &DatabaseConnection, +) -> Result, DbErr> { + let statement = Statement::from_string( + db.get_database_backend(), + r#" + select * from indexed_document where id not in ( + select indexed_document_id from embedding_queue where status != 'Failed') + "#, + ); + + Model::find_by_statement(statement).all(db).await +} + #[cfg(test)] mod test { use std::collections::HashMap; diff --git a/crates/entities/src/models/vec_documents.rs b/crates/entities/src/models/vec_documents.rs index 34dc3c983..daf0ef55b 100644 --- a/crates/entities/src/models/vec_documents.rs +++ b/crates/entities/src/models/vec_documents.rs @@ -40,7 +40,7 @@ where db.get_database_backend(), r#" update vec_documents set embedding = $2 - where id = $1 + where rowid = $1 "#, vec![id.into(), embedding.into()], ) diff --git a/crates/shared/src/config.rs b/crates/shared/src/config.rs index 48d3cd52a..0cf83d43b 100644 --- a/crates/shared/src/config.rs +++ b/crates/shared/src/config.rs @@ -1,7 +1,7 @@ use crate::form::{FormType, SettingOpts}; use diff::Diff; use directories::ProjectDirs; -use embeddings::EmbeddingSettings; +use embeddings::{embedding_setting_opts, EmbeddingSettings}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; @@ -234,6 +234,7 @@ impl From for Vec<(String, SettingOpts)> { config.extend(fs_setting_opts(&settings)); config.extend(audio_setting_opts(&settings)); + config.extend(embedding_setting_opts(&settings)); config } diff --git a/crates/shared/src/config/embeddings.rs b/crates/shared/src/config/embeddings.rs index 406c142db..b4d810c5d 100644 --- a/crates/shared/src/config/embeddings.rs +++ b/crates/shared/src/config/embeddings.rs @@ -5,19 +5,11 @@ use crate::form::{FormType, SettingOpts}; use super::UserSettings; -#[derive(Clone, Debug, Serialize, Deserialize, Diff)] +#[derive(Clone, Debug, Serialize, Deserialize, Diff, Default)] pub struct EmbeddingSettings { pub enable_embeddings: bool, } -impl Default for EmbeddingSettings { - fn default() -> Self { - EmbeddingSettings { - enable_embeddings: true, - } - } -} - #[allow(dead_code)] pub fn embedding_setting_opts(settings: &UserSettings) -> Vec<(String, SettingOpts)> { vec![( @@ -26,7 +18,7 @@ pub fn embedding_setting_opts(settings: &UserSettings) -> Vec<(String, SettingOp label: "Beta: Enable Similarity Search".into(), value: settings.embedding_settings.enable_embeddings.to_string(), form_type: FormType::Bool, - restart_required: true, + restart_required: false, help_text: Some( r#"Embeddings are generated for documents and search will check for semantic similarity as well as standard search."# diff --git a/crates/spyglass-model-interface/src/lib.rs b/crates/spyglass-model-interface/src/lib.rs index 5b84f04db..e2e01ca17 100644 --- a/crates/spyglass-model-interface/src/lib.rs +++ b/crates/spyglass-model-interface/src/lib.rs @@ -421,7 +421,8 @@ impl WrapErr for Result { pub fn load_tokenizer(model_root: &Path) -> anyhow::Result { // Load tokenizer let tokenizer_path = model_root.join("tokenizer.json"); - let mut tokenizer = Tokenizer::from_file(tokenizer_path).expect("tokenizer.json not found"); + let mut tokenizer = Tokenizer::from_file(tokenizer_path) + .map_err(|error| anyhow::format_err!("Error loading tokenizer {:?}", error))?; // See https://github.com/huggingface/tokenizers/pull/1357 if let Some(pre_tokenizer) = tokenizer.get_pre_tokenizer() { if let PreTokenizerWrapper::Metaspace(m) = pre_tokenizer { diff --git a/crates/spyglass/src/api/handler/search.rs b/crates/spyglass/src/api/handler/search.rs index 008611c29..4c3568aac 100644 --- a/crates/spyglass/src/api/handler/search.rs +++ b/crates/spyglass/src/api/handler/search.rs @@ -62,7 +62,7 @@ pub async fn search_docs( })); } - if let Some(embedding_api) = state.embedding_api.as_ref() { + if let Some(embedding_api) = state.embedding_api.load_full().as_ref() { match embedding_api.embed(&query, EmbeddingContentType::Query) { Ok(embedding) => { let mut distances = diff --git a/crates/spyglass/src/documents/embeddings.rs b/crates/spyglass/src/documents/embeddings.rs index 32822e904..749d43952 100644 --- a/crates/spyglass/src/documents/embeddings.rs +++ b/crates/spyglass/src/documents/embeddings.rs @@ -18,7 +18,7 @@ pub async fn processing_embedding(state: AppState, job_id: i64) { Ok(Some(job)) => { match job.content { Some(content) => { - let embedding = if let Some(api) = state.embedding_api.as_ref() { + let embedding = if let Some(api) = state.embedding_api.load_full().as_ref() { api.embed(&content, EmbeddingContentType::Document) } else { Err(anyhow::format_err!( diff --git a/crates/spyglass/src/documents/mod.rs b/crates/spyglass/src/documents/mod.rs index 104570cc2..a594099be 100644 --- a/crates/spyglass/src/documents/mod.rs +++ b/crates/spyglass/src/documents/mod.rs @@ -194,7 +194,7 @@ pub async fn process_crawl_results( ) .await?; - if crawl_result.content.is_some() && state.embedding_api.as_ref().is_some() { + if crawl_result.content.is_some() && state.embedding_api.load().as_ref().is_some() { embedding_map.insert(doc_id.clone(), crawl_result.content.clone().unwrap()); } diff --git a/crates/spyglass/src/state.rs b/crates/spyglass/src/state.rs index ffa513fc4..e1c8817c5 100644 --- a/crates/spyglass/src/state.rs +++ b/crates/spyglass/src/state.rs @@ -57,7 +57,7 @@ impl FetchLimitType { #[derive(Clone)] pub struct AppState { pub db: DatabaseConnection, - pub embedding_api: Arc>, + pub embedding_api: Arc>>, pub app_state: Arc>, pub lenses: Arc>, pub pipelines: Arc>, @@ -120,6 +120,11 @@ impl AppState { self.config = config; } + pub fn reload_model(&mut self) { + let embedding_api = load_model(self.user_settings.load_full().as_ref()); + self.embedding_api.store(Arc::new(embedding_api)); + } + pub fn builder() -> AppStateBuilder { AppStateBuilder::new() } @@ -184,27 +189,7 @@ impl AppStateBuilder { UserSettings::default() }; - let mut embedding_api = None; - if self - .user_settings - .as_ref() - .is_some_and(|settings| settings.embedding_settings.enable_embeddings) - { - let mut model_root = user_settings.data_directory.clone(); - model_root.push("models"); - model_root.push("embeddings"); - - let mut tokenizer_file = model_root.clone(); - tokenizer_file.push("tokenizer.json"); - let mut model = model_root.clone(); - model.push("model.safetensors"); - - if tokenizer_file.exists() && model.exists() { - embedding_api = EmbeddingApi::new(model_root.clone()).ok(); - } else { - log::warn!("Model does not exist"); - } - } + let embedding_api = load_model(&user_settings); let (shutdown_tx, _) = broadcast::channel::(16); let (config_tx, _) = broadcast::channel::(16); @@ -231,7 +216,7 @@ impl AppStateBuilder { user_settings: Arc::new(ArcSwap::from_pointee(user_settings)), fetch_limits: Arc::new(DashMap::new()), readonly_mode: self.readonly_mode.unwrap_or_default(), - embedding_api: Arc::new(embedding_api), + embedding_api: Arc::new(ArcSwap::from_pointee(embedding_api)), } } @@ -280,3 +265,34 @@ impl AppStateBuilder { self } } + +fn load_model(user_settings: &UserSettings) -> Option { + if user_settings.embedding_settings.enable_embeddings { + let mut model_root = user_settings.data_directory.clone(); + model_root.push("models"); + model_root.push("embeddings"); + + let mut tokenizer_file = model_root.clone(); + tokenizer_file.push("tokenizer.json"); + let mut model = model_root.clone(); + model.push("model.safetensors"); + + if tokenizer_file.exists() && model.exists() { + match EmbeddingApi::new(model_root.clone()) { + Ok(embedding_api) => { + log::info!("Embedding Model Loaded"); + Some(embedding_api) + } + Err(error) => { + log::error!("Error Loading Embedding Model {:?}", error); + None + } + } + } else { + log::warn!("Model does not exist"); + None + } + } else { + None + } +} diff --git a/crates/spyglass/src/task.rs b/crates/spyglass/src/task.rs index 28da8315a..6a251be5a 100644 --- a/crates/spyglass/src/task.rs +++ b/crates/spyglass/src/task.rs @@ -1,12 +1,16 @@ use anyhow::anyhow; use entities::models::crawl_queue::CrawlStatus; -use entities::models::{bootstrap_queue, connection, crawl_queue, embedding_queue}; +use entities::models::{ + bootstrap_queue, connection, crawl_queue, embedding_queue, indexed_document, +}; +use entities::sea_orm::Set; use entities::sea_orm::{sea_query::Expr, ColumnTrait, Condition, EntityTrait, QueryFilter}; use futures::StreamExt; use notify::event::ModifyKind; use notify::{EventKind, RecursiveMode, Watcher}; use shared::config::{Config, LensConfig, UserSettings, UserSettingsDiff}; use spyglass_rpc::{ModelDownloadStatusPayload, RpcEvent, RpcEventType}; +use std::collections::HashMap; use std::fs::File; use std::io::Write; use std::path::PathBuf; @@ -23,6 +27,7 @@ use crate::filesystem; use crate::state::AppState; use crate::task::worker::FetchResult; use diff::Diff; +use entities::sea_orm::ActiveModelBehavior; use spyglass_processor::utils::extensions::AudioExt; pub mod lens; @@ -265,18 +270,37 @@ pub async fn config_task(mut state: AppState) { } if new_settings.embedding_settings.enable_embeddings { + let model_dir = state.config.embedding_model_dir(); let model_path = state.config.embedding_model_dir().join("model.safetensors"); let tokenizer_path = state.config.embedding_model_dir().join("tokenizer.json"); let model_config_path = state.config.embedding_model_dir().join("config.json"); if !model_path.exists() || !tokenizer_path.exists() || !model_config_path.exists() { - let state_clone = state.clone(); + log::debug!("Loading Embedding Models..."); + let mut state_clone = state.clone(); + + if !model_dir.exists() { + let _ = std::fs::create_dir_all(model_dir); + } + tokio::spawn(async move { - let _ = download_model(&state_clone, "Embedding Model", model_path, shared::constants::EMBEDDING_MODEL).await; - let _ = download_model(&state_clone, "Embedding Model Config", tokenizer_path, shared::constants::EMBEDDING_MODEL_CONFIG).await; - let _ = download_model(&state_clone, "Embedding Model Tokenizer", model_config_path, shared::constants::EMBEDDING_MODEL_TOKENIZER).await; - //TODO Embed current documents + if let Err(error) = download_model(&state_clone, "Embedding Model", model_path, shared::constants::EMBEDDING_MODEL).await { + log::error!("Error downloading Embedding model {:?}", error); + } + if let Err(error) = download_model(&state_clone, "Embedding Model Config", model_config_path, shared::constants::EMBEDDING_MODEL_CONFIG).await { + log::error!("Error downloading Embedding model config {:?}", error); + } + if let Err(error) = download_model(&state_clone, "Embedding Model Tokenizer", tokenizer_path, shared::constants::EMBEDDING_MODEL_TOKENIZER).await { + log::error!("Error downloading Embedding model tokenizer config {:?}", error); + } + + state_clone.reload_model(); + + add_missing_embeddings(&state_clone).await; }); + } else { + state.reload_model(); + add_missing_embeddings(&state).await; } } } @@ -290,6 +314,48 @@ pub async fn config_task(mut state: AppState) { } } +async fn add_missing_embeddings(state: &AppState) { + match indexed_document::get_documents_missing_embeddings(&state.db).await { + Ok(missing_embeddings) => { + // could be a very large set of documents + for missing_embeddings in missing_embeddings.chunks(1000) { + let doc_ids = missing_embeddings + .iter() + .map(|doc| doc.doc_id.clone()) + .collect::>(); + + let docs = state + .index + .search_by_query(None, Some(doc_ids), &[], &[]) + .await; + let mut content_map: HashMap = HashMap::new(); + for (_, result) in docs { + content_map.insert(result.doc_id.to_owned(), result.content.to_owned()); + } + + let updates = missing_embeddings + .iter() + .map(|doc| { + let mut model = embedding_queue::ActiveModel::new(); + let content = content_map.get(&doc.doc_id).cloned(); + model.document_id = Set(doc.doc_id.clone()); + model.content = Set(content); + model.indexed_document_id = Set(doc.id); + model + }) + .collect::>(); + + if let Err(error) = embedding_queue::add_to_queue(&state.db, &updates).await { + log::error!("Error adding documents to embedding queue {:?}", error); + } + } + } + Err(error) => { + log::error!("Error getting missing document embeddings. {:?}", error); + } + } +} + /// Downloads a model from our assets S3 bucket async fn download_model( state: &AppState, @@ -297,7 +363,6 @@ async fn download_model( model_path: PathBuf, model_url: &str, ) -> anyhow::Result<()> { - // Currently we only have the audio model :) match reqwest::get(model_url).await { Ok(res) => { let total_size = res.content_length().expect("Unable to get content length");