Skip to content

Commit

Permalink
Embedding enable disable (#535)
Browse files Browse the repository at this point in the history
* update to auto load embedding model

- When embeddings is turned on download model if needed
- Update embedding api to be swapable
- When embeddings are enabled add all already indexed documents to the embedding queue (if they have not already be embedded)

* remove unused custom default

---------

Co-authored-by: travolin <joel@spyglass.fyi>
  • Loading branch information
travolin and travolin authored Nov 9, 2024
1 parent 04aac30 commit 27642c6
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 48 deletions.
4 changes: 4 additions & 0 deletions apps/tauri/src/cmd/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ pub async fn save_user_settings(
current_settings.audio_settings.enable_audio_transcription =
serde_json::from_str(value).unwrap_or_default()
}
"embedding_settings.enable_embeddings" => {
current_settings.embedding_settings.enable_embeddings =
serde_json::from_str(value).unwrap_or_default()
}
_ => {}
}
}
Expand Down
14 changes: 12 additions & 2 deletions crates/entities/src/models/embedding_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ where
Entity::insert(model)
.on_conflict(
OnConflict::column(Column::DocumentId)
.update_columns([Column::Status, Column::Content, Column::IndexedDocumentId])
.update_columns([
Column::Status,
Column::Content,
Column::IndexedDocumentId,
Column::UpdatedAt,
])
.to_owned(),
)
.exec(db)
Expand All @@ -98,7 +103,12 @@ where
Entity::insert_many(to_add.to_vec())
.on_conflict(
OnConflict::column(Column::DocumentId)
.update_columns([Column::Status, Column::Content, Column::IndexedDocumentId])
.update_columns([
Column::Status,
Column::Content,
Column::IndexedDocumentId,
Column::UpdatedAt,
])
.to_owned(),
)
.exec_without_returning(db)
Expand Down
14 changes: 14 additions & 0 deletions crates/entities/src/models/indexed_document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,20 @@ pub async fn copy_table(
Ok(())
}

pub async fn get_documents_missing_embeddings(
db: &DatabaseConnection,
) -> Result<Vec<Model>, DbErr> {
let statement = Statement::from_string(
db.get_database_backend(),
r#"
select * from indexed_document where id not in (
select indexed_document_id from embedding_queue where status != 'Failed')
"#,
);

Model::find_by_statement(statement).all(db).await
}

#[cfg(test)]
mod test {
use std::collections::HashMap;
Expand Down
2 changes: 1 addition & 1 deletion crates/entities/src/models/vec_documents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ where
db.get_database_backend(),
r#"
update vec_documents set embedding = $2
where id = $1
where rowid = $1
"#,
vec![id.into(), embedding.into()],
)
Expand Down
3 changes: 2 additions & 1 deletion crates/shared/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::form::{FormType, SettingOpts};
use diff::Diff;
use directories::ProjectDirs;
use embeddings::EmbeddingSettings;
use embeddings::{embedding_setting_opts, EmbeddingSettings};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
Expand Down Expand Up @@ -234,6 +234,7 @@ impl From<UserSettings> for Vec<(String, SettingOpts)> {

config.extend(fs_setting_opts(&settings));
config.extend(audio_setting_opts(&settings));
config.extend(embedding_setting_opts(&settings));

config
}
Expand Down
12 changes: 2 additions & 10 deletions crates/shared/src/config/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,11 @@ use crate::form::{FormType, SettingOpts};

use super::UserSettings;

#[derive(Clone, Debug, Serialize, Deserialize, Diff)]
#[derive(Clone, Debug, Serialize, Deserialize, Diff, Default)]
pub struct EmbeddingSettings {
pub enable_embeddings: bool,
}

impl Default for EmbeddingSettings {
fn default() -> Self {
EmbeddingSettings {
enable_embeddings: true,
}
}
}

