From abf827f2a6bebc70ef8829909dcf6acb3ce24715 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Tue, 17 Dec 2024 23:23:43 +0800 Subject: [PATCH] chore: query multiple collab embedding state (#1081) * chore: query multiple collab embedding state * chore: clippy --- .github/workflows/integration_test.yml | 2 +- ...db181b062747d6463e400158cfdc753a82c5b.json | 41 +++ ...7de29e2e5df569faefa098254f1afd3aa662d.json | 41 +++ ...4c281da43d140f3156efcb56e1b908d5f013a.json | 29 -- libs/client-api-test/src/test_client.rs | 42 ++- libs/client-api/src/http_collab.rs | 26 +- libs/collab-rt-entity/src/message.rs | 33 ++- libs/database-entity/src/dto.rs | 7 +- libs/database/src/collab/collab_db_ops.rs | 75 ++++- .../src/index/collab_embeddings_ops.rs | 1 - libs/infra/src/env_util.rs | 30 +- libs/shared-entity/src/dto/workspace_dto.rs | 9 + services/appflowy-collaborate/src/config.rs | 21 +- .../src/group/group_init.rs | 2 +- .../src/indexer/document_indexer.rs | 27 -- .../src/indexer/indexer_scheduler.rs | 261 ++++++++++++------ .../src/indexer/metrics.rs | 1 + .../src/indexer/provider.rs | 8 - src/api/workspace.rs | 21 +- tests/ai_test/chat_with_selected_doc_test.rs | 33 ++- 20 files changed, 502 insertions(+), 208 deletions(-) create mode 100644 .sqlx/query-968c7a6f13255220b3d497d9a1edb181b062747d6463e400158cfdc753a82c5b.json create mode 100644 .sqlx/query-cdbbea42600d61b6541808867397de29e2e5df569faefa098254f1afd3aa662d.json delete mode 100644 .sqlx/query-fba3f52fcaf7463aafcf0c157d84c281da43d140f3156efcb56e1b908d5f013a.json diff --git a/.github/workflows/integration_test.yml b/.github/workflows/integration_test.yml index 298afcb3c..fc539dc29 100644 --- a/.github/workflows/integration_test.yml +++ b/.github/workflows/integration_test.yml @@ -138,7 +138,7 @@ jobs: docker ps -a docker compose -f docker-compose-ci.yml logs - - name: Docker Logs + - name: AI Logs if: always() run: | docker logs appflowy-cloud-ai-1 diff --git a/.sqlx/query-968c7a6f13255220b3d497d9a1edb181b062747d6463e400158cfdc753a82c5b.json b/.sqlx/query-968c7a6f13255220b3d497d9a1edb181b062747d6463e400158cfdc753a82c5b.json new file mode 100644 index 000000000..0f5012072 --- /dev/null +++ b/.sqlx/query-968c7a6f13255220b3d497d9a1edb181b062747d6463e400158cfdc753a82c5b.json @@ -0,0 +1,41 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n ac.oid AS object_id,\n ac.partition_key,\n ac.indexed_at,\n ace.updated_at\n FROM af_collab_embeddings ac\n JOIN af_collab ace\n ON ac.oid = ace.oid\n AND ac.partition_key = ace.partition_key\n WHERE ac.oid = $1 AND ac.partition_key = $2\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "object_id", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "partition_key", + "type_info": "Int4" + }, + { + "ordinal": 2, + "name": "indexed_at", + "type_info": "Timestamp" + }, + { + "ordinal": 3, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text", + "Int4" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "968c7a6f13255220b3d497d9a1edb181b062747d6463e400158cfdc753a82c5b" +} diff --git a/.sqlx/query-cdbbea42600d61b6541808867397de29e2e5df569faefa098254f1afd3aa662d.json b/.sqlx/query-cdbbea42600d61b6541808867397de29e2e5df569faefa098254f1afd3aa662d.json new file mode 100644 index 000000000..8face929f --- /dev/null +++ b/.sqlx/query-cdbbea42600d61b6541808867397de29e2e5df569faefa098254f1afd3aa662d.json @@ -0,0 +1,41 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n ac.oid AS object_id,\n ac.partition_key,\n ac.indexed_at,\n ace.updated_at\n FROM af_collab_embeddings ac\n JOIN af_collab ace\n ON ac.oid = ace.oid\n AND ac.partition_key = ace.partition_key\n WHERE ac.oid = ANY($1) AND ac.partition_key = ANY($2)\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "object_id", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "partition_key", + "type_info": "Int4" + }, + { + "ordinal": 2, + "name": "indexed_at", + "type_info": "Timestamp" + }, + { + "ordinal": 3, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "TextArray", + "Int4Array" + ] + }, + "nullable": [ + false, + false, + false, + false + ] + }, + "hash": "cdbbea42600d61b6541808867397de29e2e5df569faefa098254f1afd3aa662d" +} diff --git a/.sqlx/query-fba3f52fcaf7463aafcf0c157d84c281da43d140f3156efcb56e1b908d5f013a.json b/.sqlx/query-fba3f52fcaf7463aafcf0c157d84c281da43d140f3156efcb56e1b908d5f013a.json deleted file mode 100644 index 3dd819564..000000000 --- a/.sqlx/query-fba3f52fcaf7463aafcf0c157d84c281da43d140f3156efcb56e1b908d5f013a.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT oid AS object_id,indexed_at\n FROM af_collab_embeddings\n WHERE oid = $1 AND partition_key = $2\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "object_id", - "type_info": "Text" - }, - { - "ordinal": 1, - "name": "indexed_at", - "type_info": "Timestamp" - } - ], - "parameters": { - "Left": [ - "Text", - "Int4" - ] - }, - "nullable": [ - false, - false - ] - }, - "hash": "fba3f52fcaf7463aafcf0c157d84c281da43d140f3156efcb56e1b908d5f013a" -} diff --git a/libs/client-api-test/src/test_client.rs b/libs/client-api-test/src/test_client.rs index 9c2f18fb9..1a63536b1 100644 --- a/libs/client-api-test/src/test_client.rs +++ b/libs/client-api-test/src/test_client.rs @@ -37,14 +37,15 @@ use client_api::entity::{ }; use client_api::ws::{WSClient, WSClientConfig}; use database_entity::dto::{ - AFRole, AFSnapshotMeta, AFSnapshotMetas, AFUserProfile, AFUserWorkspaceInfo, AFWorkspace, - AFWorkspaceInvitationStatus, AFWorkspaceMember, BatchQueryCollabResult, CollabParams, - CreateCollabParams, QueryCollab, QueryCollabParams, QuerySnapshotParams, SnapshotData, + AFCollabEmbedInfo, AFRole, AFSnapshotMeta, AFSnapshotMetas, AFUserProfile, AFUserWorkspaceInfo, + AFWorkspace, AFWorkspaceInvitationStatus, AFWorkspaceMember, BatchQueryCollabResult, + CollabParams, CreateCollabParams, QueryCollab, QueryCollabParams, QuerySnapshotParams, + SnapshotData, }; use shared_entity::dto::ai_dto::CalculateSimilarityParams; use shared_entity::dto::search_dto::SearchDocumentResponseItem; use shared_entity::dto::workspace_dto::{ - BlobMetadata, CollabResponse, PublishedDuplicate, WorkspaceMemberChangeset, + BlobMetadata, CollabResponse, EmbeddedCollabQuery, PublishedDuplicate, WorkspaceMemberChangeset, WorkspaceMemberInvitation, WorkspaceSpaceUsage, }; use shared_entity::response::AppResponseError; @@ -555,6 +556,34 @@ impl TestClient { self.api_client.get_profile().await.unwrap() } + pub async fn wait_until_all_embedding( + &self, + workspace_id: &str, + query: Vec, + ) -> Vec { + let timeout_duration = Duration::from_secs(30); + let poll_interval = Duration::from_millis(2000); + let poll_fut = async { + loop { + match self + .api_client + .batch_get_collab_embed_info(workspace_id, query.clone()) + .await + { + Ok(items) if items.len() == query.len() => return Ok::<_, Error>(items), + _ => tokio::time::sleep(poll_interval).await, + } + } + }; + + // Enforce timeout + match timeout(timeout_duration, poll_fut).await { + Ok(Ok(items)) => items, + Ok(Err(e)) => panic!("Test failed: {}", e), + Err(_) => panic!("Test failed: Timeout after 30 seconds."), + } + } + pub async fn wait_until_get_embedding(&self, workspace_id: &str, object_id: &str) { let result = timeout(Duration::from_secs(30), async { while self @@ -620,10 +649,11 @@ impl TestClient { let resp = self.api_client.calculate_similarity(params).await.unwrap(); assert!( resp.score > score, - "Similarity score is too low: {}.\nexpected: {},\ninput: {}", + "Similarity score is too low: {}.\nexpected: {},\ninput: {},\nexpected:{}", resp.score, score, - input + input, + expected ); } diff --git a/libs/client-api/src/http_collab.rs b/libs/client-api/src/http_collab.rs index e4c77a33e..f4f20aa5b 100644 --- a/libs/client-api/src/http_collab.rs +++ b/libs/client-api/src/http_collab.rs @@ -12,7 +12,7 @@ use client_api_entity::workspace_dto::{ use client_api_entity::{ AFCollabEmbedInfo, BatchQueryCollabParams, BatchQueryCollabResult, CollabParams, CreateCollabParams, DeleteCollabParams, PublishCollabItem, QueryCollab, QueryCollabParams, - UpdateCollabWebParams, + RepeatedAFCollabEmbedInfo, UpdateCollabWebParams, }; use collab_rt_entity::collab_proto::{CollabDocStateParams, PayloadCompressionType}; use collab_rt_entity::HttpRealtimeMessage; @@ -22,7 +22,7 @@ use prost::Message; use rayon::prelude::*; use reqwest::{Body, Method}; use serde::Serialize; -use shared_entity::dto::workspace_dto::{CollabResponse, CollabTypeParam}; +use shared_entity::dto::workspace_dto::{CollabResponse, CollabTypeParam, EmbeddedCollabQuery}; use shared_entity::response::{AppResponse, AppResponseError}; use std::future::Future; use std::io::Cursor; @@ -432,6 +432,28 @@ impl Client { .into_data() } + pub async fn batch_get_collab_embed_info( + &self, + workspace_id: &str, + params: Vec, + ) -> Result, AppResponseError> { + let url = format!( + "{}/api/workspace/{workspace_id}/collab/embed-info/list", + self.base_url + ); + let resp = self + .http_client_with_auth(Method::POST, &url) + .await? + .json(¶ms) + .send() + .await?; + log_request_id(&resp); + let data = AppResponse::::from_response(resp) + .await? + .into_data()?; + Ok(data.0) + } + pub async fn collab_full_sync( &self, workspace_id: &str, diff --git a/libs/collab-rt-entity/src/message.rs b/libs/collab-rt-entity/src/message.rs index 17abc2125..37010b5e1 100644 --- a/libs/collab-rt-entity/src/message.rs +++ b/libs/collab-rt-entity/src/message.rs @@ -94,13 +94,35 @@ impl RealtimeMessage { } } + fn object_id(&self) -> Option { + match self { + RealtimeMessage::Collab(msg) => Some(msg.object_id().to_string()), + RealtimeMessage::ClientCollabV1(msgs) => msgs.first().map(|msg| msg.object_id().to_string()), + RealtimeMessage::ClientCollabV2(msgs) => { + if let Some((object_id, _)) = msgs.iter().next() { + Some(object_id.to_string()) + } else { + None + } + }, + _ => None, + } + } + #[cfg(feature = "rt_compress")] pub fn encode(&self) -> Result, Error> { let data = DefaultOptions::new() .with_fixint_encoding() .allow_trailing_bytes() .with_limit(MAXIMUM_REALTIME_MESSAGE_SIZE) - .serialize(self)?; + .serialize(self) + .map_err(|e| { + anyhow!( + "Failed to encode realtime message: {}, object_id:{:?}", + e, + self.object_id() + ) + })?; let mut compressor = CompressorReader::new(&*data, 4096, 4, 22); let mut compressed_data = Vec::new(); @@ -117,7 +139,14 @@ impl RealtimeMessage { .with_fixint_encoding() .allow_trailing_bytes() .with_limit(MAXIMUM_REALTIME_MESSAGE_SIZE) - .serialize(self)?; + .serialize(self) + .map_err(|e| { + anyhow!( + "Failed to encode realtime message: {}, object_id:{:?}", + e, + self.object_id() + ) + })?; Ok(data) } diff --git a/libs/database-entity/src/dto.rs b/libs/database-entity/src/dto.rs index 253b68b3e..432f2b381 100644 --- a/libs/database-entity/src/dto.rs +++ b/libs/database-entity/src/dto.rs @@ -417,10 +417,15 @@ pub struct AFCollabMember { #[derive(Debug, Serialize, Deserialize)] pub struct AFCollabEmbedInfo { pub object_id: String, - /// The timestamp when the object embeddings updated + /// The timestamp when the object's embeddings updated pub indexed_at: DateTime, + /// The timestamp when the object's data updated + pub updated_at: DateTime, } +#[derive(Debug, Serialize, Deserialize)] +pub struct RepeatedAFCollabEmbedInfo(pub Vec); + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct PublishInfo { pub namespace: String, diff --git a/libs/database/src/collab/collab_db_ops.rs b/libs/database/src/collab/collab_db_ops.rs index ef710e2fd..2d2e52a09 100644 --- a/libs/database/src/collab/collab_db_ops.rs +++ b/libs/database/src/collab/collab_db_ops.rs @@ -2,9 +2,9 @@ use anyhow::{anyhow, Context}; use collab_entity::CollabType; use database_entity::dto::{ AFAccessLevel, AFCollabEmbedInfo, AFCollabMember, AFPermission, AFSnapshotMeta, AFSnapshotMetas, - CollabParams, QueryCollab, QueryCollabResult, RawData, + CollabParams, QueryCollab, QueryCollabResult, RawData, RepeatedAFCollabEmbedInfo, }; -use shared_entity::dto::workspace_dto::DatabaseRowUpdatedItem; +use shared_entity::dto::workspace_dto::{DatabaseRowUpdatedItem, EmbeddedCollabQuery}; use crate::collab::{partition_key_from_collab_type, SNAPSHOT_PER_HOUR}; use crate::pg_row::AFCollabRowMeta; @@ -714,10 +714,17 @@ where let partition_key = partition_key_from_collab_type(&collab_type); let record = sqlx::query!( r#" - SELECT oid AS object_id,indexed_at - FROM af_collab_embeddings - WHERE oid = $1 AND partition_key = $2 - "#, + SELECT + ac.oid AS object_id, + ac.partition_key, + ac.indexed_at, + ace.updated_at + FROM af_collab_embeddings ac + JOIN af_collab ace + ON ac.oid = ace.oid + AND ac.partition_key = ace.partition_key + WHERE ac.oid = $1 AND ac.partition_key = $2 + "#, object_id, partition_key ) @@ -727,7 +734,63 @@ where let result = record.map(|row| AFCollabEmbedInfo { object_id: row.object_id, indexed_at: DateTime::::from_naive_utc_and_offset(row.indexed_at, Utc), + updated_at: row.updated_at, }); Ok(result) } + +pub async fn batch_select_collab_embed<'a, E>( + executor: E, + embedded_collab: Vec, +) -> Result +where + E: Executor<'a, Database = Postgres>, +{ + let collab_types: Vec = embedded_collab + .iter() + .map(|query| query.collab_type.clone()) + .collect(); + let object_ids: Vec = embedded_collab + .into_iter() + .map(|query| query.object_id) + .collect(); + + // Collect the partition keys for each collab_type + let partition_keys: Vec = collab_types + .iter() + .map(partition_key_from_collab_type) + .collect(); + + // Execute the query to fetch all matching rows + let records = sqlx::query!( + r#" + SELECT + ac.oid AS object_id, + ac.partition_key, + ac.indexed_at, + ace.updated_at + FROM af_collab_embeddings ac + JOIN af_collab ace + ON ac.oid = ace.oid + AND ac.partition_key = ace.partition_key + WHERE ac.oid = ANY($1) AND ac.partition_key = ANY($2) + "#, + &object_ids, + &partition_keys + ) + .fetch_all(executor) + .await?; + + // Organize the results by object_id + let mut items = vec![]; + for row in records { + let embed_info = AFCollabEmbedInfo { + object_id: row.object_id.clone(), + indexed_at: DateTime::::from_naive_utc_and_offset(row.indexed_at, Utc), + updated_at: row.updated_at, + }; + items.push(embed_info); + } + Ok(RepeatedAFCollabEmbedInfo(items)) +} diff --git a/libs/database/src/index/collab_embeddings_ops.rs b/libs/database/src/index/collab_embeddings_ops.rs index b2b0e1066..874c271da 100644 --- a/libs/database/src/index/collab_embeddings_ops.rs +++ b/libs/database/src/index/collab_embeddings_ops.rs @@ -85,7 +85,6 @@ impl PgHasArrayType for Fragment { pub async fn upsert_collab_embeddings( transaction: &mut Transaction<'_, Postgres>, workspace_id: &Uuid, - _object_id: &str, tokens_used: u32, records: Vec, ) -> Result<(), sqlx::Error> { diff --git a/libs/infra/src/env_util.rs b/libs/infra/src/env_util.rs index f4dce7dbe..ab09ebb86 100644 --- a/libs/infra/src/env_util.rs +++ b/libs/infra/src/env_util.rs @@ -1,11 +1,19 @@ +use std::env::VarError; + pub fn get_env_var(key: &str, default: &str) -> String { - std::env::var(key).unwrap_or_else(|e| { - tracing::debug!( - "failed to read environment variable:{}:{}, using default value: {}", - key, - e, - default - ); + std::env::var(key).unwrap_or_else(|err| { + match err { + VarError::NotPresent => { + tracing::info!("using default environment variable {}:{}", key, default) + }, + VarError::NotUnicode(_) => { + tracing::error!( + "{} is not a valid UTF-8 string, use default value:{}", + key, + default + ); + }, + } default.to_owned() }) } @@ -21,12 +29,8 @@ pub fn get_env_var_opt(key: &str) -> Option { Some(val) } }, - Err(e) => { - tracing::warn!( - "failed to read environment variable {}:{}, None set", - key, - e - ); + Err(_) => { + tracing::info!("using default environment variable {}:None", key); None }, } diff --git a/libs/shared-entity/src/dto/workspace_dto.rs b/libs/shared-entity/src/dto/workspace_dto.rs index 7a072791d..f8becc835 100644 --- a/libs/shared-entity/src/dto/workspace_dto.rs +++ b/libs/shared-entity/src/dto/workspace_dto.rs @@ -125,6 +125,15 @@ pub struct CollabTypeParam { pub collab_type: CollabType, } +#[derive(Debug, Serialize, Deserialize)] +pub struct RepeatedEmbeddedCollabQuery(pub Vec); + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct EmbeddedCollabQuery { + pub collab_type: CollabType, + pub object_id: String, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CollabResponse { #[serde(flatten)] diff --git a/services/appflowy-collaborate/src/config.rs b/services/appflowy-collaborate/src/config.rs index 502e01bcc..19c4ecad7 100644 --- a/services/appflowy-collaborate/src/config.rs +++ b/services/appflowy-collaborate/src/config.rs @@ -3,6 +3,7 @@ use secrecy::Secret; use semver::Version; use serde::Deserialize; use sqlx::postgres::{PgConnectOptions, PgSslMode}; +use std::env::VarError; use std::fmt::Display; use std::str::FromStr; @@ -132,13 +133,19 @@ pub struct CollabSetting { } pub fn get_env_var(key: &str, default: &str) -> String { - std::env::var(key).unwrap_or_else(|e| { - tracing::warn!( - "failed to read environment variable {}:{}, using default value: {}", - key, - e, - default - ); + std::env::var(key).unwrap_or_else(|err| { + match err { + VarError::NotPresent => { + tracing::info!("using default environment variable {}:{}", key, default) + }, + VarError::NotUnicode(_) => { + tracing::error!( + "{} is not a valid UTF-8 string, use default value:{}", + key, + default + ); + }, + } default.to_owned() }) } diff --git a/services/appflowy-collaborate/src/group/group_init.rs b/services/appflowy-collaborate/src/group/group_init.rs index 58cbc59d9..221edc317 100644 --- a/services/appflowy-collaborate/src/group/group_init.rs +++ b/services/appflowy-collaborate/src/group/group_init.rs @@ -123,7 +123,7 @@ impl CollabGroup { }, Err(err) => { trace!( - "failed to index embeddings for document {} {}/{}: {}", + "failed to index embeddings for collab {} {}/{}: {}", self.collab_type, self.workspace_id, self.object_id, diff --git a/services/appflowy-collaborate/src/indexer/document_indexer.rs b/services/appflowy-collaborate/src/indexer/document_indexer.rs index d779f21bd..1ce9c8b54 100644 --- a/services/appflowy-collaborate/src/indexer/document_indexer.rs +++ b/services/appflowy-collaborate/src/indexer/document_indexer.rs @@ -1,7 +1,6 @@ use crate::indexer::open_ai::split_text_by_max_content_len; use crate::indexer::vector::embedder::Embedder; use crate::indexer::Indexer; -use crate::thread_pool_no_abort::ThreadPoolNoAbort; use anyhow::anyhow; use app_error::AppError; use appflowy_ai_client::dto::{ @@ -57,11 +56,6 @@ impl Indexer for DocumentIndexer { return Ok(None); } - let object_id = match content.first() { - None => return Ok(None), - Some(first) => first.object_id.clone(), - }; - let contents: Vec<_> = content .iter() .map(|fragment| fragment.content.clone()) @@ -92,32 +86,11 @@ impl Indexer for DocumentIndexer { param.embedding = Some(embedding); } - tracing::info!( - "received {} embeddings for document {} - tokens used: {}", - content.len(), - object_id, - resp.usage.total_tokens - ); Ok(Some(AFCollabEmbeddings { tokens_consumed: resp.usage.total_tokens as u32, params: content, })) } - - fn embed_in_thread_pool( - &self, - embedder: &Embedder, - content: Vec, - thread_pool: &ThreadPoolNoAbort, - ) -> Result, AppError> { - if content.is_empty() { - return Ok(None); - } - - thread_pool - .install(|| self.embed(embedder, content)) - .map_err(|e| AppError::Unhandled(e.to_string()))? - } } fn split_text_into_chunks( diff --git a/services/appflowy-collaborate/src/indexer/indexer_scheduler.rs b/services/appflowy-collaborate/src/indexer/indexer_scheduler.rs index c007feb3e..adb7791a2 100644 --- a/services/appflowy-collaborate/src/indexer/indexer_scheduler.rs +++ b/services/appflowy-collaborate/src/indexer/indexer_scheduler.rs @@ -16,6 +16,7 @@ use collab::entity::EncodedCollab; use collab::lock::RwLock; use collab::preclude::Collab; use collab_entity::CollabType; +use dashmap::DashMap; use database::collab::{CollabStorage, GetCollabOrigin}; use database::index::{get_collabs_without_embeddings, upsert_collab_embeddings}; use database::workspace::select_workspace_settings; @@ -27,7 +28,7 @@ use std::pin::Pin; use std::sync::Arc; use std::time::Instant; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; -use tracing::{error, info, trace, warn}; +use tracing::{debug, error, info, trace, warn}; use uuid::Uuid; pub struct IndexerScheduler { @@ -39,6 +40,22 @@ pub struct IndexerScheduler { metrics: Arc, schedule_tx: UnboundedSender, config: IndexerConfiguration, + active_tasks: Arc>, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +struct ActiveTask { + object_id: String, + created_at: i64, +} + +impl ActiveTask { + fn new(object_id: String) -> Self { + Self { + object_id, + created_at: chrono::Utc::now().timestamp(), + } + } } pub struct IndexerConfiguration { @@ -76,9 +93,14 @@ impl IndexerScheduler { metrics, schedule_tx, config, + active_tasks: Arc::new(Default::default()), }); - info!("Indexer scheduler is enabled: {}", this.index_enabled()); + info!( + "Indexer scheduler is enabled: {}, num threads: {}", + this.index_enabled(), + num_thread + ); if this.index_enabled() { tokio::spawn(spawn_write_indexing(rx, this.pg_pool.clone())); tokio::spawn(handle_unindexed_collabs(this.clone())); @@ -138,25 +160,41 @@ impl IndexerScheduler { let tx = self.schedule_tx.clone(); let metrics = self.metrics.clone(); + let active_task = self.active_tasks.clone(); + let task = ActiveTask::new(indexed_collab.object_id.clone()); + let task_created_at = task.created_at; + active_task.insert(indexed_collab.object_id.clone(), task); + let threads = self.threads.clone(); + rayon::spawn(move || { - match process_collab(&embedder, &indexer_provider, &indexed_collab, &metrics) { - Ok(Some((tokens_used, contents))) => { - if let Err(err) = tx.send(EmbeddingRecord { - workspace_id, - object_id: indexed_collab.object_id, - tokens_used, - contents, - }) { - error!("Failed to send embedding record: {}", err); - } - }, - Ok(None) => trace!("No embedding for collab:{}", indexed_collab.object_id), - Err(err) => { - warn!( - "Failed to create embeddings content for collab:{}, error:{}", - indexed_collab.object_id, err - ); - }, + let result = threads.install(|| { + if !should_embed(&active_task, &indexed_collab.object_id, task_created_at) { + return; + } + + match process_collab(&embedder, &indexer_provider, &indexed_collab, &metrics) { + Ok(Some((tokens_used, contents))) => { + if let Err(err) = tx.send(EmbeddingRecord { + workspace_id, + object_id: indexed_collab.object_id, + tokens_used, + contents, + }) { + error!("Failed to send embedding record: {}", err); + } + }, + Ok(None) => debug!("No embedding for collab:{}", indexed_collab.object_id), + Err(err) => { + warn!( + "Failed to create embeddings content for collab:{}, error:{}", + indexed_collab.object_id, err + ); + }, + } + }); + + if let Err(err) = result { + error!("Failed to spawn a task to index collab: {}", err); } }); Ok(()) @@ -177,35 +215,40 @@ impl IndexerScheduler { let threads = self.threads.clone(); let tx = self.schedule_tx.clone(); let metrics = self.metrics.clone(); + let active_task = self.active_tasks.clone(); rayon::spawn(move || { - let results = threads.install(|| { - indexed_collabs - .into_par_iter() - .filter_map(|collab| process_collab(&embedder, &indexer_provider, &collab, &metrics).ok()) - .filter_map(|result| result.map(|r| (r.0, r.1))) - .collect::>() - }); - - match results { - Ok(embeddings_list) => { - for (tokens_used, contents) in embeddings_list { - if contents.is_empty() { - continue; - } - let object_id = contents[0].object_id.clone(); - if let Err(err) = tx.send(EmbeddingRecord { - workspace_id, - object_id, - tokens_used, - contents, - }) { - error!("Failed to send embedding record: {}", err); - } - } - }, - Err(err) => { - error!("Failed to process batch indexing: {}", err); - }, + let embeddings_list = indexed_collabs + .into_par_iter() + .filter_map(|collab| { + let task = ActiveTask::new(collab.object_id.clone()); + let task_created_at = task.created_at; + active_task.insert(collab.object_id.clone(), task); + threads + .install(|| { + if !should_embed(&active_task, &collab.object_id, task_created_at) { + return None; + } + process_collab(&embedder, &indexer_provider, &collab, &metrics).ok() + }) + .ok() + }) + .filter_map(|result| result) + .filter_map(|result| result.map(|r| (r.0, r.1))) + .collect::>(); + + for (tokens_used, contents) in embeddings_list { + if contents.is_empty() { + continue; + } + let object_id = contents[0].object_id.clone(); + if let Err(err) = tx.send(EmbeddingRecord { + workspace_id, + object_id, + tokens_used, + contents, + }) { + error!("Failed to send embedding record: {}", err); + } } }); @@ -223,8 +266,6 @@ impl IndexerScheduler { return Ok(()); } - let workspace_id = Uuid::parse_str(workspace_id)?; - let embedder = self.create_embedder()?; let indexer = self .indexer_provider .indexer_for(collab_type) @@ -234,6 +275,8 @@ impl IndexerScheduler { collab_type )) })?; + let workspace_id = Uuid::parse_str(workspace_id)?; + let embedder = self.create_embedder()?; let lock = collab.read().await; let chunks = indexer.create_embedded_chunks(&lock, embedder.model())?; @@ -243,32 +286,46 @@ impl IndexerScheduler { let tx = self.schedule_tx.clone(); let object_id = object_id.to_string(); let metrics = self.metrics.clone(); + let active_tasks = self.active_tasks.clone(); + let task = ActiveTask::new(object_id.clone()); + let task_created_at = task.created_at; + active_tasks.insert(object_id.clone(), task); + rayon::spawn(move || { let start = Instant::now(); metrics.record_embed_count(1); - - let result = indexer.embed_in_thread_pool(&embedder, chunks, &threads); + let result = threads.install(|| { + if !should_embed(&active_tasks, &object_id, task_created_at) { + return Ok(None); + } + indexer.embed(&embedder, chunks) + }); let duration = start.elapsed(); metrics.record_processing_time(duration.as_millis()); match result { - Ok(Some(data)) => { - if let Err(err) = tx.send(EmbeddingRecord { - workspace_id, - object_id: object_id.to_string(), - tokens_used: data.tokens_consumed, - contents: data.params, - }) { - error!("Failed to send embedding record: {}", err); - } + Ok(embed_result) => match embed_result { + Ok(Some(data)) => { + if let Err(err) = tx.send(EmbeddingRecord { + workspace_id, + object_id: object_id.to_string(), + tokens_used: data.tokens_consumed, + contents: data.params, + }) { + error!("Failed to send embedding record: {}", err); + } + }, + Ok(None) => debug!("No embedding for collab:{}", object_id), + Err(err) => { + metrics.record_failed_embed_count(1); + error!( + "Failed to create embeddings content for collab:{}, error:{}", + object_id, err + ); + }, }, - Ok(None) => warn!("No embedding for collab:{}", object_id), Err(err) => { - metrics.record_failed_embed_count(1); - error!( - "Failed to create embeddings content for collab:{}, error:{}", - object_id, err - ); + error!("Failed to spawn a task to index collab: {}", err); }, } }); @@ -290,6 +347,30 @@ impl IndexerScheduler { } } +/// Determines whether an object (Collab) should be processed for embedding. +/// +/// it ensures that duplicate or unnecessary indexing tasks are avoided +/// by checking if the object is already in the active task list. If the object is +/// already being indexed, it prevents re-processing the same object. The function +/// compares the current task's timestamp with any existing active task for the same object +/// to ensure tasks are processed in order and without overlap. +#[inline] +fn should_embed( + active_tasks: &DashMap, + object_id: &str, + created_at: i64, +) -> bool { + let should_embed = active_tasks + .get(object_id) + .map(|t| t.created_at) + .unwrap_or(0) + >= created_at; + if !should_embed { + trace!("[Embedding] Skipping embedding for object: {} because a newer task is already in progress. Previous task with the same object ID has been overridden.", object_id); + } + should_embed +} + async fn handle_unindexed_collabs(scheduler: Arc) { // wait for 30 seconds before starting indexing tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; @@ -389,15 +470,21 @@ async fn index_unindexd_collab( false, ) { if let Ok(chunks) = indexer.create_embedded_chunks(&collab, embedder.model()) { - if let Ok(Some(embeddings)) = indexer.embed_in_thread_pool(&embedder, chunks, &threads) { - if let Err(err) = record_tx.send(EmbeddingRecord { - workspace_id, - object_id: object_id.clone(), - tokens_used: embeddings.tokens_consumed, - contents: embeddings.params, - }) { - error!("Failed to send embedding record: {}", err); + let result = threads.install(|| { + if let Ok(Some(embeddings)) = indexer.embed(&embedder, chunks) { + if let Err(err) = record_tx.send(EmbeddingRecord { + workspace_id, + object_id: object_id.clone(), + tokens_used: embeddings.tokens_consumed, + contents: embeddings.params, + }) { + error!("Failed to send embedding record: {}", err); + } } + }); + + if let Err(err) = result { + error!("Failed to spawn a task to index collab: {}", err); } } } @@ -417,9 +504,15 @@ async fn spawn_write_indexing(mut rx: UnboundedReceiver, pg_poo } let records = buf.drain(..n).collect::>(); + for record in records.iter() { + info!( + "[Embedding] generate collab:{} embeddings, tokens used: {}", + record.object_id, record.tokens_used + ); + } match batch_insert_records(&pg_pool, records).await { - Ok(_) => info!("wrote {} embedding records", n), - Err(err) => error!("Failed to index collab {}", err), + Ok(_) => trace!("[Embedding] save {} embeddings to disk", n), + Err(err) => error!("Failed to write collab embedding to disk:{}", err), } } } @@ -442,7 +535,6 @@ async fn batch_insert_records( upsert_collab_embeddings( &mut txn, &record.workspace_id, - &record.object_id, record.tokens_used, record.contents, ) @@ -452,6 +544,7 @@ async fn batch_insert_records( Ok(()) } +/// This function must be called within the rayon thread pool. fn process_collab( embdder: &Embedder, indexer_provider: &IndexerProvider, @@ -459,9 +552,7 @@ fn process_collab( metrics: &EmbeddingMetrics, ) -> Result)>, AppError> { if let Some(indexer) = indexer_provider.indexer_for(&indexed_collab.collab_type) { - let start_time = Instant::now(); metrics.record_embed_count(1); - let encode_collab = EncodedCollab::decode_from_bytes(&indexed_collab.encoded_collab)?; let collab = Collab::new_with_source( CollabOrigin::Empty, @@ -472,20 +563,14 @@ fn process_collab( ) .map_err(|err| AppError::Internal(err.into()))?; + let start_time = Instant::now(); let chunks = indexer.create_embedded_chunks(&collab, embdder.model())?; let result = indexer.embed(embdder, chunks); let duration = start_time.elapsed(); metrics.record_processing_time(duration.as_millis()); match result { - Ok(Some(embeddings)) => { - trace!( - "Indexed collab {}, tokens: {}", - indexed_collab.object_id, - embeddings.tokens_consumed - ); - Ok(Some((embeddings.tokens_consumed, embeddings.params))) - }, + Ok(Some(embeddings)) => Ok(Some((embeddings.tokens_consumed, embeddings.params))), Ok(None) => Ok(None), Err(err) => { metrics.record_failed_embed_count(1); diff --git a/services/appflowy-collaborate/src/indexer/metrics.rs b/services/appflowy-collaborate/src/indexer/metrics.rs index 34502727f..e410b5086 100644 --- a/services/appflowy-collaborate/src/indexer/metrics.rs +++ b/services/appflowy-collaborate/src/indexer/metrics.rs @@ -49,6 +49,7 @@ impl EmbeddingMetrics { } pub fn record_processing_time(&self, millis: u128) { + tracing::trace!("[Embedding]: processing time: {}ms", millis); self.processing_time_histogram.observe(millis as f64); } } diff --git a/services/appflowy-collaborate/src/indexer/provider.rs b/services/appflowy-collaborate/src/indexer/provider.rs index f3345fcad..fea570411 100644 --- a/services/appflowy-collaborate/src/indexer/provider.rs +++ b/services/appflowy-collaborate/src/indexer/provider.rs @@ -1,7 +1,6 @@ use crate::config::get_env_var; use crate::indexer::vector::embedder::Embedder; use crate::indexer::DocumentIndexer; -use crate::thread_pool_no_abort::ThreadPoolNoAbort; use app_error::AppError; use appflowy_ai_client::dto::EmbeddingModel; use collab::preclude::Collab; @@ -23,13 +22,6 @@ pub trait Indexer: Send + Sync { embedder: &Embedder, content: Vec, ) -> Result, AppError>; - - fn embed_in_thread_pool( - &self, - embedder: &Embedder, - content: Vec, - thread_pool: &ThreadPoolNoAbort, - ) -> Result, AppError>; } /// A structure responsible for resolving different [Indexer] types for different [CollabType]s, diff --git a/src/api/workspace.rs b/src/api/workspace.rs index ba95eb618..c41e76a5b 100644 --- a/src/api/workspace.rs +++ b/src/api/workspace.rs @@ -144,6 +144,10 @@ pub fn workspace_scope() -> Scope { web::resource("/{workspace_id}/collab/{object_id}/embed-info") .route(web::get().to(get_collab_embed_info_handler)), ) + .service( + web::resource("/{workspace_id}/collab/embed-info/list") + .route(web::post().to(batch_get_collab_embed_info_handler)), + ) .service(web::resource("/{workspace_id}/space").route(web::post().to(post_space_handler))) .service( web::resource("/{workspace_id}/space/{view_id}").route(web::patch().to(update_space_handler)), @@ -2212,11 +2216,26 @@ async fn get_collab_embed_info_handler( .await .map_err(AppResponseError::from)? .ok_or_else(|| { - AppError::RecordNotFound(format!("Collab with object_id {} not found", object_id)) + AppError::RecordNotFound(format!( + "Embedding for given object:{} not found", + object_id + )) })?; Ok(Json(AppResponse::Ok().with_data(info))) } +#[instrument(level = "debug", skip_all)] +async fn batch_get_collab_embed_info_handler( + state: Data, + payload: Json, +) -> Result>> { + let payload = payload.into_inner(); + let info = database::collab::batch_select_collab_embed(&state.pg_pool, payload.0) + .await + .map_err(AppResponseError::from)?; + Ok(Json(AppResponse::Ok().with_data(info))) +} + #[instrument(level = "debug", skip_all, err)] async fn collab_full_sync_handler( user_uuid: UserUuid, diff --git a/tests/ai_test/chat_with_selected_doc_test.rs b/tests/ai_test/chat_with_selected_doc_test.rs index 333a33755..690ef65f7 100644 --- a/tests/ai_test/chat_with_selected_doc_test.rs +++ b/tests/ai_test/chat_with_selected_doc_test.rs @@ -7,6 +7,7 @@ use collab_entity::CollabType; use database_entity::dto::CreateCollabParams; use futures_util::future::join_all; use shared_entity::dto::chat_dto::{CreateChatMessageParams, CreateChatParams, UpdateChatParams}; +use shared_entity::dto::workspace_dto::EmbeddedCollabQuery; use std::sync::Arc; use uuid::Uuid; @@ -49,9 +50,6 @@ async fn chat_with_multiple_selected_source_test() { encoded_collab_v1: doc.editor.encode_collab().encode_to_bytes().unwrap(), collab_type: CollabType::Document, }; - - let object_id = doc.object_id.clone(); - let cloned_workspace_id = workspace_id.clone(); let cloned_test_client = Arc::clone(&test_client); async move { // Create collaboration and wait for embedding in parallel @@ -60,16 +58,23 @@ async fn chat_with_multiple_selected_source_test() { .create_collab(params) .await .unwrap(); - cloned_test_client - .wait_until_get_embedding(&cloned_workspace_id, &object_id) - .await; } }) .collect(); - - // Run all tasks concurrently join_all(tasks).await; + // batch query the collab embedding info + let query = docs + .iter() + .map(|doc| EmbeddedCollabQuery { + collab_type: CollabType::Document, + object_id: doc.object_id.clone(), + }) + .collect(); + test_client + .wait_until_all_embedding(&workspace_id, query) + .await; + // create chat let chat_id = uuid::Uuid::new_v4().to_string(); let params = CreateChatParams { @@ -101,14 +106,12 @@ async fn chat_with_multiple_selected_source_test() { &test_client, &workspace_id, &chat_id, - "When do we take off to Japan? Just tell me the date, and if you’re not sure, please let me know you don’t know", + "When do we take off to Japan? Just tell me the date, and if you don't know, Just say you don’t know", ) .await; - let expected_unknown_japan_answer = r#" - I'm sorry, but I don't know the date for your trip to Japan. - "#; + let expected_unknown_japan_answer = r#"I don’t know"#; test_client - .assert_similarity(&workspace_id, &answer, expected_unknown_japan_answer, 0.8) + .assert_similarity(&workspace_id, &answer, expected_unknown_japan_answer, 0.7) .await; // update chat context to snowboarding_in_japan_plan @@ -165,11 +168,11 @@ async fn chat_with_multiple_selected_source_test() { &test_client, &workspace_id, &chat_id, - "When do we take off to Japan? Just tell me the date, and if you’re not sure, please let me know you don’t know", + "When do we take off to Japan? Just tell me the date, and if you don't know, Just say you don’t know", ) .await; test_client - .assert_similarity(&workspace_id, &answer, expected_unknown_japan_answer, 0.8) + .assert_similarity(&workspace_id, &answer, expected_unknown_japan_answer, 0.7) .await; }