From 49e83c6b6b813fb2322169f783f20f6fddb89e93 Mon Sep 17 00:00:00 2001 From: "Nathan.fooo" <86001920+appflowy@users.noreply.github.com> Date: Mon, 6 Jan 2025 01:19:00 +0800 Subject: [PATCH] chore: format chat response (#1125) * chore: chat response with format * chore: update prompt * chore: update test * chore: update test * chore: fix stress test * chore: fix test * chore: test * chore: test * chore: fix stress test * chore: fix test --- .github/workflows/stress_test.yml | 35 +-- dev.env | 4 +- docker-compose-dev.yml | 19 -- docker-compose-stress-test.yml | 97 --------- libs/appflowy-ai-client/src/client.rs | 17 +- libs/appflowy-ai-client/src/dto.rs | 74 +++++++ .../tests/chat_test/qa_test.rs | 59 ------ libs/client-api/src/http_chat.rs | 31 ++- src/api/chat.rs | 54 ++++- tests/ai_test/chat_test.rs | 200 +++++++++++------- tests/collab/stress_test.rs | 2 +- xtask/src/main.rs | 10 +- 12 files changed, 321 insertions(+), 281 deletions(-) delete mode 100644 docker-compose-stress-test.yml diff --git a/.github/workflows/stress_test.yml b/.github/workflows/stress_test.yml index 267a15eea..742ffc53c 100644 --- a/.github/workflows/stress_test.yml +++ b/.github/workflows/stress_test.yml @@ -4,14 +4,15 @@ on: [ pull_request ] concurrency: group: stress-test-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: false + cancel-in-progress: true env: + SQLX_OFFLINE: true + RUST_TOOLCHAIN: "1.80" POSTGRES_HOST: localhost REDIS_HOST: localhost - MINIO_HOST: localhost - SQLX_OFFLINE: true - RUST_TOOLCHAIN: "1.78" + LOCALHOST_GOTRUE: http://localhost/gotrue + DATABASE_URL: postgres://postgres:password@localhost:5432/postgres jobs: test: @@ -25,24 +26,34 @@ jobs: - name: Install Rust Toolchain uses: dtolnay/rust-toolchain@stable - - name: Copy and Rename deploy.env to .env - run: cp deploy.env .env + - name: Copy and Rename dev.env to .env + run: cp dev.env .env + + - name: Install Prerequisites + run: | + brew update + brew install libpq + brew install sqlx-cli + brew install protobuf - name: Replace Values in .env run: | sed -i '' 's|RUST_LOG=.*|RUST_LOG=debug|' .env sed -i '' 's|API_EXTERNAL_URL=.*|API_EXTERNAL_URL=http://localhost:9999|' .env + sed -i '' 's|APPFLOWY_INDEXER_ENABLED=.*|APPFLOWY_INDEXER_ENABLED=false|' .env + sed -i '' 's|APPFLOWY_GOTRUE_BASE_URL=.*|APPFLOWY_GOTRUE_BASE_URL=http://localhost:9999|' .env + sed -i '' 's|GOTRUE_MAILER_AUTOCONFIRM=.*|GOTRUE_MAILER_AUTOCONFIRM=false|' .env + sed -i '' 's|APPFLOWY_DATABASE_URL=.*|APPFLOWY_DATABASE_URL=postgres://postgres:password@localhost:5432/postgres|' .env + + cat .env shell: bash - name: Start Docker Compose Services run: | - docker compose -f docker-compose-stress-test.yml up -d - docker ps -a - - - name: Install Prerequisites - run: | - brew install protobuf + docker compose -f docker-compose-dev.yml up -d + ./script/code_gen.sh + cargo sqlx database create && cargo sqlx migrate run - name: Run Server and Test run: | diff --git a/dev.env b/dev.env index 34737c958..9c9bd76c8 100644 --- a/dev.env +++ b/dev.env @@ -116,13 +116,13 @@ CLOUDFLARE_TUNNEL_TOKEN= APPFLOWY_AI_OPENAI_API_KEY= APPFLOWY_AI_SERVER_PORT=5001 APPFLOWY_AI_SERVER_HOST=localhost -APPFLOWY_AI_DATABASE_URL=postgresql+psycopg://postgres:password@postgres:5432/postgres +APPFLOWY_AI_DATABASE_URL=postgresql+psycopg://postgres:password@localhost:5432/postgres APPFLOWY_AI_REDIS_URL=redis://redis:6379 APPFLOWY_LOCAL_AI_TEST_ENABLED=false # AppFlowy Indexer APPFLOWY_INDEXER_ENABLED=true -APPFLOWY_INDEXER_DATABASE_URL=postgres://postgres:password@postgres:5432/postgres +APPFLOWY_INDEXER_DATABASE_URL=postgres://postgres:password@localhost:5432/postgres APPFLOWY_INDEXER_REDIS_URL=redis://redis:6379 APPFLOWY_INDEXER_EMBEDDING_BUFFER_SIZE=5000 diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 4f6cbcd9e..2c4803ee9 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -98,14 +98,6 @@ services: ports: - 9999:9999 - portainer: - restart: on-failure - image: portainer/portainer-ce:latest - ports: - - 9442:9000 - volumes: - - /var/run/docker.sock:/var/run/docker.sock - pgadmin: restart: on-failure image: dpage/pgadmin4 @@ -119,16 +111,5 @@ services: volumes: - ./docker/pgadmin/servers.json:/pgadmin4/servers.json - ai: - restart: on-failure - image: appflowyinc/appflowy_ai:${APPFLOWY_AI_VERSION:-latest} - ports: - - 5001:5001 - environment: - - OPENAI_API_KEY=${APPFLOWY_AI_OPENAI_API_KEY} - - APPFLOWY_AI_SERVER_PORT=${APPFLOWY_AI_SERVER_PORT} - - APPFLOWY_AI_DATABASE_URL=${APPFLOWY_AI_DATABASE_URL} - - APPFLOWY_AI_REDIS_URL=${APPFLOWY_AI_REDIS_URL} - volumes: postgres_data: diff --git a/docker-compose-stress-test.yml b/docker-compose-stress-test.yml deleted file mode 100644 index 58a65e641..000000000 --- a/docker-compose-stress-test.yml +++ /dev/null @@ -1,97 +0,0 @@ -services: - nginx: - restart: on-failure - image: nginx - ports: - - 80:80 # Disable this if you are using TLS - - 443:443 - volumes: - - ./nginx/nginx.conf:/etc/nginx/nginx.conf - - ./nginx/ssl/certificate.crt:/etc/nginx/ssl/certificate.crt - - ./nginx/ssl/private_key.key:/etc/nginx/ssl/private_key.key - minio: - restart: on-failure - image: minio/minio - ports: - - 9000:9000 - - 9001:9001 - environment: - - MINIO_BROWSER_REDIRECT_URL=http://localhost:9001 - command: server /data --console-address ":9001" - - postgres: - restart: on-failure - image: pgvector/pgvector:pg16 - environment: - - POSTGRES_USER=${POSTGRES_USER:-postgres} - - POSTGRES_DB=${POSTGRES_DB:-postgres} - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - - POSTGRES_HOST=${POSTGRES_HOST:-postgres} - - SUPABASE_USER=${SUPABASE_USER:-supabase_auth_admin} - - SUPABASE_PASSWORD=${SUPABASE_PASSWORD:-root} - ports: - - 5432:5432 - volumes: - - ./migrations/before:/docker-entrypoint-initdb.d - # comment out the following line if you want to persist data when restarting docker - #- postgres_data:/var/lib/postgresql/data - - redis: - restart: on-failure - image: redis - ports: - - 6379:6379 - - gotrue: - restart: on-failure - image: supabase/gotrue:v2.159.1 - depends_on: - - postgres - environment: - # Gotrue config: https://github.com/supabase/gotrue/blob/master/example.env - - GOTRUE_SITE_URL=appflowy-flutter:// # redirected to AppFlowy application - - URI_ALLOW_LIST=* # adjust restrict if necessary - - GOTRUE_JWT_SECRET=${GOTRUE_JWT_SECRET} # authentication secret - - GOTRUE_JWT_EXP=${GOTRUE_JWT_EXP} - - GOTRUE_DB_DRIVER=postgres - - API_EXTERNAL_URL=${API_EXTERNAL_URL} - - DATABASE_URL=${GOTRUE_DATABASE_URL} - - PORT=9999 - - GOTRUE_MAILER_URLPATHS_CONFIRMATION=/verify - - GOTRUE_SMTP_HOST=${GOTRUE_SMTP_HOST} # e.g. smtp.gmail.com - - GOTRUE_SMTP_PORT=${GOTRUE_SMTP_PORT} # e.g. 465 - - GOTRUE_SMTP_USER=${GOTRUE_SMTP_USER} # email sender, e.g. noreply@appflowy.io - - GOTRUE_SMTP_PASS=${GOTRUE_SMTP_PASS} # email password - - GOTRUE_SMTP_ADMIN_EMAIL=${GOTRUE_SMTP_ADMIN_EMAIL} # email with admin privileges e.g. internal@appflowy.io - - GOTRUE_SMTP_MAX_FREQUENCY=${GOTRUE_SMTP_MAX_FREQUENCY:-1ns} # set to 1ns for running tests - - GOTRUE_RATE_LIMIT_EMAIL_SENT=${GOTRUE_RATE_LIMIT_EMAIL_SENT:-100} # number of email sendable per minute - - GOTRUE_MAILER_AUTOCONFIRM=${GOTRUE_MAILER_AUTOCONFIRM:-false} # change this to true to skip email confirmation - # Google OAuth config - - GOTRUE_EXTERNAL_GOOGLE_ENABLED=${GOTRUE_EXTERNAL_GOOGLE_ENABLED} - - GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID=${GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID} - - GOTRUE_EXTERNAL_GOOGLE_SECRET=${GOTRUE_EXTERNAL_GOOGLE_SECRET} - - GOTRUE_EXTERNAL_GOOGLE_REDIRECT_URI=${GOTRUE_EXTERNAL_GOOGLE_REDIRECT_URI} - # Apple OAuth config - - GOTRUE_EXTERNAL_APPLE_ENABLED=${GOTRUE_EXTERNAL_APPLE_ENABLED} - - GOTRUE_EXTERNAL_APPLE_CLIENT_ID=${GOTRUE_EXTERNAL_APPLE_CLIENT_ID} - - GOTRUE_EXTERNAL_APPLE_SECRET=${GOTRUE_EXTERNAL_APPLE_SECRET} - - GOTRUE_EXTERNAL_APPLE_REDIRECT_URI=${GOTRUE_EXTERNAL_APPLE_REDIRECT_URI} - # GITHUB OAuth config - - GOTRUE_EXTERNAL_GITHUB_ENABLED=${GOTRUE_EXTERNAL_GITHUB_ENABLED} - - GOTRUE_EXTERNAL_GITHUB_CLIENT_ID=${GOTRUE_EXTERNAL_GITHUB_CLIENT_ID} - - GOTRUE_EXTERNAL_GITHUB_SECRET=${GOTRUE_EXTERNAL_GITHUB_SECRET} - - GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI=${GOTRUE_EXTERNAL_GITHUB_REDIRECT_URI} - # Discord OAuth config - - GOTRUE_EXTERNAL_DISCORD_ENABLED=${GOTRUE_EXTERNAL_DISCORD_ENABLED} - - GOTRUE_EXTERNAL_DISCORD_CLIENT_ID=${GOTRUE_EXTERNAL_DISCORD_CLIENT_ID} - - GOTRUE_EXTERNAL_DISCORD_SECRET=${GOTRUE_EXTERNAL_DISCORD_SECRET} - - GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI=${GOTRUE_EXTERNAL_DISCORD_REDIRECT_URI} - # Prometheus Metrics - - GOTRUE_METRICS_ENABLED=true - - GOTRUE_METRICS_EXPORTER=prometheus - - GOTRUE_MAILER_TEMPLATES_CONFIRMATION=${GOTRUE_MAILER_TEMPLATES_CONFIRMATION} - ports: - - 9999:9999 - -volumes: - postgres_data: diff --git a/libs/appflowy-ai-client/src/client.rs b/libs/appflowy-ai-client/src/client.rs index 8b9af236d..a88b82969 100644 --- a/libs/appflowy-ai-client/src/client.rs +++ b/libs/appflowy-ai-client/src/client.rs @@ -1,8 +1,8 @@ use crate::dto::{ AIModel, CalculateSimilarityParams, ChatAnswer, ChatQuestion, CompleteTextResponse, CompletionType, CreateChatContext, CustomPrompt, Document, LocalAIConfig, MessageData, - RepeatedLocalAIPackage, RepeatedRelatedQuestion, SearchDocumentsRequest, SimilarityResponse, - SummarizeRowResponse, TranslateRowData, TranslateRowResponse, + RepeatedLocalAIPackage, RepeatedRelatedQuestion, ResponseFormat, SearchDocumentsRequest, + SimilarityResponse, SummarizeRowResponse, TranslateRowData, TranslateRowResponse, }; use crate::error::AIError; @@ -187,6 +187,7 @@ impl AppFlowyAIClient { rag_ids: vec![], message_id: Some(question_id.to_string()), }, + format: Default::default(), }; let url = format!("{}/chat/message", self.url); let resp = self @@ -216,6 +217,7 @@ impl AppFlowyAIClient { rag_ids, message_id: None, }, + format: Default::default(), }; let url = format!("{}/chat/message/stream", self.url); let resp = self @@ -245,12 +247,21 @@ impl AppFlowyAIClient { rag_ids, message_id: Some(question_id.to_string()), }, + format: ResponseFormat::default(), }; + self.stream_question_v3(model, json).await + } + + pub async fn stream_question_v3( + &self, + model: &AIModel, + question: ChatQuestion, + ) -> Result>, AIError> { let url = format!("{}/v2/chat/message/stream", self.url); let resp = self .async_http_client(Method::POST, &url)? .header(AI_MODEL_HEADER_KEY, model.to_str()) - .json(&json) + .json(&question) .timeout(Duration::from_secs(30)) .send() .await?; diff --git a/libs/appflowy-ai-client/src/dto.rs b/libs/appflowy-ai-client/src/dto.rs index c8a6cbe69..dad3d9a18 100644 --- a/libs/appflowy-ai-client/src/dto.rs +++ b/libs/appflowy-ai-client/src/dto.rs @@ -7,15 +7,89 @@ use std::str::FromStr; pub const STREAM_METADATA_KEY: &str = "0"; pub const STREAM_ANSWER_KEY: &str = "1"; +pub const STREAM_IMAGE_KEY: &str = "2"; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SummarizeRowResponse { pub text: String, } +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ChatQuestionQuery { + pub chat_id: String, + pub question_id: i64, + pub format: ResponseFormat, +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ChatQuestion { pub chat_id: String, pub data: MessageData, + #[serde(default)] + pub format: ResponseFormat, +} + +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct ResponseFormat { + pub output_layout: OutputLayout, + pub output_content: OutputContent, + pub output_content_metadata: Option, +} + +#[derive(Clone, Debug, Default, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum OutputLayout { + #[default] + Paragraph = 0, + BulletList = 1, + NumberedList = 2, + SimpleTable = 3, +} + +#[derive(Clone, Debug, Default, Serialize_repr, Deserialize_repr)] +#[repr(u8)] +pub enum OutputContent { + #[default] + TEXT = 0, + IMAGE = 1, + RichTextImage = 2, +} + +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct OutputContentMetadata { + /// Custom prompt for image generation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub custom_image_prompt: Option, + + /// The image model to use for generation (default: "dall-e-2"). + #[serde(default = "default_image_model")] + pub image_model: String, + + /// Size of the image (default: "256x256"). + #[serde( + default = "default_image_size", + skip_serializing_if = "Option::is_none" + )] + pub size: Option, + + /// Quality of the image (default: "standard"). + #[serde( + default = "default_image_quality", + skip_serializing_if = "Option::is_none" + )] + pub quality: Option, +} + +// Default values for the fields +fn default_image_model() -> String { + "dall-e-2".to_string() +} + +fn default_image_size() -> Option { + Some("256x256".to_string()) +} + +fn default_image_quality() -> Option { + Some("standard".to_string()) } #[derive(Clone, Debug, Serialize, Deserialize)] diff --git a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs index 2aac663ae..31ddc5f1b 100644 --- a/libs/appflowy-ai-client/tests/chat_test/qa_test.rs +++ b/libs/appflowy-ai-client/tests/chat_test/qa_test.rs @@ -24,66 +24,7 @@ async fn qa_test() { println!("questions: {:?}", questions); assert_eq!(questions.len(), 3) } -#[tokio::test] -async fn stop_stream_test() { - let client = appflowy_ai_client(); - client.health_check().await.unwrap(); - let chat_id = uuid::Uuid::new_v4().to_string(); - let mut stream = client - .stream_question(&chat_id, "I feel hungry", None, vec![], &AIModel::GPT4oMini) - .await - .unwrap(); - - let mut count = 0; - while let Some(message) = stream.next().await { - if count > 1 { - break; - } - count += 1; - println!("message: {:?}", message); - } - - assert_ne!(count, 0); -} -#[tokio::test] -async fn stream_test() { - let client = appflowy_ai_client(); - client.health_check().await.expect("Health check failed"); - let chat_id = uuid::Uuid::new_v4().to_string(); - let stream = client - .stream_question_v2( - &chat_id, - 1, - "I feel hungry", - None, - vec![], - &AIModel::GPT4oMini, - ) - .await - .expect("Failed to initiate question stream"); - - // Wrap the stream in JsonStream with appropriate type parameters - let json_stream = JsonStream::::new(stream); - - // Collect messages from the stream - let messages: Vec = json_stream - .filter_map(|item| async { - match item { - Ok(value) => value - .get(STREAM_ANSWER_KEY) - .and_then(|s| s.as_str().map(ToString::to_string)), - Err(err) => { - eprintln!("Error during streaming: {:?}", err); // Log the error for better debugging - None - }, - } - }) - .collect() - .await; - - println!("final answer: {}", messages.join("")); -} #[tokio::test] async fn download_package_test() { let client = appflowy_ai_client(); diff --git a/libs/client-api/src/http_chat.rs b/libs/client-api/src/http_chat.rs index 9c09aa2c4..b9b5c575b 100644 --- a/libs/client-api/src/http_chat.rs +++ b/libs/client-api/src/http_chat.rs @@ -11,8 +11,8 @@ use pin_project::pin_project; use reqwest::Method; use serde_json::Value; use shared_entity::dto::ai_dto::{ - CalculateSimilarityParams, RepeatedRelatedQuestion, SimilarityResponse, STREAM_ANSWER_KEY, - STREAM_METADATA_KEY, + CalculateSimilarityParams, ChatQuestionQuery, RepeatedRelatedQuestion, SimilarityResponse, + STREAM_ANSWER_KEY, STREAM_IMAGE_KEY, STREAM_METADATA_KEY, }; use shared_entity::dto::chat_dto::{ChatSettings, UpdateChatParams}; use shared_entity::response::{AppResponse, AppResponseError}; @@ -171,6 +171,26 @@ impl Client { Ok(QuestionStream::new(stream)) } + pub async fn stream_answer_v3( + &self, + workspace_id: &str, + query: ChatQuestionQuery, + ) -> Result { + let url = format!( + "{}/api/chat/{workspace_id}/{}/answer/stream", + self.base_url, query.chat_id + ); + let resp = self + .http_client_with_auth(Method::POST, &url) + .await? + .json(&query) + .send() + .await?; + log_request_id(&resp); + let stream = AppResponse::::json_response_stream(resp).await?; + Ok(QuestionStream::new(stream)) + } + pub async fn get_answer( &self, workspace_id: &str, @@ -367,6 +387,13 @@ impl Stream for QuestionStream { return Poll::Ready(Some(Ok(QuestionStreamValue::Answer { value: answer }))); } + if let Some(image) = value + .remove(STREAM_IMAGE_KEY) + .and_then(|s| s.as_str().map(ToString::to_string)) + { + return Poll::Ready(Some(Ok(QuestionStreamValue::Answer { value: image }))); + } + error!("Invalid streaming value: {:?}", value); Poll::Ready(None) }, diff --git a/src/api/chat.rs b/src/api/chat.rs index 62e465dba..f48e0b078 100644 --- a/src/api/chat.rs +++ b/src/api/chat.rs @@ -9,7 +9,9 @@ use serde::Deserialize; use crate::api::util::ai_model_from_header; use app_error::AppError; -use appflowy_ai_client::dto::{CreateChatContext, RepeatedRelatedQuestion}; +use appflowy_ai_client::dto::{ + ChatQuestion, ChatQuestionQuery, CreateChatContext, MessageData, RepeatedRelatedQuestion, +}; use authentication::jwt::UserUuid; use bytes::Bytes; use database::chat; @@ -88,6 +90,10 @@ pub fn chat_scope() -> Scope { web::resource("/{chat_id}/{message_id}/v2/answer/stream") .route(web::get().to(answer_stream_v2_handler)) ) + .service( + web::resource("/{chat_id}/answer/stream") + .route(web::post().to(answer_stream_v3_handler)) + ) // Additional functionality .service( @@ -325,6 +331,52 @@ async fn answer_stream_v2_handler( } } +#[instrument(level = "debug", skip_all, err)] +async fn answer_stream_v3_handler( + payload: Json, + state: Data, + req: HttpRequest, +) -> actix_web::Result { + let payload = payload.into_inner(); + let (content, metadata) = + chat::chat_ops::select_chat_message_content(&state.pg_pool, payload.question_id).await?; + let rag_ids = chat::chat_ops::select_chat_rag_ids(&state.pg_pool, &payload.chat_id).await?; + let ai_model = ai_model_from_header(&req); + + let question = ChatQuestion { + chat_id: payload.chat_id, + data: MessageData { + content: content.to_string(), + metadata: Some(metadata), + rag_ids, + message_id: Some(payload.question_id.to_string()), + }, + format: payload.format, + }; + trace!("[Chat] stream v3 {:?}", question); + match state + .ai_client + .stream_question_v3(&ai_model, question) + .await + { + Ok(answer_stream) => { + let new_answer_stream = answer_stream.map_err(AppError::from); + Ok( + HttpResponse::Ok() + .content_type("text/event-stream") + .streaming(new_answer_stream), + ) + }, + Err(err) => Ok( + HttpResponse::ServiceUnavailable() + .content_type("text/event-stream") + .streaming(stream::once(async move { + Err(AppError::AIServiceUnavailable(err.to_string())) + })), + ), + } +} + #[instrument(level = "debug", skip_all, err)] async fn get_chat_message_handler( path: web::Path<(String, String)>, diff --git a/tests/ai_test/chat_test.rs b/tests/ai_test/chat_test.rs index e9f5db04f..1006aa33c 100644 --- a/tests/ai_test/chat_test.rs +++ b/tests/ai_test/chat_test.rs @@ -1,6 +1,9 @@ use crate::ai_test::util::read_text_from_asset; -use assert_json_diff::{assert_json_eq, assert_json_include}; +use appflowy_ai_client::dto::{ + ChatQuestionQuery, OutputContent, OutputContentMetadata, OutputLayout, ResponseFormat, +}; +use assert_json_diff::assert_json_include; use client_api::entity::{QuestionStream, QuestionStreamValue}; use client_api_test::{ai_test_enabled, TestClient}; use futures_util::StreamExt; @@ -220,35 +223,17 @@ async fn chat_qa_test() { .create_question(&workspace_id, &chat_id, params) .await .unwrap(); - assert_json_include!( - actual: question.meta_data, - expected: json!([ - { - "id": "123", - "name": "test context", - "source": "user added", - "extra": { - "created_at": 123 - } + let expected = json!({ + "id": "123", + "name": "test context", + "source": "user added", + "extra": { + "created_at": 123 } - ]) - ); - - let answer = test_client - .api_client - .get_answer(&workspace_id, &chat_id, question.message_id) - .await - .unwrap(); - assert!(answer.content.contains("Singapore")); - assert_json_eq!( - answer.meta_data, - json!([ - { - "id": "123", - "name": "test context", - "source": "user added", - } - ]) + }); + assert_json_include!( + actual: json!(question.meta_data[0]), + expected: expected ); let related_questions = test_client @@ -294,59 +279,110 @@ async fn generate_chat_message_answer_test() { assert!(!answer.is_empty()); } -// #[tokio::test] -// async fn update_chat_message_test() { -// if !ai_test_enabled() { -// return; -// } - -// let test_client = TestClient::new_user_without_ws_conn().await; -// let workspace_id = test_client.workspace_id().await; -// let chat_id = uuid::Uuid::new_v4().to_string(); -// let params = CreateChatParams { -// chat_id: chat_id.clone(), -// name: "my second chat".to_string(), -// rag_ids: vec![], -// }; - -// test_client -// .api_client -// .create_chat(&workspace_id, params) -// .await -// .unwrap(); - -// let params = CreateChatMessageParams::new_user("where is singapore?"); -// let stream = test_client -// .api_client -// .create_chat_message(&workspace_id, &chat_id, params) -// .await -// .unwrap(); -// let messages: Vec = stream.map(|message| message.unwrap()).collect().await; -// assert_eq!(messages.len(), 2); - -// let params = UpdateChatMessageContentParams { -// chat_id: chat_id.clone(), -// message_id: messages[0].message_id, -// content: "where is China?".to_string(), -// }; -// test_client -// .api_client -// .update_chat_message(&workspace_id, &chat_id, params) -// .await -// .unwrap(); - -// let remote_messages = test_client -// .api_client -// .get_chat_messages(&workspace_id, &chat_id, MessageCursor::NextBack, 2) -// .await -// .unwrap() -// .messages; -// assert_eq!(remote_messages[0].content, "where is China?"); -// assert_eq!(remote_messages.len(), 2); - -// // when the question was updated, the answer should be different -// assert_ne!(remote_messages[1].content, messages[1].content); -// } +#[tokio::test] +async fn get_format_question_message_test() { + if !ai_test_enabled() { + return; + } + + let test_client = TestClient::new_user_without_ws_conn().await; + let workspace_id = test_client.workspace_id().await; + let chat_id = uuid::Uuid::new_v4().to_string(); + let params = CreateChatParams { + chat_id: chat_id.clone(), + name: "my ai chat".to_string(), + rag_ids: vec![], + }; + + test_client + .api_client + .create_chat(&workspace_id, params) + .await + .unwrap(); + + let params = CreateChatMessageParams::new_user( + "what is the different between Rust and c++? Give me three points", + ); + let question = test_client + .api_client + .create_question(&workspace_id, &chat_id, params) + .await + .unwrap(); + + let query = ChatQuestionQuery { + chat_id, + question_id: question.message_id, + format: ResponseFormat { + output_layout: OutputLayout::SimpleTable, + output_content: OutputContent::TEXT, + output_content_metadata: None, + }, + }; + + let answer_stream = test_client + .api_client + .stream_answer_v3(&workspace_id, query) + .await + .unwrap(); + let answer = collect_answer(answer_stream).await; + println!("answer:\n{}", answer); + assert!(!answer.is_empty()); +} + +#[tokio::test] +async fn get_text_with_image_message_test() { + if !ai_test_enabled() { + return; + } + + let test_client = TestClient::new_user_without_ws_conn().await; + let workspace_id = test_client.workspace_id().await; + let chat_id = uuid::Uuid::new_v4().to_string(); + let params = CreateChatParams { + chat_id: chat_id.clone(), + name: "my ai chat".to_string(), + rag_ids: vec![], + }; + + test_client + .api_client + .create_chat(&workspace_id, params) + .await + .unwrap(); + + let params = CreateChatMessageParams::new_user( + "I have a little cat. It is black with big eyes, short legs and a long tail", + ); + let question = test_client + .api_client + .create_question(&workspace_id, &chat_id, params) + .await + .unwrap(); + + let query = ChatQuestionQuery { + chat_id, + question_id: question.message_id, + format: ResponseFormat { + output_layout: OutputLayout::SimpleTable, + output_content: OutputContent::RichTextImage, + output_content_metadata: Some(OutputContentMetadata { + custom_image_prompt: None, + image_model: "dall-e-3".to_string(), + size: None, + quality: None, + }), + }, + }; + + let answer_stream = test_client + .api_client + .stream_answer_v3(&workspace_id, query) + .await + .unwrap(); + let answer = collect_answer(answer_stream).await; + println!("answer:\n{}", answer); + assert!(!answer.is_empty()); +} #[tokio::test] async fn get_question_message_test() { diff --git a/tests/collab/stress_test.rs b/tests/collab/stress_test.rs index 7dacb1d9d..7142f1f19 100644 --- a/tests/collab/stress_test.rs +++ b/tests/collab/stress_test.rs @@ -18,7 +18,7 @@ async fn stress_test_run_multiple_text_edits() { )); // create writer let mut writer = TestClient::new_user().await; - sleep(Duration::from_secs(2)).await; // sleep 2 secs to make sure it do not trigger register user too fast in gotrue + sleep(Duration::from_secs(5)).await; // sleep 5 secs to make sure it do not trigger register user too fast in gotrue let object_id = Uuid::new_v4().to_string(); let workspace_id = writer.workspace_id().await; diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 8b9d86815..02899098e 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -14,6 +14,7 @@ use tokio::time::{sleep, Duration}; #[tokio::main] async fn main() -> Result<()> { let is_stress_test = std::env::args().any(|arg| arg == "--stress-test"); + let disable_log = std::env::args().any(|arg| arg == "--disable-log"); let target_dir = "./target"; std::env::set_var("CARGO_TARGET_DIR", target_dir); @@ -30,7 +31,7 @@ async fn main() -> Result<()> { "cargo", &["run", "--features", "history"], appflowy_cloud_bin_name, - is_stress_test, + disable_log, )?; wait_for_readiness(appflowy_cloud_bin_name).await?; @@ -43,7 +44,7 @@ async fn main() -> Result<()> { "./services/appflowy-worker/Cargo.toml", ], worker_bin_name, - is_stress_test, + disable_log, )?; wait_for_readiness(worker_bin_name).await?; @@ -94,7 +95,10 @@ fn spawn_server( name: &str, suppress_output: bool, ) -> Result { - println!("Spawning {} process...", name); + println!( + "Spawning {} process..., log enabled:{}", + name, suppress_output + ); let mut cmd = Command::new(command); cmd.args(args);