#[allow(dead_code)]
pub fn embedding_setting_opts(settings: &UserSettings) -> Vec<(String, SettingOpts)> {
vec![(
Expand All @@ -26,7 +18,7 @@ pub fn embedding_setting_opts(settings: &UserSettings) -> Vec<(String, SettingOp
label: "Beta: Enable Similarity Search".into(),
value: settings.embedding_settings.enable_embeddings.to_string(),
form_type: FormType::Bool,
restart_required: true,
restart_required: false,
help_text: Some(
r#"Embeddings are generated for documents and search will check for
semantic similarity as well as standard search."#
Expand Down
3 changes: 2 additions & 1 deletion crates/spyglass-model-interface/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,8 @@ impl<O> WrapErr<O> for Result<O, candle::Error> {
pub fn load_tokenizer(model_root: &Path) -> anyhow::Result<Tokenizer> {
// Load tokenizer
let tokenizer_path = model_root.join("tokenizer.json");
let mut tokenizer = Tokenizer::from_file(tokenizer_path).expect("tokenizer.json not found");
let mut tokenizer = Tokenizer::from_file(tokenizer_path)
.map_err(|error| anyhow::format_err!("Error loading tokenizer {:?}", error))?;
// See https://github.com/huggingface/tokenizers/pull/1357
if let Some(pre_tokenizer) = tokenizer.get_pre_tokenizer() {
if let PreTokenizerWrapper::Metaspace(m) = pre_tokenizer {
Expand Down
2 changes: 1 addition & 1 deletion crates/spyglass/src/api/handler/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub async fn search_docs(
}));
}

if let Some(embedding_api) = state.embedding_api.as_ref() {
if let Some(embedding_api) = state.embedding_api.load_full().as_ref() {
match embedding_api.embed(&query, EmbeddingContentType::Query) {
Ok(embedding) => {
let mut distances =
Expand Down
2 changes: 1 addition & 1 deletion crates/spyglass/src/documents/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub async fn processing_embedding(state: AppState, job_id: i64) {
Ok(Some(job)) => {
match job.content {
Some(content) => {
let embedding = if let Some(api) = state.embedding_api.as_ref() {
let embedding = if let Some(api) = state.embedding_api.load_full().as_ref() {
api.embed(&content, EmbeddingContentType::Document)
} else {
Err(anyhow::format_err!(
Expand Down
2 changes: 1 addition & 1 deletion crates/spyglass/src/documents/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ pub async fn process_crawl_results(
)
.await?;

if crawl_result.content.is_some() && state.embedding_api.as_ref().is_some() {
if crawl_result.content.is_some() && state.embedding_api.load().as_ref().is_some() {
embedding_map.insert(doc_id.clone(), crawl_result.content.clone().unwrap());
}

Expand Down
62 changes: 39 additions & 23 deletions crates/spyglass/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl FetchLimitType {
#[derive(Clone)]
pub struct AppState {
pub db: DatabaseConnection,
pub embedding_api: Arc<Option<EmbeddingApi>>,
pub embedding_api: Arc<ArcSwap<Option<EmbeddingApi>>>,
pub app_state: Arc<DashMap<String, String>>,
pub lenses: Arc<DashMap<String, LensConfig>>,
pub pipelines: Arc<DashMap<String, PipelineConfiguration>>,
Expand Down Expand Up @@ -120,6 +120,11 @@ impl AppState {
self.config = config;
}

pub fn reload_model(&mut self) {
let embedding_api = load_model(self.user_settings.load_full().as_ref());
self.embedding_api.store(Arc::new(embedding_api));
}

pub fn builder() -> AppStateBuilder {
AppStateBuilder::new()
}
Expand Down Expand Up @@ -184,27 +189,7 @@ impl AppStateBuilder {
UserSettings::default()
};

let mut embedding_api = None;
if self
.user_settings
.as_ref()
.is_some_and(|settings| settings.embedding_settings.enable_embeddings)
{
let mut model_root = user_settings.data_directory.clone();
model_root.push("models");
model_root.push("embeddings");

let mut tokenizer_file = model_root.clone();
tokenizer_file.push("tokenizer.json");
let mut model = model_root.clone();
model.push("model.safetensors");

if tokenizer_file.exists() && model.exists() {
embedding_api = EmbeddingApi::new(model_root.clone()).ok();
} else {
log::warn!("Model does not exist");
}
}
let embedding_api = load_model(&user_settings);

let (shutdown_tx, _) = broadcast::channel::<AppShutdown>(16);
let (config_tx, _) = broadcast::channel::<UserSettingsChange>(16);
Expand All @@ -231,7 +216,7 @@ impl AppStateBuilder {
user_settings: Arc::new(ArcSwap::from_pointee(user_settings)),
fetch_limits: Arc::new(DashMap::new()),
readonly_mode: self.readonly_mode.unwrap_or_default(),
embedding_api: Arc::new(embedding_api),
embedding_api: Arc::new(ArcSwap::from_pointee(embedding_api)),
}
}

Expand Down Expand Up @@ -280,3 +265,34 @@ impl AppStateBuilder {
self
}
}

fn load_model(user_settings: &UserSettings) -> Option<EmbeddingApi> {
if user_settings.embedding_settings.enable_embeddings {
let mut model_root = user_settings.data_directory.clone();
model_root.push("models");
model_root.push("embeddings");

let mut tokenizer_file = model_root.clone();
tokenizer_file.push("tokenizer.json");
let mut model = model_root.clone();
model.push("model.safetensors");

if tokenizer_file.exists() && model.exists() {
match EmbeddingApi::new(model_root.clone()) {
Ok(embedding_api) => {
log::info!("Embedding Model Loaded");
Some(embedding_api)
}
Err(error) => {
log::error!("Error Loading Embedding Model {:?}", error);
None
}
}
} else {
log::warn!("Model does not exist");
None
}
} else {
None
}
}
79 changes: 72 additions & 7 deletions crates/spyglass/src/task.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use anyhow::anyhow;
use entities::models::crawl_queue::CrawlStatus;
use entities::models::{bootstrap_queue, connection, crawl_queue, embedding_queue};
use entities::models::{
bootstrap_queue, connection, crawl_queue, embedding_queue, indexed_document,
};
use entities::sea_orm::Set;
use entities::sea_orm::{sea_query::Expr, ColumnTrait, Condition, EntityTrait, QueryFilter};
use futures::StreamExt;
use notify::event::ModifyKind;
use notify::{EventKind, RecursiveMode, Watcher};
use shared::config::{Config, LensConfig, UserSettings, UserSettingsDiff};
use spyglass_rpc::{ModelDownloadStatusPayload, RpcEvent, RpcEventType};
use std::collections::HashMap;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
Expand All @@ -23,6 +27,7 @@ use crate::filesystem;
use crate::state::AppState;
use crate::task::worker::FetchResult;
use diff::Diff;
use entities::sea_orm::ActiveModelBehavior;
use spyglass_processor::utils::extensions::AudioExt;

pub mod lens;
Expand Down Expand Up @@ -265,18 +270,37 @@ pub async fn config_task(mut state: AppState) {
}

if new_settings.embedding_settings.enable_embeddings {
let model_dir = state.config.embedding_model_dir();
let model_path = state.config.embedding_model_dir().join("model.safetensors");
let tokenizer_path = state.config.embedding_model_dir().join("tokenizer.json");
let model_config_path = state.config.embedding_model_dir().join("config.json");
if !model_path.exists() || !tokenizer_path.exists() || !model_config_path.exists() {
let state_clone = state.clone();
log::debug!("Loading Embedding Models...");
let mut state_clone = state.clone();

if !model_dir.exists() {
let _ = std::fs::create_dir_all(model_dir);
}

tokio::spawn(async move {
let _ = download_model(&state_clone, "Embedding Model", model_path, shared::constants::EMBEDDING_MODEL).await;
let _ = download_model(&state_clone, "Embedding Model Config", tokenizer_path, shared::constants::EMBEDDING_MODEL_CONFIG).await;
let _ = download_model(&state_clone, "Embedding Model Tokenizer", model_config_path, shared::constants::EMBEDDING_MODEL_TOKENIZER).await;

//TODO Embed current documents
if let Err(error) = download_model(&state_clone, "Embedding Model", model_path, shared::constants::EMBEDDING_MODEL).await {
log::error!("Error downloading Embedding model {:?}", error);
}
if let Err(error) = download_model(&state_clone, "Embedding Model Config", model_config_path, shared::constants::EMBEDDING_MODEL_CONFIG).await {
log::error!("Error downloading Embedding model config {:?}", error);
}
if let Err(error) = download_model(&state_clone, "Embedding Model Tokenizer", tokenizer_path, shared::constants::EMBEDDING_MODEL_TOKENIZER).await {
log::error!("Error downloading Embedding model tokenizer config {:?}", error);
}

state_clone.reload_model();

add_missing_embeddings(&state_clone).await;
});
} else {
state.reload_model();
add_missing_embeddings(&state).await;
}
}
}
Expand All @@ -290,14 +314,55 @@ pub async fn config_task(mut state: AppState) {
}
}

async fn add_missing_embeddings(state: &AppState) {
match indexed_document::get_documents_missing_embeddings(&state.db).await {
Ok(missing_embeddings) => {
// could be a very large set of documents
for missing_embeddings in missing_embeddings.chunks(1000) {
let doc_ids = missing_embeddings
.iter()
.map(|doc| doc.doc_id.clone())
.collect::<Vec<String>>();

let docs = state
.index
.search_by_query(None, Some(doc_ids), &[], &[])
.await;
let mut content_map: HashMap<String, String> = HashMap::new();
for (_, result) in docs {
content_map.insert(result.doc_id.to_owned(), result.content.to_owned());
}

let updates = missing_embeddings
.iter()
.map(|doc| {
let mut model = embedding_queue::ActiveModel::new();
let content = content_map.get(&doc.doc_id).cloned();
model.document_id = Set(doc.doc_id.clone());
model.content = Set(content);
model.indexed_document_id = Set(doc.id);
model
})
.collect::<Vec<embedding_queue::ActiveModel>>();

if let Err(error) = embedding_queue::add_to_queue(&state.db, &updates).await {
log::error!("Error adding documents to embedding queue {:?}", error);
}
}
}
Err(error) => {
log::error!("Error getting missing document embeddings. {:?}", error);
}
}
}

/// Downloads a model from our assets S3 bucket
async fn download_model(
state: &AppState,
model_name: &str,
model_path: PathBuf,
model_url: &str,
) -> anyhow::Result<()> {
// Currently we only have the audio model :)
match reqwest::get(model_url).await {
Ok(res) => {
let total_size = res.content_length().expect("Unable to get content length");
Expand Down

0 comments on commit 27642c6

Please sign in to comment.