diff --git a/Cargo.lock b/Cargo.lock index 383cdddc516ca0..aea6d723e8cbc2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,32 +85,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "ai" -version = "0.1.0" -dependencies = [ - "anyhow", - "async-trait", - "bincode", - "futures 0.3.28", - "gpui", - "isahc", - "language", - "log", - "matrixmultiply", - "ordered-float 2.10.0", - "parking_lot", - "parse_duration", - "postage", - "rand 0.8.5", - "rusqlite", - "schemars", - "serde", - "serde_json", - "tiktoken-rs", - "util", -] - [[package]] name = "alacritty_terminal" version = "0.22.1-dev" @@ -339,9 +313,9 @@ dependencies = [ name = "assistant" version = "0.1.0" dependencies = [ - "ai", "anyhow", "chrono", + "client", "collections", "ctor", "editor", @@ -354,13 +328,14 @@ dependencies = [ "log", "menu", "multi_buffer", + "open_ai", "ordered-float 2.10.0", + "parking_lot", "project", "rand 0.8.5", "regex", "schemars", "search", - "semantic_index", "serde", "serde_json", "settings", @@ -1339,7 +1314,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa" dependencies = [ - "num-bigint 0.4.4", + "num-bigint", "num-integer", "num-traits", ] @@ -2209,11 +2184,11 @@ dependencies = [ "fs", "futures 0.3.28", "git", + "google_ai", "gpui", "hex", "indoc", "language", - "lazy_static", "live_kit_client", "live_kit_server", "log", @@ -2222,6 +2197,7 @@ dependencies = [ "nanoid", "node_runtime", "notifications", + "open_ai", "parking_lot", "pretty_assertions", "project", @@ -3554,24 +3530,12 @@ dependencies = [ "workspace", ] -[[package]] -name = "fallible-iterator" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" - [[package]] name = "fallible-iterator" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" -[[package]] -name = "fallible-streaming-iterator" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" - [[package]] name = "fancy-regex" version = "0.11.0" @@ -4183,7 +4147,7 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" dependencies = [ - "fallible-iterator 0.3.0", + "fallible-iterator", "indexmap 2.0.0", "stable_deref_trait", ] @@ -4279,6 +4243,17 @@ dependencies = [ "workspace", ] +[[package]] +name = "google_ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "serde", + "serde_json", + "util", +] + [[package]] name = "gpu-alloc" version = "0.6.0" @@ -5667,16 +5642,6 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" -[[package]] -name = "matrixmultiply" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" -dependencies = [ - "autocfg", - "rawpointer", -] - [[package]] name = "maybe-owned" version = "0.3.4" @@ -5946,19 +5911,6 @@ dependencies = [ "tempfile", ] -[[package]] -name = "ndarray" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" -dependencies = [ - "matrixmultiply", - "num-complex 0.4.4", - "num-integer", - "num-traits", - "rawpointer", -] - [[package]] name = "ndk" version = "0.7.0" @@ -6111,45 +6063,20 @@ dependencies = [ "winapi", ] -[[package]] -name = "num" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36" -dependencies = [ - "num-bigint 0.2.6", - "num-complex 0.2.4", - "num-integer", - "num-iter", - "num-rational 0.2.4", - "num-traits", -] - [[package]] name = "num" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" dependencies = [ - "num-bigint 0.4.4", - "num-complex 0.4.4", + "num-bigint", + "num-complex", "num-integer", "num-iter", "num-rational 0.4.1", "num-traits", ] -[[package]] -name = "num-bigint" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "090c7f9998ee0ff65aa5b723e4009f7b217707f1fb5ea551329cc4d6231fb304" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - [[package]] name = "num-bigint" version = "0.4.4" @@ -6196,16 +6123,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "num-complex" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6b19411a9719e753aff12e5187b74d60d3dc449ec3f4dc21e3989c3f554bc95" -dependencies = [ - "autocfg", - "num-traits", -] - [[package]] name = "num-complex" version = "0.4.4" @@ -6247,18 +6164,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "num-rational" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c000134b5dbf44adc5cb772486d335293351644b801551abe8f75c84cfa4aef" -dependencies = [ - "autocfg", - "num-bigint 0.2.6", - "num-integer", - "num-traits", -] - [[package]] name = "num-rational" version = "0.3.2" @@ -6277,7 +6182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" dependencies = [ "autocfg", - "num-bigint 0.4.4", + "num-bigint", "num-integer", "num-traits", ] @@ -6436,7 +6341,7 @@ dependencies = [ "futures-util", "hkdf", "hmac 0.12.1", - "num 0.4.1", + "num", "num-bigint-dig 0.8.4", "pbkdf2 0.12.2", "rand 0.8.5", @@ -6464,6 +6369,18 @@ dependencies = [ "pathdiff", ] +[[package]] +name = "open_ai" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "schemars", + "serde", + "serde_json", + "util", +] + [[package]] name = "openssl" version = "0.10.57" @@ -6679,17 +6596,6 @@ dependencies = [ "windows-targets 0.48.5", ] -[[package]] -name = "parse_duration" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7037e5e93e0172a5a96874380bf73bc6ecef022e26fa25f2be26864d6b3ba95d" -dependencies = [ - "lazy_static", - "num 0.2.1", - "regex", -] - [[package]] name = "password-hash" version = "0.2.1" @@ -7471,12 +7377,6 @@ dependencies = [ "raw-window-handle 0.5.2", ] -[[package]] -name = "rawpointer" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" - [[package]] name = "rayon" version = "1.8.0" @@ -7935,20 +7835,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "rusqlite" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "549b9d036d571d42e6e85d1c1425e2ac83491075078ca9a15be021c56b1641f2" -dependencies = [ - "bitflags 2.4.2", - "fallible-iterator 0.2.0", - "fallible-streaming-iterator", - "hashlink", - "libsqlite3-sys", - "smallvec", -] - [[package]] name = "rust-embed" version = "8.2.0" @@ -8378,7 +8264,6 @@ dependencies = [ "language", "menu", "project", - "semantic_index", "serde", "serde_json", "settings", @@ -8434,52 +8319,6 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58bf37232d3bb9a2c4e641ca2a11d83b5062066f88df7fed36c28772046d65ba" -[[package]] -name = "semantic_index" -version = "0.1.0" -dependencies = [ - "ai", - "anyhow", - "collections", - "ctor", - "env_logger", - "futures 0.3.28", - "gpui", - "language", - "lazy_static", - "log", - "ndarray", - "ordered-float 2.10.0", - "parking_lot", - "postage", - "pretty_assertions", - "project", - "rand 0.8.5", - "release_channel", - "rpc", - "rusqlite", - "schemars", - "serde", - "serde_json", - "settings", - "sha1", - "smol", - "tempfile", - "tree-sitter", - "tree-sitter-cpp", - "tree-sitter-elixir", - "tree-sitter-json 0.20.0", - "tree-sitter-lua", - "tree-sitter-php", - "tree-sitter-ruby", - "tree-sitter-rust", - "tree-sitter-toml", - "tree-sitter-typescript", - "unindent", - "util", - "workspace", -] - [[package]] name = "semver" version = "1.0.18" @@ -8766,7 +8605,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eb4ea60fb301dc81dfc113df680571045d375ab7345d171c5dc7d7e13107a80" dependencies = [ "chrono", - "num-bigint 0.4.4", + "num-bigint", "num-traits", "thiserror", ] @@ -9197,7 +9036,7 @@ dependencies = [ "log", "md-5", "memchr", - "num-bigint 0.4.4", + "num-bigint", "once_cell", "rand 0.8.5", "rust_decimal", @@ -12729,7 +12568,6 @@ dependencies = [ "release_channel", "rope", "search", - "semantic_index", "serde", "serde_json", "settings", diff --git a/Cargo.toml b/Cargo.toml index e30174873fe0e7..1678c3e21f113a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,6 @@ [workspace] members = [ "crates/activity_indicator", - "crates/ai", "crates/assets", "crates/assistant", "crates/audio", @@ -34,6 +33,7 @@ members = [ "crates/fuzzy", "crates/git", "crates/go_to_line", + "crates/google_ai", "crates/gpui", "crates/gpui_macros", "crates/image_viewer", @@ -52,6 +52,7 @@ members = [ "crates/multi_buffer", "crates/node_runtime", "crates/notifications", + "crates/open_ai", "crates/outline", "crates/picker", "crates/prettier", @@ -69,7 +70,6 @@ members = [ "crates/task", "crates/tasks_ui", "crates/search", - "crates/semantic_index", "crates/settings", "crates/snippet", "crates/sqlez", @@ -138,6 +138,7 @@ fsevent = { path = "crates/fsevent" } fuzzy = { path = "crates/fuzzy" } git = { path = "crates/git" } go_to_line = { path = "crates/go_to_line" } +google_ai = { path = "crates/google_ai" } gpui = { path = "crates/gpui" } gpui_macros = { path = "crates/gpui_macros" } install_cli = { path = "crates/install_cli" } @@ -156,6 +157,7 @@ menu = { path = "crates/menu" } multi_buffer = { path = "crates/multi_buffer" } node_runtime = { path = "crates/node_runtime" } notifications = { path = "crates/notifications" } +open_ai = { path = "crates/open_ai" } outline = { path = "crates/outline" } picker = { path = "crates/picker" } plugin = { path = "crates/plugin" } @@ -174,7 +176,6 @@ rpc = { path = "crates/rpc" } task = { path = "crates/task" } tasks_ui = { path = "crates/tasks_ui" } search = { path = "crates/search" } -semantic_index = { path = "crates/semantic_index" } settings = { path = "crates/settings" } snippet = { path = "crates/snippet" } sqlez = { path = "crates/sqlez" } diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index 5620a0362f06a8..42fdc0d78c7053 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -251,7 +251,6 @@ "alt-tab": "search::CycleMode", "cmd-shift-h": "search::ToggleReplace", "alt-cmd-g": "search::ActivateRegexMode", - "alt-cmd-s": "search::ActivateSemanticMode", "alt-cmd-x": "search::ActivateTextMode" } }, @@ -276,7 +275,6 @@ "alt-tab": "search::CycleMode", "cmd-shift-h": "search::ToggleReplace", "alt-cmd-g": "search::ActivateRegexMode", - "alt-cmd-s": "search::ActivateSemanticMode", "alt-cmd-x": "search::ActivateTextMode" } }, @@ -302,7 +300,6 @@ "alt-tab": "search::CycleMode", "alt-cmd-f": "project_search::ToggleFilters", "alt-cmd-g": "search::ActivateRegexMode", - "alt-cmd-s": "search::ActivateSemanticMode", "alt-cmd-x": "search::ActivateTextMode" } }, diff --git a/assets/settings/default.json b/assets/settings/default.json index 6599298430a950..f93e22e059861b 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -237,6 +237,8 @@ "default_width": 380 }, "assistant": { + // Version of this setting. + "version": "1", // Whether to show the assistant panel button in the status bar. "button": true, // Where to dock the assistant panel. Can be 'left', 'right' or 'bottom'. @@ -245,28 +247,16 @@ "default_width": 640, // Default height when the assistant is docked to the bottom. "default_height": 320, - // Deprecated: Please use `provider.api_url` instead. - // The default OpenAI API endpoint to use when starting new conversations. - "openai_api_url": "https://api.openai.com/v1", - // Deprecated: Please use `provider.default_model` instead. - // The default OpenAI model to use when starting new conversations. This - // setting can take three values: - // - // 1. "gpt-3.5-turbo-0613"" - // 2. "gpt-4-0613"" - // 3. "gpt-4-1106-preview" - "default_open_ai_model": "gpt-4-1106-preview", + // AI provider. "provider": { - "type": "openai", - // The default OpenAI API endpoint to use when starting new conversations. - "api_url": "https://api.openai.com/v1", - // The default OpenAI model to use when starting new conversations. This + "name": "openai", + // The default model to use when starting new conversations. This // setting can take three values: // - // 1. "gpt-3.5-turbo-0613"" - // 2. "gpt-4-0613"" - // 3. "gpt-4-1106-preview" - "default_model": "gpt-4-1106-preview" + // 1. "gpt-3.5-turbo" + // 2. "gpt-4" + // 3. "gpt-4-turbo-preview" + "default_model": "gpt-4-turbo-preview" } }, // Whether the screen sharing icon is shown in the os status bar. @@ -505,10 +495,6 @@ // Existing terminals will not pick up this change until they are recreated. // "max_scroll_history_lines": 10000, }, - // Difference settings for semantic_index - "semantic_index": { - "enabled": true - }, // Settings specific to our elixir integration "elixir": { // Change the LSP zed uses for elixir. diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml deleted file mode 100644 index 69c3b88e62a3c3..00000000000000 --- a/crates/ai/Cargo.toml +++ /dev/null @@ -1,41 +0,0 @@ -[package] -name = "ai" -version = "0.1.0" -edition = "2021" -publish = false -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/ai.rs" -doctest = false - -[features] -test-support = [] - -[dependencies] -anyhow.workspace = true -async-trait.workspace = true -bincode = "1.3.3" -futures.workspace = true -gpui.workspace = true -isahc.workspace = true -language.workspace = true -log.workspace = true -matrixmultiply = "0.3.7" -ordered-float.workspace = true -parking_lot.workspace = true -parse_duration = "2.1.1" -postage.workspace = true -rand.workspace = true -rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] } -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -tiktoken-rs.workspace = true -util.workspace = true - -[dev-dependencies] -gpui = { workspace = true, features = ["test-support"] } diff --git a/crates/ai/LICENSE-GPL b/crates/ai/LICENSE-GPL deleted file mode 120000 index 89e542f750cd38..00000000000000 --- a/crates/ai/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs deleted file mode 100644 index dda22d2a1d04dd..00000000000000 --- a/crates/ai/src/ai.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub mod auth; -pub mod completion; -pub mod embedding; -pub mod models; -pub mod prompts; -pub mod providers; -#[cfg(any(test, feature = "test-support"))] -pub mod test; diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs deleted file mode 100644 index 62556d718360a9..00000000000000 --- a/crates/ai/src/auth.rs +++ /dev/null @@ -1,23 +0,0 @@ -use futures::future::BoxFuture; -use gpui::AppContext; - -#[derive(Clone, Debug)] -pub enum ProviderCredential { - Credentials { api_key: String }, - NoCredentials, - NotNeeded, -} - -pub trait CredentialProvider: Send + Sync { - fn has_credentials(&self) -> bool; - #[must_use] - fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture; - #[must_use] - fn save_credentials( - &self, - cx: &mut AppContext, - credential: ProviderCredential, - ) -> BoxFuture<()>; - #[must_use] - fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()>; -} diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs deleted file mode 100644 index 30a60fcf1d5c5d..00000000000000 --- a/crates/ai/src/completion.rs +++ /dev/null @@ -1,23 +0,0 @@ -use anyhow::Result; -use futures::{future::BoxFuture, stream::BoxStream}; - -use crate::{auth::CredentialProvider, models::LanguageModel}; - -pub trait CompletionRequest: Send + Sync { - fn data(&self) -> serde_json::Result; -} - -pub trait CompletionProvider: CredentialProvider { - fn base_model(&self) -> Box; - fn complete( - &self, - prompt: Box, - ) -> BoxFuture<'static, Result>>>; - fn box_clone(&self) -> Box; -} - -impl Clone for Box { - fn clone(&self) -> Box { - self.box_clone() - } -} diff --git a/crates/ai/src/embedding.rs b/crates/ai/src/embedding.rs deleted file mode 100644 index 49611e002af64c..00000000000000 --- a/crates/ai/src/embedding.rs +++ /dev/null @@ -1,121 +0,0 @@ -use std::time::Instant; - -use anyhow::Result; -use async_trait::async_trait; -use ordered_float::OrderedFloat; -use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}; -use rusqlite::ToSql; - -use crate::auth::CredentialProvider; -use crate::models::LanguageModel; - -#[derive(Debug, PartialEq, Clone)] -pub struct Embedding(pub Vec); - -// This is needed for semantic index functionality -// Unfortunately it has to live wherever the "Embedding" struct is created. -// Keeping this in here though, introduces a 'rusqlite' dependency into AI -// which is less than ideal -impl FromSql for Embedding { - fn column_result(value: ValueRef) -> FromSqlResult { - let bytes = value.as_blob()?; - let embedding = - bincode::deserialize(bytes).map_err(|err| rusqlite::types::FromSqlError::Other(err))?; - Ok(Embedding(embedding)) - } -} - -impl ToSql for Embedding { - fn to_sql(&self) -> rusqlite::Result { - let bytes = bincode::serialize(&self.0) - .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?; - Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes))) - } -} -impl From> for Embedding { - fn from(value: Vec) -> Self { - Embedding(value) - } -} - -impl Embedding { - pub fn similarity(&self, other: &Self) -> OrderedFloat { - let len = self.0.len(); - assert_eq!(len, other.0.len()); - - let mut result = 0.0; - unsafe { - matrixmultiply::sgemm( - 1, - len, - 1, - 1.0, - self.0.as_ptr(), - len as isize, - 1, - other.0.as_ptr(), - 1, - len as isize, - 0.0, - &mut result as *mut f32, - 1, - 1, - ); - } - OrderedFloat(result) - } -} - -#[async_trait] -pub trait EmbeddingProvider: CredentialProvider { - fn base_model(&self) -> Box; - async fn embed_batch(&self, spans: Vec) -> Result>; - fn max_tokens_per_batch(&self) -> usize; - fn rate_limit_expiration(&self) -> Option; -} - -#[cfg(test)] -mod tests { - use super::*; - use rand::prelude::*; - - #[gpui::test] - fn test_similarity(mut rng: StdRng) { - assert_eq!( - Embedding::from(vec![1., 0., 0., 0., 0.]) - .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])), - 0. - ); - assert_eq!( - Embedding::from(vec![2., 0., 0., 0., 0.]) - .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])), - 6. - ); - - for _ in 0..100 { - let size = 1536; - let mut a = vec![0.; size]; - let mut b = vec![0.; size]; - for (a, b) in a.iter_mut().zip(b.iter_mut()) { - *a = rng.gen(); - *b = rng.gen(); - } - let a = Embedding::from(a); - let b = Embedding::from(b); - - assert_eq!( - round_to_decimals(a.similarity(&b), 1), - round_to_decimals(reference_dot(&a.0, &b.0), 1) - ); - } - - fn round_to_decimals(n: OrderedFloat, decimal_places: i32) -> f32 { - let factor = 10.0_f32.powi(decimal_places); - (n * factor).round() / factor - } - - fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat { - OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum()) - } - } -} diff --git a/crates/ai/src/models.rs b/crates/ai/src/models.rs deleted file mode 100644 index 1db3d58c6f54ad..00000000000000 --- a/crates/ai/src/models.rs +++ /dev/null @@ -1,16 +0,0 @@ -pub enum TruncationDirection { - Start, - End, -} - -pub trait LanguageModel { - fn name(&self) -> String; - fn count_tokens(&self, content: &str) -> anyhow::Result; - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result; - fn capacity(&self) -> anyhow::Result; -} diff --git a/crates/ai/src/prompts/base.rs b/crates/ai/src/prompts/base.rs deleted file mode 100644 index da7f7070404c9c..00000000000000 --- a/crates/ai/src/prompts/base.rs +++ /dev/null @@ -1,337 +0,0 @@ -use std::cmp::Reverse; -use std::ops::Range; -use std::sync::Arc; - -use language::BufferSnapshot; -use util::ResultExt; - -use crate::models::LanguageModel; -use crate::prompts::repository_context::PromptCodeSnippet; - -pub(crate) enum PromptFileType { - Text, - Code, -} - -// TODO: Set this up to manage for defaults well -pub struct PromptArguments { - pub model: Arc, - pub user_prompt: Option, - pub language_name: Option, - pub project_name: Option, - pub snippets: Vec, - pub reserved_tokens: usize, - pub buffer: Option, - pub selected_range: Option>, -} - -impl PromptArguments { - pub(crate) fn get_file_type(&self) -> PromptFileType { - if self - .language_name - .as_ref() - .map(|name| !["Markdown", "Plain Text"].contains(&name.as_str())) - .unwrap_or(true) - { - PromptFileType::Code - } else { - PromptFileType::Text - } - } -} - -pub trait PromptTemplate { - fn generate( - &self, - args: &PromptArguments, - max_token_length: Option, - ) -> anyhow::Result<(String, usize)>; -} - -#[repr(i8)] -#[derive(PartialEq, Eq)] -pub enum PromptPriority { - /// Ignores truncation. - Mandatory, - /// Truncates based on priority. - Ordered { order: usize }, -} - -impl PartialOrd for PromptPriority { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for PromptPriority { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - match (self, other) { - (Self::Mandatory, Self::Mandatory) => std::cmp::Ordering::Equal, - (Self::Mandatory, Self::Ordered { .. }) => std::cmp::Ordering::Greater, - (Self::Ordered { .. }, Self::Mandatory) => std::cmp::Ordering::Less, - (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.cmp(a), - } - } -} - -pub struct PromptChain { - args: PromptArguments, - templates: Vec<(PromptPriority, Box)>, -} - -impl PromptChain { - pub fn new( - args: PromptArguments, - templates: Vec<(PromptPriority, Box)>, - ) -> Self { - PromptChain { args, templates } - } - - pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> { - // Argsort based on Prompt Priority - let separator = "\n"; - let separator_tokens = self.args.model.count_tokens(separator)?; - let mut sorted_indices = (0..self.templates.len()).collect::>(); - sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0)); - - let mut tokens_outstanding = if truncate { - Some(self.args.model.capacity()? - self.args.reserved_tokens) - } else { - None - }; - - let mut prompts = vec!["".to_string(); sorted_indices.len()]; - for idx in sorted_indices { - let (_, template) = &self.templates[idx]; - - if let Some((template_prompt, prompt_token_count)) = - template.generate(&self.args, tokens_outstanding).log_err() - { - if template_prompt != "" { - prompts[idx] = template_prompt; - - if let Some(remaining_tokens) = tokens_outstanding { - let new_tokens = prompt_token_count + separator_tokens; - tokens_outstanding = if remaining_tokens > new_tokens { - Some(remaining_tokens - new_tokens) - } else { - Some(0) - }; - } - } - } - } - - prompts.retain(|x| x != ""); - - let full_prompt = prompts.join(separator); - let total_token_count = self.args.model.count_tokens(&full_prompt)?; - anyhow::Ok((prompts.join(separator), total_token_count)) - } -} - -#[cfg(test)] -pub(crate) mod tests { - use crate::models::TruncationDirection; - use crate::test::FakeLanguageModel; - - use super::*; - - #[test] - pub fn test_prompt_chain() { - struct TestPromptTemplate {} - impl PromptTemplate for TestPromptTemplate { - fn generate( - &self, - args: &PromptArguments, - max_token_length: Option, - ) -> anyhow::Result<(String, usize)> { - let mut content = "This is a test prompt template".to_string(); - - let mut token_count = args.model.count_tokens(&content)?; - if let Some(max_token_length) = max_token_length { - if token_count > max_token_length { - content = args.model.truncate( - &content, - max_token_length, - TruncationDirection::End, - )?; - token_count = max_token_length; - } - } - - anyhow::Ok((content, token_count)) - } - } - - struct TestLowPriorityTemplate {} - impl PromptTemplate for TestLowPriorityTemplate { - fn generate( - &self, - args: &PromptArguments, - max_token_length: Option, - ) -> anyhow::Result<(String, usize)> { - let mut content = "This is a low priority test prompt template".to_string(); - - let mut token_count = args.model.count_tokens(&content)?; - if let Some(max_token_length) = max_token_length { - if token_count > max_token_length { - content = args.model.truncate( - &content, - max_token_length, - TruncationDirection::End, - )?; - token_count = max_token_length; - } - } - - anyhow::Ok((content, token_count)) - } - } - - let model: Arc = Arc::new(FakeLanguageModel { capacity: 100 }); - let args = PromptArguments { - model: model.clone(), - language_name: None, - project_name: None, - snippets: Vec::new(), - reserved_tokens: 0, - buffer: None, - selected_range: None, - user_prompt: None, - }; - - let templates: Vec<(PromptPriority, Box)> = vec![ - ( - PromptPriority::Ordered { order: 0 }, - Box::new(TestPromptTemplate {}), - ), - ( - PromptPriority::Ordered { order: 1 }, - Box::new(TestLowPriorityTemplate {}), - ), - ]; - let chain = PromptChain::new(args, templates); - - let (prompt, token_count) = chain.generate(false).unwrap(); - - assert_eq!( - prompt, - "This is a test prompt template\nThis is a low priority test prompt template" - .to_string() - ); - - assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); - - // Testing with Truncation Off - // Should ignore capacity and return all prompts - let model: Arc = Arc::new(FakeLanguageModel { capacity: 20 }); - let args = PromptArguments { - model: model.clone(), - language_name: None, - project_name: None, - snippets: Vec::new(), - reserved_tokens: 0, - buffer: None, - selected_range: None, - user_prompt: None, - }; - - let templates: Vec<(PromptPriority, Box)> = vec![ - ( - PromptPriority::Ordered { order: 0 }, - Box::new(TestPromptTemplate {}), - ), - ( - PromptPriority::Ordered { order: 1 }, - Box::new(TestLowPriorityTemplate {}), - ), - ]; - let chain = PromptChain::new(args, templates); - - let (prompt, token_count) = chain.generate(false).unwrap(); - - assert_eq!( - prompt, - "This is a test prompt template\nThis is a low priority test prompt template" - .to_string() - ); - - assert_eq!(model.count_tokens(&prompt).unwrap(), token_count); - - // Testing with Truncation Off - // Should ignore capacity and return all prompts - let capacity = 20; - let model: Arc = Arc::new(FakeLanguageModel { capacity }); - let args = PromptArguments { - model: model.clone(), - language_name: None, - project_name: None, - snippets: Vec::new(), - reserved_tokens: 0, - buffer: None, - selected_range: None, - user_prompt: None, - }; - - let templates: Vec<(PromptPriority, Box)> = vec![ - ( - PromptPriority::Ordered { order: 0 }, - Box::new(TestPromptTemplate {}), - ), - ( - PromptPriority::Ordered { order: 1 }, - Box::new(TestLowPriorityTemplate {}), - ), - ( - PromptPriority::Ordered { order: 2 }, - Box::new(TestLowPriorityTemplate {}), - ), - ]; - let chain = PromptChain::new(args, templates); - - let (prompt, token_count) = chain.generate(true).unwrap(); - - assert_eq!(prompt, "This is a test promp".to_string()); - assert_eq!(token_count, capacity); - - // Change Ordering of Prompts Based on Priority - let capacity = 120; - let reserved_tokens = 10; - let model: Arc = Arc::new(FakeLanguageModel { capacity }); - let args = PromptArguments { - model: model.clone(), - language_name: None, - project_name: None, - snippets: Vec::new(), - reserved_tokens, - buffer: None, - selected_range: None, - user_prompt: None, - }; - let templates: Vec<(PromptPriority, Box)> = vec![ - ( - PromptPriority::Mandatory, - Box::new(TestLowPriorityTemplate {}), - ), - ( - PromptPriority::Ordered { order: 0 }, - Box::new(TestPromptTemplate {}), - ), - ( - PromptPriority::Ordered { order: 1 }, - Box::new(TestLowPriorityTemplate {}), - ), - ]; - let chain = PromptChain::new(args, templates); - - let (prompt, token_count) = chain.generate(true).unwrap(); - - assert_eq!( - prompt, - "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt " - .to_string() - ); - assert_eq!(token_count, capacity - reserved_tokens); - } -} diff --git a/crates/ai/src/prompts/file_context.rs b/crates/ai/src/prompts/file_context.rs deleted file mode 100644 index f108a62f6f0f82..00000000000000 --- a/crates/ai/src/prompts/file_context.rs +++ /dev/null @@ -1,164 +0,0 @@ -use anyhow::anyhow; -use language::BufferSnapshot; -use language::ToOffset; - -use crate::models::LanguageModel; -use crate::models::TruncationDirection; -use crate::prompts::base::PromptArguments; -use crate::prompts::base::PromptTemplate; -use std::fmt::Write; -use std::ops::Range; -use std::sync::Arc; - -fn retrieve_context( - buffer: &BufferSnapshot, - selected_range: &Option>, - model: Arc, - max_token_count: Option, -) -> anyhow::Result<(String, usize, bool)> { - let mut prompt = String::new(); - let mut truncated = false; - if let Some(selected_range) = selected_range { - let start = selected_range.start.to_offset(buffer); - let end = selected_range.end.to_offset(buffer); - - let start_window = buffer.text_for_range(0..start).collect::(); - - let mut selected_window = String::new(); - if start == end { - write!(selected_window, "<|START|>").unwrap(); - } else { - write!(selected_window, "<|START|").unwrap(); - } - - write!( - selected_window, - "{}", - buffer.text_for_range(start..end).collect::() - ) - .unwrap(); - - if start != end { - write!(selected_window, "|END|>").unwrap(); - } - - let end_window = buffer.text_for_range(end..buffer.len()).collect::(); - - if let Some(max_token_count) = max_token_count { - let selected_tokens = model.count_tokens(&selected_window)?; - if selected_tokens > max_token_count { - return Err(anyhow!( - "selected range is greater than model context window, truncation not possible" - )); - }; - - let mut remaining_tokens = max_token_count - selected_tokens; - let start_window_tokens = model.count_tokens(&start_window)?; - let end_window_tokens = model.count_tokens(&end_window)?; - let outside_tokens = start_window_tokens + end_window_tokens; - if outside_tokens > remaining_tokens { - let (start_goal_tokens, end_goal_tokens) = - if start_window_tokens < end_window_tokens { - let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens); - remaining_tokens -= start_goal_tokens; - let end_goal_tokens = remaining_tokens.min(end_window_tokens); - (start_goal_tokens, end_goal_tokens) - } else { - let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens); - remaining_tokens -= end_goal_tokens; - let start_goal_tokens = remaining_tokens.min(start_window_tokens); - (start_goal_tokens, end_goal_tokens) - }; - - let truncated_start_window = - model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?; - let truncated_end_window = - model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?; - writeln!( - prompt, - "{truncated_start_window}{selected_window}{truncated_end_window}" - ) - .unwrap(); - truncated = true; - } else { - writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap(); - } - } else { - // If we dont have a selected range, include entire file. - writeln!(prompt, "{}", &buffer.text()).unwrap(); - - // Dumb truncation strategy - if let Some(max_token_count) = max_token_count { - if model.count_tokens(&prompt)? > max_token_count { - truncated = true; - prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?; - } - } - } - } - - let token_count = model.count_tokens(&prompt)?; - anyhow::Ok((prompt, token_count, truncated)) -} - -pub struct FileContext {} - -impl PromptTemplate for FileContext { - fn generate( - &self, - args: &PromptArguments, - max_token_length: Option, - ) -> anyhow::Result<(String, usize)> { - if let Some(buffer) = &args.buffer { - let mut prompt = String::new(); - // Add Initial Preamble - // TODO: Do we want to add the path in here? - writeln!( - prompt, - "The file you are currently working on has the following content:" - ) - .unwrap(); - - let language_name = args - .language_name - .clone() - .unwrap_or("".to_string()) - .to_lowercase(); - - let (context, _, truncated) = retrieve_context( - buffer, - &args.selected_range, - args.model.clone(), - max_token_length, - )?; - writeln!(prompt, "```{language_name}\n{context}\n```").unwrap(); - - if truncated { - writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap(); - } - - if let Some(selected_range) = &args.selected_range { - let start = selected_range.start.to_offset(buffer); - let end = selected_range.end.to_offset(buffer); - - if start == end { - writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap(); - } else { - writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap(); - } - } - - // Really dumb truncation strategy - if let Some(max_tokens) = max_token_length { - prompt = args - .model - .truncate(&prompt, max_tokens, TruncationDirection::End)?; - } - - let token_count = args.model.count_tokens(&prompt)?; - anyhow::Ok((prompt, token_count)) - } else { - Err(anyhow!("no buffer provided to retrieve file context from")) - } - } -} diff --git a/crates/ai/src/prompts/generate.rs b/crates/ai/src/prompts/generate.rs deleted file mode 100644 index c7be620107ee4d..00000000000000 --- a/crates/ai/src/prompts/generate.rs +++ /dev/null @@ -1,99 +0,0 @@ -use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; -use anyhow::anyhow; -use std::fmt::Write; - -pub fn capitalize(s: &str) -> String { - let mut c = s.chars(); - match c.next() { - None => String::new(), - Some(f) => f.to_uppercase().collect::() + c.as_str(), - } -} - -pub struct GenerateInlineContent {} - -impl PromptTemplate for GenerateInlineContent { - fn generate( - &self, - args: &PromptArguments, - max_token_length: Option, - ) -> anyhow::Result<(String, usize)> { - let Some(user_prompt) = &args.user_prompt else { - return Err(anyhow!("user prompt not provided")); - }; - - let file_type = args.get_file_type(); - let content_type = match &file_type { - PromptFileType::Code => "code", - PromptFileType::Text => "text", - }; - - let mut prompt = String::new(); - - if let Some(selected_range) = &args.selected_range { - if selected_range.start == selected_range.end { - writeln!( - prompt, - "Assume the cursor is located where the `<|START|>` span is." - ) - .unwrap(); - writeln!( - prompt, - "{} can't be replaced, so assume your answer will be inserted at the cursor.", - capitalize(content_type) - ) - .unwrap(); - writeln!( - prompt, - "Generate {content_type} based on the users prompt: {user_prompt}", - ) - .unwrap(); - } else { - writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); - writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); - writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap(); - } - } else { - writeln!( - prompt, - "Generate {content_type} based on the users prompt: {user_prompt}" - ) - .unwrap(); - } - - if let Some(language_name) = &args.language_name { - writeln!( - prompt, - "Your answer MUST always and only be valid {}.", - language_name - ) - .unwrap(); - } - writeln!(prompt, "Never make remarks about the output.").unwrap(); - writeln!( - prompt, - "Do not return anything else, except the generated {content_type}." - ) - .unwrap(); - - match file_type { - PromptFileType::Code => { - // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap(); - } - _ => {} - } - - // Really dumb truncation strategy - if let Some(max_tokens) = max_token_length { - prompt = args.model.truncate( - &prompt, - max_tokens, - crate::models::TruncationDirection::End, - )?; - } - - let token_count = args.model.count_tokens(&prompt)?; - - anyhow::Ok((prompt, token_count)) - } -} diff --git a/crates/ai/src/prompts/mod.rs b/crates/ai/src/prompts/mod.rs deleted file mode 100644 index 0025269a440d1e..00000000000000 --- a/crates/ai/src/prompts/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod base; -pub mod file_context; -pub mod generate; -pub mod preamble; -pub mod repository_context; diff --git a/crates/ai/src/prompts/preamble.rs b/crates/ai/src/prompts/preamble.rs deleted file mode 100644 index 92e0edeb78b481..00000000000000 --- a/crates/ai/src/prompts/preamble.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate}; -use std::fmt::Write; - -pub struct EngineerPreamble {} - -impl PromptTemplate for EngineerPreamble { - fn generate( - &self, - args: &PromptArguments, - max_token_length: Option, - ) -> anyhow::Result<(String, usize)> { - let mut prompts = Vec::new(); - - match args.get_file_type() { - PromptFileType::Code => { - prompts.push(format!( - "You are an expert {}engineer.", - args.language_name.clone().unwrap_or("".to_string()) + " " - )); - } - PromptFileType::Text => { - prompts.push("You are an expert engineer.".to_string()); - } - } - - if let Some(project_name) = args.project_name.clone() { - prompts.push(format!( - "You are currently working inside the '{project_name}' project in code editor Zed." - )); - } - - if let Some(mut remaining_tokens) = max_token_length { - let mut prompt = String::new(); - let mut total_count = 0; - for prompt_piece in prompts { - let prompt_token_count = - args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?; - if remaining_tokens > prompt_token_count { - writeln!(prompt, "{prompt_piece}").unwrap(); - remaining_tokens -= prompt_token_count; - total_count += prompt_token_count; - } - } - - anyhow::Ok((prompt, total_count)) - } else { - let prompt = prompts.join("\n"); - let token_count = args.model.count_tokens(&prompt)?; - anyhow::Ok((prompt, token_count)) - } - } -} diff --git a/crates/ai/src/prompts/repository_context.rs b/crates/ai/src/prompts/repository_context.rs deleted file mode 100644 index b31a3f63c2f92d..00000000000000 --- a/crates/ai/src/prompts/repository_context.rs +++ /dev/null @@ -1,96 +0,0 @@ -use crate::prompts::base::{PromptArguments, PromptTemplate}; -use std::fmt::Write; -use std::{ops::Range, path::PathBuf}; - -use gpui::{AsyncAppContext, Model}; -use language::{Anchor, Buffer}; - -#[derive(Clone)] -pub struct PromptCodeSnippet { - path: Option, - language_name: Option, - content: String, -} - -impl PromptCodeSnippet { - pub fn new( - buffer: Model, - range: Range, - cx: &mut AsyncAppContext, - ) -> anyhow::Result { - let (content, language_name, file_path) = buffer.update(cx, |buffer, _| { - let snapshot = buffer.snapshot(); - let content = snapshot.text_for_range(range.clone()).collect::(); - - let language_name = buffer - .language() - .map(|language| language.name().to_string().to_lowercase()); - - let file_path = buffer.file().map(|file| file.path().to_path_buf()); - - (content, language_name, file_path) - })?; - - anyhow::Ok(PromptCodeSnippet { - path: file_path, - language_name, - content, - }) - } -} - -impl ToString for PromptCodeSnippet { - fn to_string(&self) -> String { - let path = self - .path - .as_ref() - .map(|path| path.to_string_lossy().to_string()) - .unwrap_or("".to_string()); - let language_name = self.language_name.clone().unwrap_or("".to_string()); - let content = self.content.clone(); - - format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```") - } -} - -pub struct RepositoryContext {} - -impl PromptTemplate for RepositoryContext { - fn generate( - &self, - args: &PromptArguments, - max_token_length: Option, - ) -> anyhow::Result<(String, usize)> { - const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500; - let template = "You are working inside a large repository, here are a few code snippets that may be useful."; - let mut prompt = String::new(); - - let mut remaining_tokens = max_token_length; - let separator_token_length = args.model.count_tokens("\n")?; - for snippet in &args.snippets { - let mut snippet_prompt = template.to_string(); - let content = snippet.to_string(); - writeln!(snippet_prompt, "{content}").unwrap(); - - let token_count = args.model.count_tokens(&snippet_prompt)?; - if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT { - if let Some(tokens_left) = remaining_tokens { - if tokens_left >= token_count { - writeln!(prompt, "{snippet_prompt}").unwrap(); - remaining_tokens = if tokens_left >= (token_count + separator_token_length) - { - Some(tokens_left - token_count - separator_token_length) - } else { - Some(0) - }; - } - } else { - writeln!(prompt, "{snippet_prompt}").unwrap(); - } - } - } - - let total_token_count = args.model.count_tokens(&prompt)?; - anyhow::Ok((prompt, total_token_count)) - } -} diff --git a/crates/ai/src/providers.rs b/crates/ai/src/providers.rs deleted file mode 100644 index acd0f9d9105386..00000000000000 --- a/crates/ai/src/providers.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod open_ai; diff --git a/crates/ai/src/providers/open_ai.rs b/crates/ai/src/providers/open_ai.rs deleted file mode 100644 index 8aff4877a86b80..00000000000000 --- a/crates/ai/src/providers/open_ai.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod completion; -pub mod embedding; -pub mod model; - -pub use completion::*; -pub use embedding::*; -pub use model::OpenAiLanguageModel; - -pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs deleted file mode 100644 index 04cc3588948e00..00000000000000 --- a/crates/ai/src/providers/open_ai/completion.rs +++ /dev/null @@ -1,421 +0,0 @@ -use std::{ - env, - fmt::{self, Display}, - io, - sync::Arc, -}; - -use anyhow::{anyhow, Result}; -use futures::{ - future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt, - Stream, StreamExt, -}; -use gpui::{AppContext, BackgroundExecutor}; -use isahc::{http::StatusCode, Request, RequestExt}; -use parking_lot::RwLock; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use util::ResultExt; - -use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL}; -use crate::{ - auth::{CredentialProvider, ProviderCredential}, - completion::{CompletionProvider, CompletionRequest}, - models::LanguageModel, -}; - -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, -} - -impl Role { - pub fn cycle(&mut self) { - *self = match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "User"), - Role::Assistant => write!(f, "Assistant"), - Role::System => write!(f, "System"), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct RequestMessage { - pub role: Role, - pub content: String, -} - -#[derive(Debug, Default, Serialize)] -pub struct OpenAiRequest { - pub model: String, - pub messages: Vec, - pub stream: bool, - pub stop: Vec, - pub temperature: f32, -} - -impl CompletionRequest for OpenAiRequest { - fn data(&self) -> serde_json::Result { - serde_json::to_string(self) - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct ResponseMessage { - pub role: Option, - pub content: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAiUsage { - pub prompt_tokens: u32, - pub completion_tokens: u32, - pub total_tokens: u32, -} - -#[derive(Deserialize, Debug)] -pub struct ChatChoiceDelta { - pub index: u32, - pub delta: ResponseMessage, - pub finish_reason: Option, -} - -#[derive(Deserialize, Debug)] -pub struct OpenAiResponseStreamEvent { - pub id: Option, - pub object: String, - pub created: u32, - pub model: String, - pub choices: Vec, - pub usage: Option, -} - -async fn stream_completion( - api_url: String, - kind: OpenAiCompletionProviderKind, - credential: ProviderCredential, - executor: BackgroundExecutor, - request: Box, -) -> Result>> { - let api_key = match credential { - ProviderCredential::Credentials { api_key } => api_key, - _ => { - return Err(anyhow!("no credentials provider for completion")); - } - }; - - let (tx, rx) = futures::channel::mpsc::unbounded::>(); - - let (auth_header_name, auth_header_value) = kind.auth_header(api_key); - let json_data = request.data()?; - let mut response = Request::post(kind.completions_endpoint_url(&api_url)) - .header("Content-Type", "application/json") - .header(auth_header_name, auth_header_value) - .body(json_data)? - .send_async() - .await?; - - let status = response.status(); - if status == StatusCode::OK { - executor - .spawn(async move { - let mut lines = BufReader::new(response.body_mut()).lines(); - - fn parse_line( - line: Result, - ) -> Result> { - if let Some(data) = line?.strip_prefix("data: ") { - let event = serde_json::from_str(data)?; - Ok(Some(event)) - } else { - Ok(None) - } - } - - while let Some(line) = lines.next().await { - if let Some(event) = parse_line(line).transpose() { - let done = event.as_ref().map_or(false, |event| { - event - .choices - .last() - .map_or(false, |choice| choice.finish_reason.is_some()) - }); - if tx.unbounded_send(event).is_err() { - break; - } - - if done { - break; - } - } - } - - anyhow::Ok(()) - }) - .detach(); - - Ok(rx) - } else { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAiResponse { - error: OpenAiError, - } - - #[derive(Deserialize)] - struct OpenAiError { - message: String, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "Failed to connect to OpenAI API: {}", - response.error.message, - )), - - _ => Err(anyhow!( - "Failed to connect to OpenAI API: {} {}", - response.status(), - body, - )), - } - } -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)] -pub enum AzureOpenAiApiVersion { - /// Retiring April 2, 2024. - #[serde(rename = "2023-03-15-preview")] - V2023_03_15Preview, - #[serde(rename = "2023-05-15")] - V2023_05_15, - /// Retiring April 2, 2024. - #[serde(rename = "2023-06-01-preview")] - V2023_06_01Preview, - /// Retiring April 2, 2024. - #[serde(rename = "2023-07-01-preview")] - V2023_07_01Preview, - /// Retiring April 2, 2024. - #[serde(rename = "2023-08-01-preview")] - V2023_08_01Preview, - /// Retiring April 2, 2024. - #[serde(rename = "2023-09-01-preview")] - V2023_09_01Preview, - #[serde(rename = "2023-12-01-preview")] - V2023_12_01Preview, - #[serde(rename = "2024-02-15-preview")] - V2024_02_15Preview, -} - -impl fmt::Display for AzureOpenAiApiVersion { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}", - match self { - Self::V2023_03_15Preview => "2023-03-15-preview", - Self::V2023_05_15 => "2023-05-15", - Self::V2023_06_01Preview => "2023-06-01-preview", - Self::V2023_07_01Preview => "2023-07-01-preview", - Self::V2023_08_01Preview => "2023-08-01-preview", - Self::V2023_09_01Preview => "2023-09-01-preview", - Self::V2023_12_01Preview => "2023-12-01-preview", - Self::V2024_02_15Preview => "2024-02-15-preview", - } - ) - } -} - -#[derive(Clone)] -pub enum OpenAiCompletionProviderKind { - OpenAi, - AzureOpenAi { - deployment_id: String, - api_version: AzureOpenAiApiVersion, - }, -} - -impl OpenAiCompletionProviderKind { - /// Returns the chat completion endpoint URL for this [`OpenAiCompletionProviderKind`]. - fn completions_endpoint_url(&self, api_url: &str) -> String { - match self { - Self::OpenAi => { - // https://platform.openai.com/docs/api-reference/chat/create - format!("{api_url}/chat/completions") - } - Self::AzureOpenAi { - deployment_id, - api_version, - } => { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions - format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}") - } - } - } - - /// Returns the authentication header for this [`OpenAiCompletionProviderKind`]. - fn auth_header(&self, api_key: String) -> (&'static str, String) { - match self { - Self::OpenAi => ("Authorization", format!("Bearer {api_key}")), - Self::AzureOpenAi { .. } => ("Api-Key", api_key), - } - } -} - -#[derive(Clone)] -pub struct OpenAiCompletionProvider { - api_url: String, - kind: OpenAiCompletionProviderKind, - model: OpenAiLanguageModel, - credential: Arc>, - executor: BackgroundExecutor, -} - -impl OpenAiCompletionProvider { - pub async fn new( - api_url: String, - kind: OpenAiCompletionProviderKind, - model_name: String, - executor: BackgroundExecutor, - ) -> Self { - let model = executor - .spawn(async move { OpenAiLanguageModel::load(&model_name) }) - .await; - let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); - Self { - api_url, - kind, - model, - credential, - executor, - } - } -} - -impl CredentialProvider for OpenAiCompletionProvider { - fn has_credentials(&self) -> bool { - match *self.credential.read() { - ProviderCredential::Credentials { .. } => true, - _ => false, - } - } - - fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture { - let existing_credential = self.credential.read().clone(); - let retrieved_credential = match existing_credential { - ProviderCredential::Credentials { .. } => { - return async move { existing_credential }.boxed() - } - _ => { - if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { - async move { ProviderCredential::Credentials { api_key } }.boxed() - } else { - let credentials = cx.read_credentials(OPEN_AI_API_URL); - async move { - if let Some(Some((_, api_key))) = credentials.await.log_err() { - if let Some(api_key) = String::from_utf8(api_key).log_err() { - ProviderCredential::Credentials { api_key } - } else { - ProviderCredential::NoCredentials - } - } else { - ProviderCredential::NoCredentials - } - } - .boxed() - } - } - }; - - async move { - let retrieved_credential = retrieved_credential.await; - *self.credential.write() = retrieved_credential.clone(); - retrieved_credential - } - .boxed() - } - - fn save_credentials( - &self, - cx: &mut AppContext, - credential: ProviderCredential, - ) -> BoxFuture<()> { - *self.credential.write() = credential.clone(); - let credential = credential.clone(); - let write_credentials = match credential { - ProviderCredential::Credentials { api_key } => { - Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes())) - } - _ => None, - }; - - async move { - if let Some(write_credentials) = write_credentials { - write_credentials.await.log_err(); - } - } - .boxed() - } - - fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> { - *self.credential.write() = ProviderCredential::NoCredentials; - let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL); - async move { - delete_credentials.await.log_err(); - } - .boxed() - } -} - -impl CompletionProvider for OpenAiCompletionProvider { - fn base_model(&self) -> Box { - let model: Box = Box::new(self.model.clone()); - model - } - - fn complete( - &self, - prompt: Box, - ) -> BoxFuture<'static, Result>>> { - // Currently the CompletionRequest for OpenAI, includes a 'model' parameter - // This means that the model is determined by the CompletionRequest and not the CompletionProvider, - // which is currently model based, due to the language model. - // At some point in the future we should rectify this. - let credential = self.credential.read().clone(); - let api_url = self.api_url.clone(); - let kind = self.kind.clone(); - let request = stream_completion(api_url, kind, credential, self.executor.clone(), prompt); - async move { - let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) - } - .boxed() - } - - fn box_clone(&self) -> Box { - Box::new((*self).clone()) - } -} diff --git a/crates/ai/src/providers/open_ai/embedding.rs b/crates/ai/src/providers/open_ai/embedding.rs deleted file mode 100644 index ddff082359df30..00000000000000 --- a/crates/ai/src/providers/open_ai/embedding.rs +++ /dev/null @@ -1,345 +0,0 @@ -use anyhow::{anyhow, Result}; -use async_trait::async_trait; -use futures::future::BoxFuture; -use futures::AsyncReadExt; -use futures::FutureExt; -use gpui::AppContext; -use gpui::BackgroundExecutor; -use isahc::http::StatusCode; -use isahc::prelude::Configurable; -use isahc::{AsyncBody, Response}; -use parking_lot::{Mutex, RwLock}; -use parse_duration::parse; -use postage::watch; -use serde::{Deserialize, Serialize}; -use serde_json; -use std::env; -use std::ops::Add; -use std::sync::{Arc, OnceLock}; -use std::time::{Duration, Instant}; -use tiktoken_rs::{cl100k_base, CoreBPE}; -use util::http::{HttpClient, Request}; -use util::ResultExt; - -use crate::auth::{CredentialProvider, ProviderCredential}; -use crate::embedding::{Embedding, EmbeddingProvider}; -use crate::models::LanguageModel; -use crate::providers::open_ai::OpenAiLanguageModel; - -use crate::providers::open_ai::OPEN_AI_API_URL; - -pub(crate) fn open_ai_bpe_tokenizer() -> &'static CoreBPE { - static OPEN_AI_BPE_TOKENIZER: OnceLock = OnceLock::new(); - OPEN_AI_BPE_TOKENIZER.get_or_init(|| cl100k_base().unwrap()) -} - -#[derive(Clone)] -pub struct OpenAiEmbeddingProvider { - api_url: String, - model: OpenAiLanguageModel, - credential: Arc>, - pub client: Arc, - pub executor: BackgroundExecutor, - rate_limit_count_rx: watch::Receiver>, - rate_limit_count_tx: Arc>>>, -} - -#[derive(Serialize)] -struct OpenAiEmbeddingRequest<'a> { - model: &'static str, - input: Vec<&'a str>, -} - -#[derive(Deserialize)] -struct OpenAiEmbeddingResponse { - data: Vec, - usage: OpenAiEmbeddingUsage, -} - -#[derive(Debug, Deserialize)] -struct OpenAiEmbedding { - embedding: Vec, - index: usize, - object: String, -} - -#[derive(Deserialize)] -struct OpenAiEmbeddingUsage { - prompt_tokens: usize, - total_tokens: usize, -} - -impl OpenAiEmbeddingProvider { - pub async fn new( - api_url: String, - client: Arc, - executor: BackgroundExecutor, - ) -> Self { - let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None); - let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx)); - - // Loading the model is expensive, so ensure this runs off the main thread. - let model = executor - .spawn(async move { OpenAiLanguageModel::load("text-embedding-ada-002") }) - .await; - let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials)); - - OpenAiEmbeddingProvider { - api_url, - model, - credential, - client, - executor, - rate_limit_count_rx, - rate_limit_count_tx, - } - } - - fn get_api_key(&self) -> Result { - match self.credential.read().clone() { - ProviderCredential::Credentials { api_key } => Ok(api_key), - _ => Err(anyhow!("api credentials not provided")), - } - } - - fn resolve_rate_limit(&self) { - let reset_time = *self.rate_limit_count_tx.lock().borrow(); - - if let Some(reset_time) = reset_time { - if Instant::now() >= reset_time { - *self.rate_limit_count_tx.lock().borrow_mut() = None - } - } - - log::trace!( - "resolving reset time: {:?}", - *self.rate_limit_count_tx.lock().borrow() - ); - } - - fn update_reset_time(&self, reset_time: Instant) { - let original_time = *self.rate_limit_count_tx.lock().borrow(); - - let updated_time = if let Some(original_time) = original_time { - if reset_time < original_time { - Some(reset_time) - } else { - Some(original_time) - } - } else { - Some(reset_time) - }; - - log::trace!("updating rate limit time: {:?}", updated_time); - - *self.rate_limit_count_tx.lock().borrow_mut() = updated_time; - } - async fn send_request( - &self, - api_url: &str, - api_key: &str, - spans: Vec<&str>, - request_timeout: u64, - ) -> Result> { - let request = Request::post(format!("{api_url}/embeddings")) - .redirect_policy(isahc::config::RedirectPolicy::Follow) - .timeout(Duration::from_secs(request_timeout)) - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", api_key)) - .body( - serde_json::to_string(&OpenAiEmbeddingRequest { - input: spans.clone(), - model: "text-embedding-ada-002", - }) - .unwrap() - .into(), - )?; - - Ok(self.client.send(request).await?) - } -} - -impl CredentialProvider for OpenAiEmbeddingProvider { - fn has_credentials(&self) -> bool { - match *self.credential.read() { - ProviderCredential::Credentials { .. } => true, - _ => false, - } - } - - fn retrieve_credentials(&self, cx: &mut AppContext) -> BoxFuture { - let existing_credential = self.credential.read().clone(); - let retrieved_credential = match existing_credential { - ProviderCredential::Credentials { .. } => { - return async move { existing_credential }.boxed() - } - _ => { - if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() { - async move { ProviderCredential::Credentials { api_key } }.boxed() - } else { - let credentials = cx.read_credentials(OPEN_AI_API_URL); - async move { - if let Some(Some((_, api_key))) = credentials.await.log_err() { - if let Some(api_key) = String::from_utf8(api_key).log_err() { - ProviderCredential::Credentials { api_key } - } else { - ProviderCredential::NoCredentials - } - } else { - ProviderCredential::NoCredentials - } - } - .boxed() - } - } - }; - - async move { - let retrieved_credential = retrieved_credential.await; - *self.credential.write() = retrieved_credential.clone(); - retrieved_credential - } - .boxed() - } - - fn save_credentials( - &self, - cx: &mut AppContext, - credential: ProviderCredential, - ) -> BoxFuture<()> { - *self.credential.write() = credential.clone(); - let credential = credential.clone(); - let write_credentials = match credential { - ProviderCredential::Credentials { api_key } => { - Some(cx.write_credentials(OPEN_AI_API_URL, "Bearer", api_key.as_bytes())) - } - _ => None, - }; - - async move { - if let Some(write_credentials) = write_credentials { - write_credentials.await.log_err(); - } - } - .boxed() - } - - fn delete_credentials(&self, cx: &mut AppContext) -> BoxFuture<()> { - *self.credential.write() = ProviderCredential::NoCredentials; - let delete_credentials = cx.delete_credentials(OPEN_AI_API_URL); - async move { - delete_credentials.await.log_err(); - } - .boxed() - } -} - -#[async_trait] -impl EmbeddingProvider for OpenAiEmbeddingProvider { - fn base_model(&self) -> Box { - let model: Box = Box::new(self.model.clone()); - model - } - - fn max_tokens_per_batch(&self) -> usize { - 50000 - } - - fn rate_limit_expiration(&self) -> Option { - *self.rate_limit_count_rx.borrow() - } - - async fn embed_batch(&self, spans: Vec) -> Result> { - const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45]; - const MAX_RETRIES: usize = 4; - - let api_url = self.api_url.as_str(); - let api_key = self.get_api_key()?; - - let mut request_number = 0; - let mut rate_limiting = false; - let mut request_timeout: u64 = 15; - let mut response: Response; - while request_number < MAX_RETRIES { - response = self - .send_request( - &api_url, - &api_key, - spans.iter().map(|x| &**x).collect(), - request_timeout, - ) - .await?; - - request_number += 1; - - match response.status() { - StatusCode::REQUEST_TIMEOUT => { - request_timeout += 5; - } - StatusCode::OK => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - let response: OpenAiEmbeddingResponse = serde_json::from_str(&body)?; - - log::trace!( - "openai embedding completed. tokens: {:?}", - response.usage.total_tokens - ); - - // If we complete a request successfully that was previously rate_limited - // resolve the rate limit - if rate_limiting { - self.resolve_rate_limit() - } - - return Ok(response - .data - .into_iter() - .map(|embedding| Embedding::from(embedding.embedding)) - .collect()); - } - StatusCode::TOO_MANY_REQUESTS => { - rate_limiting = true; - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - let delay_duration = { - let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64); - if let Some(time_to_reset) = - response.headers().get("x-ratelimit-reset-tokens") - { - if let Ok(time_str) = time_to_reset.to_str() { - parse(time_str).unwrap_or(delay) - } else { - delay - } - } else { - delay - } - }; - - // If we've previously rate limited, increment the duration but not the count - let reset_time = Instant::now().add(delay_duration); - self.update_reset_time(reset_time); - - log::trace!( - "openai rate limiting: waiting {:?} until lifted", - &delay_duration - ); - - self.executor.timer(delay_duration).await; - } - _ => { - let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - return Err(anyhow!( - "open ai bad request: {:?} {:?}", - &response.status(), - body - )); - } - } - } - Err(anyhow!("openai max retries")) - } -} diff --git a/crates/ai/src/providers/open_ai/model.rs b/crates/ai/src/providers/open_ai/model.rs deleted file mode 100644 index f2f75977e488e6..00000000000000 --- a/crates/ai/src/providers/open_ai/model.rs +++ /dev/null @@ -1,59 +0,0 @@ -use anyhow::anyhow; -use tiktoken_rs::CoreBPE; - -use crate::models::{LanguageModel, TruncationDirection}; - -use super::open_ai_bpe_tokenizer; - -#[derive(Clone)] -pub struct OpenAiLanguageModel { - name: String, - bpe: Option, -} - -impl OpenAiLanguageModel { - pub fn load(model_name: &str) -> Self { - let bpe = tiktoken_rs::get_bpe_from_model(model_name) - .unwrap_or(open_ai_bpe_tokenizer().to_owned()); - OpenAiLanguageModel { - name: model_name.to_string(), - bpe: Some(bpe), - } - } -} - -impl LanguageModel for OpenAiLanguageModel { - fn name(&self) -> String { - self.name.clone() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - anyhow::Ok(bpe.encode_with_special_tokens(content).len()) - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result { - if let Some(bpe) = &self.bpe { - let tokens = bpe.encode_with_special_tokens(content); - if tokens.len() > length { - match direction { - TruncationDirection::End => bpe.decode(tokens[..length].to_vec()), - TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()), - } - } else { - bpe.decode(tokens) - } - } else { - Err(anyhow!("bpe for open ai model was not retrieved")) - } - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name)) - } -} diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs deleted file mode 100644 index f10ca4f5fa7ffb..00000000000000 --- a/crates/ai/src/test.rs +++ /dev/null @@ -1,206 +0,0 @@ -use std::{ - sync::atomic::{self, AtomicUsize, Ordering}, - time::Instant, -}; - -use async_trait::async_trait; -use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; -use gpui::AppContext; -use parking_lot::Mutex; - -use crate::{ - auth::{CredentialProvider, ProviderCredential}, - completion::{CompletionProvider, CompletionRequest}, - embedding::{Embedding, EmbeddingProvider}, - models::{LanguageModel, TruncationDirection}, -}; - -#[derive(Clone)] -pub struct FakeLanguageModel { - pub capacity: usize, -} - -impl LanguageModel for FakeLanguageModel { - fn name(&self) -> String { - "dummy".to_string() - } - fn count_tokens(&self, content: &str) -> anyhow::Result { - anyhow::Ok(content.chars().collect::>().len()) - } - fn truncate( - &self, - content: &str, - length: usize, - direction: TruncationDirection, - ) -> anyhow::Result { - println!("TRYING TO TRUNCATE: {:?}", length.clone()); - - if length > self.count_tokens(content)? { - println!("NOT TRUNCATING"); - return anyhow::Ok(content.to_string()); - } - - anyhow::Ok(match direction { - TruncationDirection::End => content.chars().collect::>()[..length] - .into_iter() - .collect::(), - TruncationDirection::Start => content.chars().collect::>()[length..] - .into_iter() - .collect::(), - }) - } - fn capacity(&self) -> anyhow::Result { - anyhow::Ok(self.capacity) - } -} - -#[derive(Default)] -pub struct FakeEmbeddingProvider { - pub embedding_count: AtomicUsize, -} - -impl Clone for FakeEmbeddingProvider { - fn clone(&self) -> Self { - FakeEmbeddingProvider { - embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)), - } - } -} - -impl FakeEmbeddingProvider { - pub fn embedding_count(&self) -> usize { - self.embedding_count.load(atomic::Ordering::SeqCst) - } - - pub fn embed_sync(&self, span: &str) -> Embedding { - let mut result = vec![1.0; 26]; - for letter in span.chars() { - let letter = letter.to_ascii_lowercase(); - if letter as u32 >= 'a' as u32 { - let ix = (letter as u32) - ('a' as u32); - if ix < 26 { - result[ix as usize] += 1.0; - } - } - } - - let norm = result.iter().map(|x| x * x).sum::().sqrt(); - for x in &mut result { - *x /= norm; - } - - result.into() - } -} - -impl CredentialProvider for FakeEmbeddingProvider { - fn has_credentials(&self) -> bool { - true - } - - fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture { - async { ProviderCredential::NotNeeded }.boxed() - } - - fn save_credentials( - &self, - _cx: &mut AppContext, - _credential: ProviderCredential, - ) -> BoxFuture<()> { - async {}.boxed() - } - - fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> { - async {}.boxed() - } -} - -#[async_trait] -impl EmbeddingProvider for FakeEmbeddingProvider { - fn base_model(&self) -> Box { - Box::new(FakeLanguageModel { capacity: 1000 }) - } - fn max_tokens_per_batch(&self) -> usize { - 1000 - } - - fn rate_limit_expiration(&self) -> Option { - None - } - - async fn embed_batch(&self, spans: Vec) -> anyhow::Result> { - self.embedding_count - .fetch_add(spans.len(), atomic::Ordering::SeqCst); - - anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect()) - } -} - -pub struct FakeCompletionProvider { - last_completion_tx: Mutex>>, -} - -impl Clone for FakeCompletionProvider { - fn clone(&self) -> Self { - Self { - last_completion_tx: Mutex::new(None), - } - } -} - -impl FakeCompletionProvider { - pub fn new() -> Self { - Self { - last_completion_tx: Mutex::new(None), - } - } - - pub fn send_completion(&self, completion: impl Into) { - let mut tx = self.last_completion_tx.lock(); - tx.as_mut().unwrap().try_send(completion.into()).unwrap(); - } - - pub fn finish_completion(&self) { - self.last_completion_tx.lock().take().unwrap(); - } -} - -impl CredentialProvider for FakeCompletionProvider { - fn has_credentials(&self) -> bool { - true - } - - fn retrieve_credentials(&self, _cx: &mut AppContext) -> BoxFuture { - async { ProviderCredential::NotNeeded }.boxed() - } - - fn save_credentials( - &self, - _cx: &mut AppContext, - _credential: ProviderCredential, - ) -> BoxFuture<()> { - async {}.boxed() - } - - fn delete_credentials(&self, _cx: &mut AppContext) -> BoxFuture<()> { - async {}.boxed() - } -} - -impl CompletionProvider for FakeCompletionProvider { - fn base_model(&self) -> Box { - let model: Box = Box::new(FakeLanguageModel { capacity: 8190 }); - model - } - fn complete( - &self, - _prompt: Box, - ) -> BoxFuture<'static, anyhow::Result>>> { - let (tx, rx) = mpsc::channel(1); - *self.last_completion_tx.lock() = Some(tx); - async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() - } - fn box_clone(&self) -> Box { - Box::new((*self).clone()) - } -} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index d84075a632b2d6..45b2f530a53d47 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -5,17 +5,14 @@ edition = "2021" publish = false license = "GPL-3.0-or-later" -[lints] -workspace = true - [lib] path = "src/assistant.rs" doctest = false [dependencies] -ai.workspace = true anyhow.workspace = true chrono.workspace = true +client.workspace = true collections.workspace = true editor.workspace = true fs.workspace = true @@ -26,12 +23,13 @@ language.workspace = true log.workspace = true menu.workspace = true multi_buffer.workspace = true +open_ai = { workspace = true, features = ["schemars"] } ordered-float.workspace = true +parking_lot.workspace = true project.workspace = true regex.workspace = true schemars.workspace = true search.workspace = true -semantic_index.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true @@ -45,7 +43,6 @@ uuid.workspace = true workspace.workspace = true [dev-dependencies] -ai = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index ae01db5adea194..036ae5c0716e45 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -1,22 +1,24 @@ pub mod assistant_panel; pub mod assistant_settings; mod codegen; +mod completion_provider; mod prompts; +mod saved_conversation; mod streaming_diff; -use ai::providers::open_ai::Role; -use anyhow::Result; pub use assistant_panel::AssistantPanel; -use assistant_settings::OpenAiModel; +use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel}; use chrono::{DateTime, Local}; -use collections::HashMap; -use fs::Fs; -use futures::StreamExt; +use client::{proto, Client}; +pub(crate) use completion_provider::*; use gpui::{actions, AppContext, SharedString}; -use regex::Regex; +pub(crate) use saved_conversation::*; use serde::{Deserialize, Serialize}; -use std::{cmp::Reverse, ffi::OsStr, path::PathBuf, sync::Arc}; -use util::paths::CONVERSATIONS_DIR; +use settings::Settings; +use std::{ + fmt::{self, Display}, + sync::Arc, +}; actions!( assistant, @@ -30,7 +32,6 @@ actions!( ResetKey, InlineAssist, ToggleIncludeConversation, - ToggleRetrieveContext, ] ); @@ -39,6 +40,139 @@ actions!( )] struct MessageId(usize); +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn cycle(&mut self) { + *self = match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "user"), + Role::Assistant => write!(f, "assistant"), + Role::System => write!(f, "system"), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum LanguageModel { + ZedDotDev(ZedDotDevModel), + OpenAi(OpenAiModel), +} + +impl Default for LanguageModel { + fn default() -> Self { + LanguageModel::ZedDotDev(ZedDotDevModel::default()) + } +} + +impl LanguageModel { + pub fn telemetry_id(&self) -> String { + match self { + LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), + LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()), + } + } + + pub fn display_name(&self) -> String { + match self { + LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()), + LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + LanguageModel::OpenAi(model) => tiktoken_rs::model::get_context_size(model.id()), + LanguageModel::ZedDotDev(model) => match model { + ZedDotDevModel::GptThreePointFiveTurbo + | ZedDotDevModel::GptFour + | ZedDotDevModel::GptFourTurbo => tiktoken_rs::model::get_context_size(model.id()), + ZedDotDevModel::Custom(_) => 30720, // TODO: Base this on the selected model. + }, + } + } + + pub fn id(&self) -> &str { + match self { + LanguageModel::OpenAi(model) => model.id(), + LanguageModel::ZedDotDev(model) => model.id(), + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelRequestMessage { + pub role: Role, + pub content: String, +} + +impl LanguageModelRequestMessage { + pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { + proto::LanguageModelRequestMessage { + role: match self.role { + Role::User => proto::LanguageModelRole::LanguageModelUser, + Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, + Role::System => proto::LanguageModelRole::LanguageModelSystem, + } as i32, + content: self.content.clone(), + } + } +} + +#[derive(Debug, Default, Serialize)] +pub struct LanguageModelRequest { + pub model: LanguageModel, + pub messages: Vec, + pub stop: Vec, + pub temperature: f32, +} + +impl LanguageModelRequest { + pub fn to_proto(&self) -> proto::CompleteWithLanguageModel { + proto::CompleteWithLanguageModel { + model: self.model.id().to_string(), + messages: self.messages.iter().map(|m| m.to_proto()).collect(), + stop: self.stop.clone(), + temperature: self.temperature, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct LanguageModelUsage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct LanguageModelChoiceDelta { + pub index: u32, + pub delta: LanguageModelResponseMessage, + pub finish_reason: Option, +} + #[derive(Clone, Debug, Serialize, Deserialize)] struct MessageMetadata { role: Role, @@ -53,71 +187,9 @@ enum MessageStatus { Error(SharedString), } -#[derive(Serialize, Deserialize)] -struct SavedMessage { - id: MessageId, - start: usize, -} - -#[derive(Serialize, Deserialize)] -struct SavedConversation { - id: Option, - zed: String, - version: String, - text: String, - messages: Vec, - message_metadata: HashMap, - summary: String, - api_url: Option, - model: OpenAiModel, -} - -impl SavedConversation { - const VERSION: &'static str = "0.1.0"; -} - -struct SavedConversationMetadata { - title: String, - path: PathBuf, - mtime: chrono::DateTime, -} - -impl SavedConversationMetadata { - pub async fn list(fs: Arc) -> Result> { - fs.create_dir(&CONVERSATIONS_DIR).await?; - - let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?; - let mut conversations = Vec::::new(); - while let Some(path) = paths.next().await { - let path = path?; - if path.extension() != Some(OsStr::new("json")) { - continue; - } - - let pattern = r" - \d+.zed.json$"; - let re = Regex::new(pattern).unwrap(); - - let metadata = fs.metadata(&path).await?; - if let Some((file_name, metadata)) = path - .file_name() - .and_then(|name| name.to_str()) - .zip(metadata) - { - let title = re.replace(file_name, ""); - conversations.push(Self { - title: title.into_owned(), - path, - mtime: metadata.mtime.into(), - }); - } - } - conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime)); - - Ok(conversations) - } -} - -pub fn init(cx: &mut AppContext) { +pub fn init(client: Arc, cx: &mut AppContext) { + AssistantSettings::register(cx); + completion_provider::init(client, cx); assistant_panel::init(cx); } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 4093df72804fd8..419bd10f4e96fc 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1,21 +1,13 @@ use crate::{ - assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAiModel}, + assistant_settings::{AssistantDockPosition, AssistantSettings, ZedDotDevModel}, codegen::{self, Codegen, CodegenKind}, prompts::generate_content_prompt, - Assist, CycleMessageRole, InlineAssist, MessageId, MessageMetadata, MessageStatus, + Assist, CompletionProvider, CycleMessageRole, InlineAssist, LanguageModel, + LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageMetadata, MessageStatus, NewConversation, QuoteSelection, ResetKey, Role, SavedConversation, SavedConversationMetadata, - SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, ToggleRetrieveContext, + SavedMessage, Split, ToggleFocus, ToggleIncludeConversation, }; -use ai::prompts::repository_context::PromptCodeSnippet; -use ai::{ - auth::ProviderCredential, - completion::{CompletionProvider, CompletionRequest}, - providers::open_ai::{ - OpenAiCompletionProvider, OpenAiCompletionProviderKind, OpenAiRequest, RequestMessage, - OPEN_AI_API_URL, - }, -}; -use anyhow::{anyhow, Result}; +use anyhow::Result; use chrono::{DateTime, Local}; use collections::{hash_map, HashMap, HashSet, VecDeque}; use editor::{ @@ -24,35 +16,25 @@ use editor::{ BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, ToDisplayPoint, }, scroll::{Autoscroll, AutoscrollStrategy}, - Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBufferSnapshot, ToOffset, + Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBufferSnapshot, ToOffset as _, ToPoint, }; use fs::Fs; use futures::StreamExt; use gpui::{ - canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AppContext, - AsyncAppContext, AsyncWindowContext, ClipboardItem, Context, EventEmitter, FocusHandle, - FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, IntoElement, Model, - ModelContext, ParentElement, Pixels, PromptLevel, Render, SharedString, + canvas, div, point, relative, rems, uniform_list, Action, AnyElement, AnyView, AppContext, + AsyncAppContext, AsyncWindowContext, AvailableSpace, ClipboardItem, Context, EventEmitter, + FocusHandle, FocusableView, FontStyle, FontWeight, HighlightStyle, InteractiveElement, + IntoElement, Model, ModelContext, ParentElement, Pixels, Render, SharedString, StatefulInteractiveElement, Styled, Subscription, Task, TextStyle, UniformListScrollHandle, View, ViewContext, VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext, }; use language::{language_settings::SoftWrap, Buffer, BufferId, LanguageRegistry, ToOffset as _}; +use parking_lot::Mutex; use project::Project; use search::{buffer_search::DivRegistrar, BufferSearchBar}; -use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::Settings; -use std::{ - cell::Cell, - cmp, - fmt::Write, - iter, - ops::Range, - path::{Path, PathBuf}, - rc::Rc, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{cmp, fmt::Write, iter, ops::Range, path::PathBuf, sync::Arc, time::Duration}; use telemetry_events::AssistantKind; use theme::ThemeSettings; use ui::{ @@ -69,7 +51,6 @@ use workspace::{ }; pub fn init(cx: &mut AppContext) { - AssistantSettings::register(cx); cx.observe_new_views( |workspace: &mut Workspace, _cx: &mut ViewContext| { workspace @@ -88,27 +69,29 @@ pub struct AssistantPanel { workspace: WeakView, width: Option, height: Option, - active_editor_index: Option, - prev_active_editor_index: Option, - editors: Vec>, + active_conversation_editor: Option, + show_saved_conversations: bool, saved_conversations: Vec, saved_conversations_scroll_handle: UniformListScrollHandle, zoomed: bool, focus_handle: FocusHandle, toolbar: View, - completion_provider: Arc, - api_key_editor: Option>, languages: Arc, fs: Arc, - subscriptions: Vec, + _subscriptions: Vec, next_inline_assist_id: usize, pending_inline_assists: HashMap, pending_inline_assist_ids_by_editor: HashMap, Vec>, include_conversation_in_next_inline_assist: bool, inline_prompt_history: VecDeque, _watch_saved_conversations: Task>, - semantic_index: Option>, - retrieve_context_in_next_inline_assist: bool, + model: LanguageModel, + authentication_prompt: Option, +} + +struct ActiveConversationEditor { + editor: View, + _subscriptions: Vec, } impl AssistantPanel { @@ -124,22 +107,6 @@ impl AssistantPanel { .await .log_err() .unwrap_or_default(); - let (provider_kind, api_url, model_name) = cx.update(|cx| { - let settings = AssistantSettings::get_global(cx); - anyhow::Ok(( - settings.provider_kind()?, - settings.provider_api_url()?, - settings.provider_model_name()?, - )) - })??; - - let completion_provider = OpenAiCompletionProvider::new( - api_url, - provider_kind, - model_name, - cx.background_executor().clone(), - ) - .await; // TODO: deserialize state. let workspace_handle = workspace.clone(); @@ -168,41 +135,48 @@ impl AssistantPanel { let toolbar = cx.new_view(|cx| { let mut toolbar = Toolbar::new(); toolbar.set_can_navigate(false, cx); - toolbar.add_item(cx.new_view(|cx| BufferSearchBar::new(cx)), cx); + toolbar.add_item(cx.new_view(BufferSearchBar::new), cx); toolbar }); - let semantic_index = SemanticIndex::global(cx); - let focus_handle = cx.focus_handle(); - cx.on_focus_in(&focus_handle, Self::focus_in).detach(); - cx.on_focus_out(&focus_handle, Self::focus_out).detach(); + let subscriptions = vec![ + cx.on_focus_in(&focus_handle, Self::focus_in), + cx.on_focus_out(&focus_handle, Self::focus_out), + cx.observe_global::({ + let mut prev_settings_version = + CompletionProvider::global(cx).settings_version(); + move |this, cx| { + this.completion_provider_changed(prev_settings_version, cx); + prev_settings_version = + CompletionProvider::global(cx).settings_version(); + } + }), + ]; + let model = CompletionProvider::global(cx).default_model(); Self { workspace: workspace_handle, - active_editor_index: Default::default(), - prev_active_editor_index: Default::default(), - editors: Default::default(), + active_conversation_editor: None, + show_saved_conversations: false, saved_conversations, saved_conversations_scroll_handle: Default::default(), zoomed: false, focus_handle, toolbar, - completion_provider: Arc::new(completion_provider), - api_key_editor: None, languages: workspace.app_state().languages.clone(), fs: workspace.app_state().fs.clone(), width: None, height: None, - subscriptions: Default::default(), + _subscriptions: subscriptions, next_inline_assist_id: 0, pending_inline_assists: Default::default(), pending_inline_assist_ids_by_editor: Default::default(), include_conversation_in_next_inline_assist: false, inline_prompt_history: Default::default(), _watch_saved_conversations, - semantic_index, - retrieve_context_in_next_inline_assist: false, + model, + authentication_prompt: None, } }) }) @@ -214,14 +188,8 @@ impl AssistantPanel { .update(cx, |toolbar, cx| toolbar.focus_changed(true, cx)); cx.notify(); if self.focus_handle.is_focused(cx) { - if self.has_credentials() { - if let Some(editor) = self.active_editor() { - cx.focus_view(editor); - } - } - - if let Some(api_key_editor) = self.api_key_editor.as_ref() { - cx.focus_view(api_key_editor); + if let Some(editor) = self.active_conversation_editor() { + cx.focus_view(editor); } } } @@ -232,6 +200,30 @@ impl AssistantPanel { cx.notify(); } + fn completion_provider_changed( + &mut self, + prev_settings_version: usize, + cx: &mut ViewContext, + ) { + if self.is_authenticated(cx) { + self.authentication_prompt = None; + + let model = CompletionProvider::global(cx).default_model(); + self.set_model(model, cx); + + if self.active_conversation_editor().is_none() { + self.new_conversation(cx); + } + } else if self.authentication_prompt.is_none() + || prev_settings_version != CompletionProvider::global(cx).settings_version() + { + self.authentication_prompt = + Some(cx.update_global::(|provider, cx| { + provider.authentication_prompt(cx) + })); + } + } + pub fn inline_assist( workspace: &mut Workspace, _: &InlineAssist, @@ -250,7 +242,7 @@ impl AssistantPanel { }; let project = workspace.project().clone(); - if assistant.update(cx, |assistant, _| assistant.has_credentials()) { + if assistant.update(cx, |assistant, cx| assistant.is_authenticated(cx)) { assistant.update(cx, |assistant, cx| { assistant.new_inline_assist(&active_editor, cx, &project) }); @@ -258,9 +250,9 @@ impl AssistantPanel { let assistant = assistant.downgrade(); cx.spawn(|workspace, mut cx| async move { assistant - .update(&mut cx, |assistant, cx| assistant.load_credentials(cx))? - .await; - if assistant.update(&mut cx, |assistant, _| assistant.has_credentials())? { + .update(&mut cx, |assistant, cx| assistant.authenticate(cx))? + .await?; + if assistant.update(&mut cx, |assistant, cx| assistant.is_authenticated(cx))? { assistant.update(&mut cx, |assistant, cx| { assistant.new_inline_assist(&active_editor, cx, &project) })?; @@ -311,34 +303,11 @@ impl AssistantPanel { }; let inline_assist_id = post_inc(&mut self.next_inline_assist_id); - let provider = self.completion_provider.clone(); - - let codegen = cx.new_model(|cx| { - Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) - }); - if let Some(semantic_index) = self.semantic_index.clone() { - let project = project.clone(); - cx.spawn(|_, mut cx| async move { - let previously_indexed = semantic_index - .update(&mut cx, |index, cx| { - index.project_previously_indexed(&project, cx) - })? - .await - .unwrap_or(false); - if previously_indexed { - let _ = semantic_index - .update(&mut cx, |index, cx| { - index.index_project(project.clone(), cx) - })? - .await; - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } + let codegen = + cx.new_model(|cx| Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, cx)); - let measurements = Rc::new(Cell::new(BlockMeasurements::default())); + let measurements = Arc::new(Mutex::new(BlockMeasurements::default())); let inline_assistant = cx.new_view(|cx| { InlineAssistant::new( inline_assist_id, @@ -348,9 +317,6 @@ impl AssistantPanel { codegen.clone(), self.workspace.clone(), cx, - self.retrieve_context_in_next_inline_assist, - self.semantic_index.clone(), - project.clone(), ) }); let block_id = editor.update(cx, |editor, cx| { @@ -365,10 +331,10 @@ impl AssistantPanel { render: Arc::new({ let inline_assistant = inline_assistant.clone(); move |cx: &mut BlockContext| { - measurements.set(BlockMeasurements { + *measurements.lock() = BlockMeasurements { anchor_x: cx.anchor_x, gutter_width: cx.gutter_dimensions.width, - }); + }; inline_assistant.clone().into_any_element() } }), @@ -456,7 +422,7 @@ impl AssistantPanel { .entry(editor.downgrade()) .or_default() .push(inline_assist_id); - self.update_highlights_for_editor(&editor, cx); + self.update_highlights_for_editor(editor, cx); } fn handle_inline_assistant_event( @@ -470,15 +436,8 @@ impl AssistantPanel { InlineAssistantEvent::Confirmed { prompt, include_conversation, - retrieve_context, } => { - self.confirm_inline_assist( - assist_id, - prompt, - *include_conversation, - cx, - *retrieve_context, - ); + self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx); } InlineAssistantEvent::Canceled => { self.finish_inline_assist(assist_id, true, cx); @@ -491,9 +450,6 @@ impl AssistantPanel { } => { self.include_conversation_in_next_inline_assist = *include_conversation; } - InlineAssistantEvent::RetrieveContextToggled { retrieve_context } => { - self.retrieve_context_in_next_inline_assist = *retrieve_context - } } } @@ -575,10 +531,9 @@ impl AssistantPanel { user_prompt: &str, include_conversation: bool, cx: &mut ViewContext, - retrieve_context: bool, ) { let conversation = if include_conversation { - self.active_editor() + self.active_conversation_editor() .map(|editor| editor.read(cx).conversation.clone()) } else { None @@ -599,17 +554,13 @@ impl AssistantPanel { let project = pending_assist.project.clone(); - let project_name = if let Some(project) = project.upgrade() { - Some( - project - .read(cx) - .worktree_root_names(cx) - .collect::>() - .join("/"), - ) - } else { - None - }; + let project_name = project.upgrade().map(|project| { + project + .read(cx) + .worktree_root_names(cx) + .collect::>() + .join("/") + }); self.inline_prompt_history .retain(|prompt| prompt != user_prompt); @@ -652,7 +603,7 @@ impl AssistantPanel { // If Markdown or No Language is Known, increase the randomness for more creative output // If Code, decrease temperature to get more deterministic outputs let temperature = if let Some(language) = language_name.clone() { - if *language != *"Markdown" { + if language.as_ref() != "Markdown" { 0.5 } else { 1.0 @@ -663,61 +614,9 @@ impl AssistantPanel { let user_prompt = user_prompt.to_string(); - let snippets = if retrieve_context { - let Some(project) = project.upgrade() else { - return; - }; - - let search_results = if let Some(semantic_index) = self.semantic_index.clone() { - let search_results = semantic_index.update(cx, |this, cx| { - this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx) - }); - - cx.background_executor() - .spawn(async move { search_results.await.unwrap_or_default() }) - } else { - Task::ready(Vec::new()) - }; - - let snippets = cx.spawn(|_, mut cx| async move { - let mut snippets = Vec::new(); - for result in search_results.await { - snippets.push(PromptCodeSnippet::new( - result.buffer, - result.range, - &mut cx, - )?); - } - anyhow::Ok(snippets) - }); - snippets - } else { - Task::ready(Ok(Vec::new())) - }; - - let Some(mut model_name) = AssistantSettings::get_global(cx) - .provider_model_name() - .log_err() - else { - return; - }; - - let prompt = cx.background_executor().spawn({ - let model_name = model_name.clone(); - async move { - let snippets = snippets.await?; - - let language_name = language_name.as_deref(); - generate_content_prompt( - user_prompt, - language_name, - buffer, - range, - snippets, - &model_name, - project_name, - ) - } + let prompt = cx.background_executor().spawn(async move { + let language_name = language_name.as_deref(); + generate_content_prompt(user_prompt, language_name, buffer, range, project_name) }); let mut messages = Vec::new(); @@ -729,25 +628,24 @@ impl AssistantPanel { .messages(cx) .map(|message| message.to_open_ai_message(buffer)), ); - model_name = conversation.model.full_name().to_string(); } + let model = self.model.clone(); cx.spawn(|_, mut cx| async move { // I Don't know if we want to return a ? here. let prompt = prompt.await?; - messages.push(RequestMessage { + messages.push(LanguageModelRequestMessage { role: Role::User, content: prompt, }); - let request = Box::new(OpenAiRequest { - model: model_name, + let request = LanguageModelRequest { + model, messages, - stream: true, stop: vec!["|END|>".to_string()], temperature, - }); + }; codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx))?; anyhow::Ok(()) @@ -781,7 +679,7 @@ impl AssistantPanel { } else { editor.highlight_background::( background_ranges, - |theme| theme.editor_active_line_background, // todo("use the appropriate color") + |theme| theme.editor_active_line_background, // todo!("use the appropriate color") cx, ); } @@ -801,54 +699,82 @@ impl AssistantPanel { }); } - fn build_api_key_editor(&mut self, cx: &mut WindowContext<'_>) { - self.api_key_editor = Some(build_api_key_editor(cx)); - } - fn new_conversation(&mut self, cx: &mut ViewContext) -> View { let editor = cx.new_view(|cx| { ConversationEditor::new( - self.completion_provider.clone(), + self.model.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), cx, ) }); - self.add_conversation(editor.clone(), cx); + self.show_conversation(editor.clone(), cx); editor } - fn add_conversation(&mut self, editor: View, cx: &mut ViewContext) { - self.subscriptions - .push(cx.subscribe(&editor, Self::handle_conversation_editor_event)); + fn show_conversation( + &mut self, + conversation_editor: View, + cx: &mut ViewContext, + ) { + let mut subscriptions = Vec::new(); + subscriptions + .push(cx.subscribe(&conversation_editor, Self::handle_conversation_editor_event)); - let conversation = editor.read(cx).conversation.clone(); - self.subscriptions - .push(cx.observe(&conversation, |_, _, cx| cx.notify())); + let conversation = conversation_editor.read(cx).conversation.clone(); + subscriptions.push(cx.observe(&conversation, |_, _, cx| cx.notify())); + + let editor = conversation_editor.read(cx).editor.clone(); + self.toolbar.update(cx, |toolbar, cx| { + toolbar.set_active_item(Some(&editor), cx); + }); + if self.focus_handle.contains_focused(cx) { + cx.focus_view(&editor); + } + self.active_conversation_editor = Some(ActiveConversationEditor { + editor: conversation_editor, + _subscriptions: subscriptions, + }); + self.show_saved_conversations = false; - let index = self.editors.len(); - self.editors.push(editor); - self.set_active_editor_index(Some(index), cx); + cx.notify(); } - fn set_active_editor_index(&mut self, index: Option, cx: &mut ViewContext) { - self.prev_active_editor_index = self.active_editor_index; - self.active_editor_index = index; - if let Some(editor) = self.active_editor() { - let editor = editor.read(cx).editor.clone(); - self.toolbar.update(cx, |toolbar, cx| { - toolbar.set_active_item(Some(&editor), cx); - }); - if self.focus_handle.contains_focused(cx) { - cx.focus_view(&editor); - } - } else { - self.toolbar.update(cx, |toolbar, cx| { - toolbar.set_active_item(None, cx); - }); - } + fn cycle_model(&mut self, cx: &mut ViewContext) { + let next_model = match &self.model { + LanguageModel::OpenAi(model) => LanguageModel::OpenAi(match &model { + open_ai::Model::ThreePointFiveTurbo => open_ai::Model::Four, + open_ai::Model::Four => open_ai::Model::FourTurbo, + open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo, + }), + LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model { + ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour, + ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo, + ZedDotDevModel::GptFourTurbo => { + match CompletionProvider::global(cx).default_model() { + LanguageModel::ZedDotDev(custom) => custom, + _ => ZedDotDevModel::GptThreePointFiveTurbo, + } + } + ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo, + }), + }; + + self.set_model(next_model, cx); + } + fn set_model(&mut self, model: LanguageModel, cx: &mut ViewContext) { + self.model = model.clone(); + if let Some(editor) = self.active_conversation_editor() { + editor.update(cx, |active_conversation, cx| { + active_conversation + .conversation + .update(cx, |conversation, cx| { + conversation.set_model(model, cx); + }) + }) + } cx.notify(); } @@ -863,49 +789,6 @@ impl AssistantPanel { } } - fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - if let Some(api_key) = self - .api_key_editor - .as_ref() - .map(|editor| editor.read(cx).text(cx)) - { - if !api_key.is_empty() { - let credential = ProviderCredential::Credentials { - api_key: api_key.clone(), - }; - - let completion_provider = self.completion_provider.clone(); - cx.spawn(|this, mut cx| async move { - cx.update(|cx| completion_provider.save_credentials(cx, credential))? - .await; - - this.update(&mut cx, |this, cx| { - this.api_key_editor.take(); - this.focus_handle.focus(cx); - cx.notify(); - }) - }) - .detach_and_log_err(cx); - } - } else { - cx.propagate(); - } - } - - fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { - let completion_provider = self.completion_provider.clone(); - cx.spawn(|this, mut cx| async move { - cx.update(|cx| completion_provider.delete_credentials(cx))? - .await; - this.update(&mut cx, |this, cx| { - this.build_api_key_editor(cx); - this.focus_handle.focus(cx); - cx.notify(); - }) - }) - .detach_and_log_err(cx); - } - fn toggle_zoom(&mut self, _: &workspace::ToggleZoom, cx: &mut ViewContext) { if self.zoomed { cx.emit(PanelEvent::ZoomOut) @@ -958,58 +841,27 @@ impl AssistantPanel { } } - fn active_editor(&self) -> Option<&View> { - self.editors.get(self.active_editor_index?) + fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { + CompletionProvider::global(cx) + .reset_credentials(cx) + .detach_and_log_err(cx); } - fn render_api_key_editor( - &self, - editor: &View, - cx: &mut ViewContext, - ) -> impl IntoElement { - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color: if editor.read(cx).read_only(cx) { - cx.theme().colors().text_disabled - } else { - cx.theme().colors().text - }, - font_family: settings.ui_font.family.clone(), - font_features: settings.ui_font.features, - font_size: rems(0.875).into(), - font_weight: FontWeight::NORMAL, - font_style: FontStyle::Normal, - line_height: relative(1.3), - background_color: None, - underline: None, - strikethrough: None, - white_space: WhiteSpace::Normal, - }; - EditorElement::new( - &editor, - EditorStyle { - background: cx.theme().colors().editor_background, - local_player: cx.theme().players().local(), - text: text_style, - ..Default::default() - }, - ) + fn active_conversation_editor(&self) -> Option<&View> { + Some(&self.active_conversation_editor.as_ref()?.editor) } fn render_hamburger_button(cx: &mut ViewContext) -> impl IntoElement { IconButton::new("hamburger_button", IconName::Menu) .on_click(cx.listener(|this, _event, cx| { - if this.active_editor().is_some() { - this.set_active_editor_index(None, cx); - } else { - this.set_active_editor_index(this.prev_active_editor_index, cx); - } + this.show_saved_conversations = !this.show_saved_conversations; + cx.notify(); })) .tooltip(|cx| Tooltip::text("Conversation History", cx)) } fn render_editor_tools(&self, cx: &mut ViewContext) -> Vec { - if self.active_editor().is_some() { + if self.active_conversation_editor().is_some() { vec![ Self::render_split_button(cx).into_any_element(), Self::render_quote_button(cx).into_any_element(), @@ -1023,7 +875,7 @@ impl AssistantPanel { fn render_split_button(cx: &mut ViewContext) -> impl IntoElement { IconButton::new("split_button", IconName::Snip) .on_click(cx.listener(|this, _event, cx| { - if let Some(active_editor) = this.active_editor() { + if let Some(active_editor) = this.active_conversation_editor() { active_editor.update(cx, |editor, cx| editor.split(&Default::default(), cx)); } })) @@ -1034,7 +886,7 @@ impl AssistantPanel { fn render_assist_button(cx: &mut ViewContext) -> impl IntoElement { IconButton::new("assist_button", IconName::MagicWand) .on_click(cx.listener(|this, _event, cx| { - if let Some(active_editor) = this.active_editor() { + if let Some(active_editor) = this.active_conversation_editor() { active_editor.update(cx, |editor, cx| editor.assist(&Default::default(), cx)); } })) @@ -1111,202 +963,185 @@ impl AssistantPanel { fn open_conversation(&mut self, path: PathBuf, cx: &mut ViewContext) -> Task> { cx.focus(&self.focus_handle); - if let Some(ix) = self.editor_index_for_path(&path, cx) { - self.set_active_editor_index(Some(ix), cx); - return Task::ready(Ok(())); - } - let fs = self.fs.clone(); let workspace = self.workspace.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { - let saved_conversation = fs.load(&path).await?; - let saved_conversation = serde_json::from_str(&saved_conversation)?; - let conversation = - Conversation::deserialize(saved_conversation, path.clone(), languages, &mut cx) - .await?; + let saved_conversation = SavedConversation::load(&path, fs.as_ref()).await?; + let model = this.update(&mut cx, |this, _| this.model.clone())?; + let conversation = Conversation::deserialize( + saved_conversation, + model, + path.clone(), + languages, + &mut cx, + ) + .await?; this.update(&mut cx, |this, cx| { - // If, by the time we've loaded the conversation, the user has already opened - // the same conversation, we don't want to open it again. - if let Some(ix) = this.editor_index_for_path(&path, cx) { - this.set_active_editor_index(Some(ix), cx); - } else { - let editor = cx.new_view(|cx| { - ConversationEditor::for_conversation(conversation, fs, workspace, cx) - }); - this.add_conversation(editor, cx); - } + let editor = cx.new_view(|cx| { + ConversationEditor::for_conversation(conversation, fs, workspace, cx) + }); + this.show_conversation(editor, cx); })?; Ok(()) }) } - fn editor_index_for_path(&self, path: &Path, cx: &AppContext) -> Option { - self.editors - .iter() - .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) - } - - fn has_credentials(&mut self) -> bool { - self.completion_provider.has_credentials() + fn is_authenticated(&mut self, cx: &mut ViewContext) -> bool { + CompletionProvider::global(cx).is_authenticated() } - fn load_credentials(&mut self, cx: &mut ViewContext) -> Task<()> { - let completion_provider = self.completion_provider.clone(); - cx.spawn(|_, mut cx| async move { - if let Some(retrieve_credentials) = cx - .update(|cx| completion_provider.retrieve_credentials(cx)) - .log_err() - { - retrieve_credentials.await; - } - }) + fn authenticate(&mut self, cx: &mut ViewContext) -> Task> { + cx.update_global::(|provider, cx| provider.authenticate(cx)) } -} -fn build_api_key_editor(cx: &mut WindowContext) -> View { - cx.new_view(|cx| { - let mut editor = Editor::single_line(cx); - editor.set_placeholder_text("sk-000000000000000000000000000000000000000000000000", cx); - editor - }) -} - -impl Render for AssistantPanel { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - if let Some(api_key_editor) = self.api_key_editor.clone() { - const INSTRUCTIONS: [&'static str; 6] = [ - "To use the assistant panel or inline assistant, you need to add your OpenAI API key.", - " - You can create an API key at: platform.openai.com/api-keys", - " - Make sure your OpenAI account has credits", - " - Having a subscription for another service like GitHub Copilot won't work.", - " ", - "Paste your OpenAI API key and press Enter to use the assistant:" - ]; - - v_flex() - .p_4() - .size_full() - .on_action(cx.listener(AssistantPanel::save_credentials)) - .track_focus(&self.focus_handle) - .children( - INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), - ) - .child( - h_flex() - .w_full() - .my_2() - .px_2() - .py_1() - .bg(cx.theme().colors().editor_background) - .rounded_md() - .child(self.render_api_key_editor(&api_key_editor, cx)), - ) - .child( + fn render_signed_in(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let header = TabBar::new("assistant_header") + .start_child( + h_flex().gap_1().child(Self::render_hamburger_button(cx)), // .children(title), + ) + .children(self.active_conversation_editor().map(|editor| { + h_flex() + .h(rems(Tab::CONTAINER_HEIGHT_IN_REMS)) + .flex_1() + .px_2() + .child(Label::new(editor.read(cx).title(cx)).into_element()) + })) + .when(self.focus_handle.contains_focused(cx), |this| { + this.end_child( h_flex() .gap_2() - .child(Label::new("Click on").size(LabelSize::Small)) - .child(Icon::new(IconName::Ai).size(IconSize::XSmall)) + .when(self.active_conversation_editor().is_some(), |this| { + this.child(h_flex().gap_1().children(self.render_editor_tools(cx))) + .child( + ui::Divider::vertical() + .inset() + .color(ui::DividerColor::Border), + ) + }) .child( - Label::new("in the status bar to close this panel.") - .size(LabelSize::Small), + h_flex() + .gap_1() + .child(Self::render_plus_button(cx)) + .child(self.render_zoom_button(cx)), ), ) - } else { - let header = TabBar::new("assistant_header") - .start_child( - h_flex().gap_1().child(Self::render_hamburger_button(cx)), // .children(title), - ) - .children(self.active_editor().map(|editor| { - h_flex() - .h(rems(Tab::CONTAINER_HEIGHT_IN_REMS)) - .flex_1() - .px_2() - .child(Label::new(editor.read(cx).title(cx)).into_element()) - })) - .when(self.focus_handle.contains_focused(cx), |this| { - this.end_child( - h_flex() - .gap_2() - .when(self.active_editor().is_some(), |this| { - this.child(h_flex().gap_1().children(self.render_editor_tools(cx))) - .child( - ui::Divider::vertical() - .inset() - .color(ui::DividerColor::Border), - ) - }) - .child( - h_flex() - .gap_1() - .child(Self::render_plus_button(cx)) - .child(self.render_zoom_button(cx)), - ), - ) - }); + }); - let contents = if self.active_editor().is_some() { - let mut registrar = DivRegistrar::new( - |panel, cx| panel.toolbar.read(cx).item_of_type::(), - cx, - ); - BufferSearchBar::register(&mut registrar); - registrar.into_div() + let contents = if self.active_conversation_editor().is_some() { + let mut registrar = DivRegistrar::new( + |panel, cx| panel.toolbar.read(cx).item_of_type::(), + cx, + ); + BufferSearchBar::register(&mut registrar); + registrar.into_div() + } else { + div() + }; + v_flex() + .key_context("AssistantPanel") + .size_full() + .on_action(cx.listener(|this, _: &workspace::NewFile, cx| { + this.new_conversation(cx); + })) + .on_action(cx.listener(AssistantPanel::toggle_zoom)) + .on_action(cx.listener(AssistantPanel::deploy)) + .on_action(cx.listener(AssistantPanel::select_next_match)) + .on_action(cx.listener(AssistantPanel::select_prev_match)) + .on_action(cx.listener(AssistantPanel::handle_editor_cancel)) + .on_action(cx.listener(AssistantPanel::reset_credentials)) + .track_focus(&self.focus_handle) + .child(header) + .children(if self.toolbar.read(cx).hidden() { + None } else { - div() - }; - v_flex() - .key_context("AssistantPanel") - .size_full() - .on_action(cx.listener(|this, _: &workspace::NewFile, cx| { - this.new_conversation(cx); - })) - .on_action(cx.listener(AssistantPanel::reset_credentials)) - .on_action(cx.listener(AssistantPanel::toggle_zoom)) - .on_action(cx.listener(AssistantPanel::deploy)) - .on_action(cx.listener(AssistantPanel::select_next_match)) - .on_action(cx.listener(AssistantPanel::select_prev_match)) - .on_action(cx.listener(AssistantPanel::handle_editor_cancel)) - .track_focus(&self.focus_handle) - .child(header) - .children(if self.toolbar.read(cx).hidden() { - None - } else { - Some(self.toolbar.clone()) - }) - .child( - contents - .flex_1() - .child(if let Some(editor) = self.active_editor() { - editor.clone().into_any_element() - } else { - let view = cx.view().clone(); - let scroll_handle = self.saved_conversations_scroll_handle.clone(); - let conversation_count = self.saved_conversations.len(); - canvas( - move |bounds, cx| { - let mut list = uniform_list( - view, - "saved_conversations", - conversation_count, - |this, range, cx| { - range - .map(|ix| this.render_saved_conversation(ix, cx)) - .collect() - }, - ) - .track_scroll(scroll_handle) - .into_any_element(); - list.layout(bounds.origin, bounds.size.into(), cx); - list + Some(self.toolbar.clone()) + }) + .child(contents.flex_1().child( + if self.show_saved_conversations || self.active_conversation_editor().is_none() { + let view = cx.view().clone(); + let scroll_handle = self.saved_conversations_scroll_handle.clone(); + let conversation_count = self.saved_conversations.len(); + canvas( + move |bounds, cx| { + let mut saved_conversations = uniform_list( + view, + "saved_conversations", + conversation_count, + |this, range, cx| { + range + .map(|ix| this.render_saved_conversation(ix, cx)) + .collect() }, - |_bounds, mut list, cx| list.paint(cx), ) - .size_full() - .into_any_element() - }), - ) + .track_scroll(scroll_handle) + .into_any_element(); + saved_conversations.layout( + bounds.origin, + bounds.size.map(AvailableSpace::Definite), + cx, + ); + saved_conversations + }, + |_bounds, mut saved_conversations, cx| saved_conversations.paint(cx), + ) + .size_full() + .into_any_element() + } else { + let editor = self.active_conversation_editor().unwrap(); + let conversation = editor.read(cx).conversation.clone(); + div() + .size_full() + .child(editor.clone()) + .child( + h_flex() + .absolute() + .gap_1() + .top_3() + .right_5() + .child(self.render_model(&conversation, cx)) + .children(self.render_remaining_tokens(&conversation, cx)), + ) + .into_any_element() + }, + )) + } + + fn render_model( + &self, + conversation: &Model, + cx: &mut ViewContext, + ) -> impl IntoElement { + Button::new("current_model", conversation.read(cx).model.display_name()) + .style(ButtonStyle::Filled) + .tooltip(move |cx| Tooltip::text("Change Model", cx)) + .on_click(cx.listener(|this, _, cx| this.cycle_model(cx))) + } + + fn render_remaining_tokens( + &self, + conversation: &Model, + cx: &mut ViewContext, + ) -> Option { + let remaining_tokens = conversation.read(cx).remaining_tokens()?; + let remaining_tokens_color = if remaining_tokens <= 0 { + Color::Error + } else if remaining_tokens <= 500 { + Color::Warning + } else { + Color::Default + }; + Some(Label::new(remaining_tokens.to_string()).color(remaining_tokens_color)) + } +} + +impl Render for AssistantPanel { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + if let Some(authentication_prompt) = self.authentication_prompt.as_ref() { + authentication_prompt.clone().into_any() + } else { + self.render_signed_in(cx).into_any_element() } } } @@ -1335,7 +1170,7 @@ impl Panel for AssistantPanel { DockPosition::Bottom => AssistantDockPosition::Bottom, DockPosition::Right => AssistantDockPosition::Right, }; - settings.dock = Some(dock); + settings.set_dock(dock); }); } @@ -1343,9 +1178,9 @@ impl Panel for AssistantPanel { let settings = AssistantSettings::get_global(cx); match self.position(cx) { DockPosition::Left | DockPosition::Right => { - self.width.unwrap_or_else(|| settings.default_width) + self.width.unwrap_or(settings.default_width) } - DockPosition::Bottom => self.height.unwrap_or_else(|| settings.default_height), + DockPosition::Bottom => self.height.unwrap_or(settings.default_height), } } @@ -1368,13 +1203,11 @@ impl Panel for AssistantPanel { fn set_active(&mut self, active: bool, cx: &mut ViewContext) { if active { - let load_credentials = self.load_credentials(cx); + let load_credentials = self.authenticate(cx); cx.spawn(|this, mut cx| async move { - load_credentials.await; + load_credentials.await?; this.update(&mut cx, |this, cx| { - if !this.has_credentials() { - this.build_api_key_editor(cx); - } else if this.editors.is_empty() { + if this.is_authenticated(cx) && this.active_conversation_editor().is_none() { this.new_conversation(cx); } }) @@ -1426,24 +1259,21 @@ struct Conversation { pending_summary: Task>, completion_count: usize, pending_completions: Vec, - model: OpenAiModel, - api_url: Option, + model: LanguageModel, token_count: Option, - max_token_count: usize, pending_token_count: Task>, pending_save: Task>, path: Option, _subscriptions: Vec, - completion_provider: Arc, } impl EventEmitter for Conversation {} impl Conversation { fn new( + model: LanguageModel, language_registry: Arc, cx: &mut ModelContext, - completion_provider: Arc, ) -> Self { let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.new_model(|cx| { @@ -1460,16 +1290,6 @@ impl Conversation { buffer }); - let settings = AssistantSettings::get_global(cx); - let model = settings - .provider_model() - .log_err() - .unwrap_or(OpenAiModel::FourTurbo); - let api_url = settings - .provider_api_url() - .log_err() - .unwrap_or_else(|| OPEN_AI_API_URL.to_string()); - let mut this = Self { id: Some(Uuid::new_v4().to_string()), message_anchors: Default::default(), @@ -1480,15 +1300,12 @@ impl Conversation { completion_count: Default::default(), pending_completions: Default::default(), token_count: None, - max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), pending_token_count: Task::ready(None), - api_url: Some(api_url), model, _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, buffer, - completion_provider, }; let message = MessageAnchor { id: MessageId(post_inc(&mut this.next_message_id.0)), @@ -1527,13 +1344,12 @@ impl Conversation { .as_ref() .map(|summary| summary.text.clone()) .unwrap_or_default(), - model: self.model, - api_url: self.api_url.clone(), } } async fn deserialize( saved_conversation: SavedConversation, + model: LanguageModel, path: PathBuf, language_registry: Arc, cx: &mut AsyncAppContext, @@ -1542,21 +1358,6 @@ impl Conversation { Some(id) => Some(id), None => Some(Uuid::new_v4().to_string()), }; - let model = saved_conversation.model; - let api_url = saved_conversation.api_url; - let completion_provider: Arc = Arc::new( - OpenAiCompletionProvider::new( - api_url - .clone() - .unwrap_or_else(|| OPEN_AI_API_URL.to_string()), - OpenAiCompletionProviderKind::OpenAi, - model.full_name().into(), - cx.background_executor().clone(), - ) - .await, - ); - cx.update(|cx| completion_provider.retrieve_credentials(cx))? - .await; let markdown = language_registry.language_for_name("Markdown"); let mut message_anchors = Vec::new(); @@ -1600,15 +1401,12 @@ impl Conversation { completion_count: Default::default(), pending_completions: Default::default(), token_count: None, - max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()), pending_token_count: Task::ready(None), - api_url, model, _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), buffer, - completion_provider, }; this.count_remaining_tokens(cx); this @@ -1621,50 +1419,25 @@ impl Conversation { event: &language::Event, cx: &mut ModelContext, ) { - match event { - language::Event::Edited => { - self.count_remaining_tokens(cx); - cx.emit(ConversationEvent::MessagesEdited); - } - _ => {} + if *event == language::Event::Edited { + self.count_remaining_tokens(cx); + cx.emit(ConversationEvent::MessagesEdited); } } fn count_remaining_tokens(&mut self, cx: &mut ModelContext) { - let messages = self - .messages(cx) - .map(|message| tiktoken_rs::ChatCompletionRequestMessage { - role: match message.role { - Role::User => "user".into(), - Role::Assistant => "assistant".into(), - Role::System => "system".into(), - }, - content: Some( - self.buffer - .read(cx) - .text_for_range(message.offset_range) - .collect(), - ), - name: None, - function_call: None, - }) - .collect::>(); - let model = self.model; + let request = self.to_completion_request(cx); self.pending_token_count = cx.spawn(|this, mut cx| { async move { cx.background_executor() .timer(Duration::from_millis(200)) .await; + let token_count = cx - .background_executor() - .spawn(async move { - tiktoken_rs::num_tokens_from_messages(&model.full_name(), &messages) - }) + .update(|cx| CompletionProvider::global(cx).count_tokens(request, cx))? .await?; this.update(&mut cx, |this, cx| { - this.max_token_count = - tiktoken_rs::model::get_context_size(&this.model.full_name()); this.token_count = Some(token_count); cx.notify() })?; @@ -1675,13 +1448,12 @@ impl Conversation { } fn remaining_tokens(&self) -> Option { - Some(self.max_token_count as isize - self.token_count? as isize) + Some(self.model.max_token_count() as isize - self.token_count? as isize) } - fn set_model(&mut self, model: OpenAiModel, cx: &mut ModelContext) { + fn set_model(&mut self, model: LanguageModel, cx: &mut ModelContext) { self.model = model; self.count_remaining_tokens(cx); - cx.notify(); } fn assist( @@ -1727,24 +1499,13 @@ impl Conversation { } if should_assist { - if !self.completion_provider.has_credentials() { + if !CompletionProvider::global(cx).is_authenticated() { log::info!("completion provider has no credentials"); return Default::default(); } - let request: Box = Box::new(OpenAiRequest { - model: self.model.full_name().to_string(), - messages: self - .messages(cx) - .filter(|message| matches!(message.status, MessageStatus::Done)) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .collect(), - stream: true, - stop: vec![], - temperature: 1.0, - }); - - let stream = self.completion_provider.complete(request); + let request = self.to_completion_request(cx); + let stream = CompletionProvider::global(cx).complete(request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1810,7 +1571,7 @@ impl Conversation { )); } } - cx.notify(); + cx.emit(ConversationEvent::MessagesEdited); } }) .ok(); @@ -1826,6 +1587,20 @@ impl Conversation { user_messages } + fn to_completion_request(&self, cx: &mut ModelContext) -> LanguageModelRequest { + let request = LanguageModelRequest { + model: self.model.clone(), + messages: self + .messages(cx) + .filter(|message| matches!(message.status, MessageStatus::Done)) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .collect(), + stop: vec![], + temperature: 1.0, + }; + request + } + fn cancel_last_assist(&mut self) -> bool { self.pending_completions.pop().is_some() } @@ -2002,7 +1777,7 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - if !self.completion_provider.has_credentials() { + if !CompletionProvider::global(cx).is_authenticated() { return; } @@ -2010,20 +1785,19 @@ impl Conversation { .messages(cx) .take(2) .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .chain(Some(RequestMessage { + .chain(Some(LanguageModelRequestMessage { role: Role::User, content: "Summarize the conversation into a short title without punctuation" .into(), })); - let request: Box = Box::new(OpenAiRequest { - model: self.model.full_name().to_string(), + let request = LanguageModelRequest { + model: self.model.clone(), messages: messages.collect(), - stream: true, stop: vec![], temperature: 1.0, - }); + }; - let stream = self.completion_provider.complete(request); + let stream = CompletionProvider::global(cx).complete(request); self.pending_summary = cx.spawn(|this, mut cx| { async move { let mut messages = stream.await?; @@ -2210,14 +1984,13 @@ struct ConversationEditor { impl ConversationEditor { fn new( - completion_provider: Arc, + model: LanguageModel, language_registry: Arc, fs: Arc, workspace: WeakView, cx: &mut ViewContext, ) -> Self { - let conversation = - cx.new_model(|cx| Conversation::new(language_registry, cx, completion_provider)); + let conversation = cx.new_model(|cx| Conversation::new(model, language_registry, cx)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -2255,12 +2028,14 @@ impl ConversationEditor { } fn assist(&mut self, _: &Assist, cx: &mut ViewContext) { - report_assistant_event( - self.workspace.clone(), - self.conversation.read(cx).id.clone(), - AssistantKind::Panel, - cx, - ); + self.conversation.update(cx, |conversation, cx| { + report_assistant_event( + self.workspace.clone(), + Some(conversation), + AssistantKind::Panel, + cx, + ) + }); let cursors = self.cursors(cx); @@ -2543,7 +2318,7 @@ impl ConversationEditor { if let Some(text) = text { panel.update(cx, |panel, cx| { let conversation = panel - .active_editor() + .active_conversation_editor() .cloned() .unwrap_or_else(|| panel.new_conversation(cx)); conversation.update(cx, |conversation, cx| { @@ -2572,7 +2347,7 @@ impl ConversationEditor { spanned_messages += 1; write!(&mut copied_text, "## {}\n\n", message.role).unwrap(); for chunk in conversation.buffer.read(cx).text_for_range(range) { - copied_text.push_str(&chunk); + copied_text.push_str(chunk); } copied_text.push('\n'); } @@ -2591,7 +2366,7 @@ impl ConversationEditor { fn split(&mut self, _: &Split, cx: &mut ViewContext) { self.conversation.update(cx, |conversation, cx| { let selections = self.editor.read(cx).selections.disjoint_anchors(); - for selection in selections.into_iter() { + for selection in selections.as_ref() { let buffer = self.editor.read(cx).buffer().read(cx).snapshot(cx); let range = selection .map(|endpoint| endpoint.to_offset(&buffer)) @@ -2607,13 +2382,6 @@ impl ConversationEditor { }); } - fn cycle_model(&mut self, cx: &mut ViewContext) { - self.conversation.update(cx, |conversation, cx| { - let new_model = conversation.model.cycle(); - conversation.set_model(new_model, cx); - }); - } - fn title(&self, cx: &AppContext) -> String { self.conversation .read(cx) @@ -2622,28 +2390,6 @@ impl ConversationEditor { .map(|summary| summary.text.clone()) .unwrap_or_else(|| "New Conversation".into()) } - - fn render_current_model(&self, cx: &mut ViewContext) -> impl IntoElement { - Button::new( - "current_model", - self.conversation.read(cx).model.short_name(), - ) - .style(ButtonStyle::Filled) - .tooltip(move |cx| Tooltip::text("Change Model", cx)) - .on_click(cx.listener(|this, _, cx| this.cycle_model(cx))) - } - - fn render_remaining_tokens(&self, cx: &mut ViewContext) -> Option { - let remaining_tokens = self.conversation.read(cx).remaining_tokens()?; - let remaining_tokens_color = if remaining_tokens <= 0 { - Color::Error - } else if remaining_tokens <= 500 { - Color::Warning - } else { - Color::Default - }; - Some(Label::new(remaining_tokens.to_string()).color(remaining_tokens_color)) - } } impl EventEmitter for ConversationEditor {} @@ -2667,15 +2413,6 @@ impl Render for ConversationEditor { .bg(cx.theme().colors().editor_background) .child(self.editor.clone()), ) - .child( - h_flex() - .absolute() - .gap_1() - .top_3() - .right_5() - .child(self.render_current_model(cx)) - .children(self.render_remaining_tokens(cx)), - ) } } @@ -2703,11 +2440,11 @@ pub struct Message { } impl Message { - fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage { + fn to_open_ai_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage { let content = buffer .text_for_range(self.offset_range.clone()) .collect::(); - RequestMessage { + LanguageModelRequestMessage { role: self.role, content: content.trim_end().into(), } @@ -2718,16 +2455,12 @@ enum InlineAssistantEvent { Confirmed { prompt: String, include_conversation: bool, - retrieve_context: bool, }, Canceled, Dismissed, IncludeConversationToggled { include_conversation: bool, }, - RetrieveContextToggled { - retrieve_context: bool, - }, } struct InlineAssistant { @@ -2736,24 +2469,19 @@ struct InlineAssistant { workspace: WeakView, confirmed: bool, include_conversation: bool, - measurements: Rc>, + measurements: Arc>, prompt_history: VecDeque, prompt_history_ix: Option, pending_prompt: String, codegen: Model, _subscriptions: Vec, - retrieve_context: bool, - semantic_index: Option>, - semantic_permissioned: Option, - project: WeakModel, - maintain_rate_limit: Option>, } impl EventEmitter for InlineAssistant {} impl Render for InlineAssistant { fn render(&mut self, cx: &mut ViewContext) -> impl Element { - let measurements = self.measurements.get(); + let measurements = *self.measurements.lock(); h_flex() .w_full() .py_2() @@ -2762,7 +2490,6 @@ impl Render for InlineAssistant { .on_action(cx.listener(Self::confirm)) .on_action(cx.listener(Self::cancel)) .on_action(cx.listener(Self::toggle_include_conversation)) - .on_action(cx.listener(Self::toggle_retrieve_context)) .on_action(cx.listener(Self::move_up)) .on_action(cx.listener(Self::move_down)) .child( @@ -2783,24 +2510,6 @@ impl Render for InlineAssistant { ) }), ) - .children(if SemanticIndex::enabled(cx) { - Some( - IconButton::new("retrieve_context", IconName::MagnifyingGlass) - .on_click(cx.listener(|this, _, cx| { - this.toggle_retrieve_context(&ToggleRetrieveContext, cx) - })) - .selected(self.retrieve_context) - .tooltip(|cx| { - Tooltip::for_action( - "Retrieve Context", - &ToggleRetrieveContext, - cx, - ) - }), - ) - } else { - None - }) .children(if let Some(error) = self.codegen.read(cx).error() { let error_message = SharedString::from(error.to_string()); Some( @@ -2819,11 +2528,6 @@ impl Render for InlineAssistant { .ml(measurements.anchor_x - measurements.gutter_width) .child(self.render_prompt_editor(cx)), ) - .children(if self.retrieve_context { - self.retrieve_context_status(cx) - } else { - None - }) } } @@ -2834,18 +2538,14 @@ impl FocusableView for InlineAssistant { } impl InlineAssistant { - #[allow(clippy::too_many_arguments)] fn new( id: usize, - measurements: Rc>, + measurements: Arc>, include_conversation: bool, prompt_history: VecDeque, codegen: Model, workspace: WeakView, cx: &mut ViewContext, - retrieve_context: bool, - semantic_index: Option>, - project: Model, ) -> Self { let prompt_editor = cx.new_view(|cx| { let mut editor = Editor::single_line(cx); @@ -2858,16 +2558,12 @@ impl InlineAssistant { }); cx.focus_view(&prompt_editor); - let mut subscriptions = vec![ + let subscriptions = vec![ cx.observe(&codegen, Self::handle_codegen_changed), cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events), ]; - if let Some(semantic_index) = semantic_index.clone() { - subscriptions.push(cx.observe(&semantic_index, Self::semantic_index_changed)); - } - - let assistant = Self { + Self { id, prompt_editor, workspace, @@ -2879,33 +2575,7 @@ impl InlineAssistant { pending_prompt: String::new(), codegen, _subscriptions: subscriptions, - retrieve_context, - semantic_permissioned: None, - semantic_index, - project: project.downgrade(), - maintain_rate_limit: None, - }; - - assistant.index_project(cx).log_err(); - - assistant - } - - fn semantic_permissioned(&self, cx: &mut ViewContext) -> Task> { - if let Some(value) = self.semantic_permissioned { - return Task::ready(Ok(value)); } - - let Some(project) = self.project.upgrade() else { - return Task::ready(Err(anyhow!("project was dropped"))); - }; - - self.semantic_index - .as_ref() - .map(|semantic| { - semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx)) - }) - .unwrap_or(Task::ready(Ok(false))) } fn handle_prompt_editor_events( @@ -2920,37 +2590,6 @@ impl InlineAssistant { } } - fn semantic_index_changed( - &mut self, - semantic_index: Model, - cx: &mut ViewContext, - ) { - let Some(project) = self.project.upgrade() else { - return; - }; - - let status = semantic_index.read(cx).status(&project); - match status { - SemanticIndexStatus::Indexing { - rate_limit_expiry: Some(_), - .. - } => { - if self.maintain_rate_limit.is_none() { - self.maintain_rate_limit = Some(cx.spawn(|this, mut cx| async move { - loop { - cx.background_executor().timer(Duration::from_secs(1)).await; - this.update(&mut cx, |_, cx| cx.notify()).log_err(); - } - })); - } - return; - } - _ => { - self.maintain_rate_limit = None; - } - } - } - fn handle_codegen_changed(&mut self, _: Model, cx: &mut ViewContext) { let is_read_only = !self.codegen.read(cx).idle(); self.prompt_editor.update(cx, |editor, cx| { @@ -2983,161 +2622,12 @@ impl InlineAssistant { cx.emit(InlineAssistantEvent::Confirmed { prompt, include_conversation: self.include_conversation, - retrieve_context: self.retrieve_context, }); self.confirmed = true; cx.notify(); } } - fn toggle_retrieve_context(&mut self, _: &ToggleRetrieveContext, cx: &mut ViewContext) { - let semantic_permissioned = self.semantic_permissioned(cx); - - let Some(project) = self.project.upgrade() else { - return; - }; - - let project_name = project - .read(cx) - .worktree_root_names(cx) - .collect::>() - .join("/"); - let is_plural = project_name.chars().filter(|letter| *letter == '/').count() > 0; - let prompt_text = format!("Would you like to index the '{}' project{} for context retrieval? This requires sending code to the OpenAI API", project_name, - if is_plural { - "s" - } else {""}); - - cx.spawn(|this, mut cx| async move { - // If Necessary prompt user - if !semantic_permissioned.await.unwrap_or(false) { - let answer = this.update(&mut cx, |_, cx| { - cx.prompt( - PromptLevel::Info, - prompt_text.as_str(), - None, - &["Continue", "Cancel"], - ) - })?; - - if answer.await? == 0 { - this.update(&mut cx, |this, _| { - this.semantic_permissioned = Some(true); - })?; - } else { - return anyhow::Ok(()); - } - } - - // If permissioned, update context appropriately - this.update(&mut cx, |this, cx| { - this.retrieve_context = !this.retrieve_context; - - cx.emit(InlineAssistantEvent::RetrieveContextToggled { - retrieve_context: this.retrieve_context, - }); - - if this.retrieve_context { - this.index_project(cx).log_err(); - } - - cx.notify(); - })?; - - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - - fn index_project(&self, cx: &mut ViewContext) -> anyhow::Result<()> { - let Some(project) = self.project.upgrade() else { - return Err(anyhow!("project was dropped!")); - }; - - let semantic_permissioned = self.semantic_permissioned(cx); - if let Some(semantic_index) = SemanticIndex::global(cx) { - cx.spawn(|_, mut cx| async move { - // This has to be updated to accommodate for semantic_permissions - if semantic_permissioned.await.unwrap_or(false) { - semantic_index - .update(&mut cx, |index, cx| index.index_project(project, cx))? - .await - } else { - Err(anyhow!("project is not permissioned for semantic indexing")) - } - }) - .detach_and_log_err(cx); - } - - anyhow::Ok(()) - } - - fn retrieve_context_status(&self, cx: &mut ViewContext) -> Option { - let Some(project) = self.project.upgrade() else { - return None; - }; - - let semantic_index = SemanticIndex::global(cx)?; - let status = semantic_index.update(cx, |index, _| index.status(&project)); - match status { - SemanticIndexStatus::NotAuthenticated {} => Some( - div() - .id("error") - .tooltip(|cx| Tooltip::text("Not Authenticated. Please ensure you have a valid 'OPENAI_API_KEY' in your environment variables.", cx)) - .child(Icon::new(IconName::XCircle)) - .into_any_element() - ), - - SemanticIndexStatus::NotIndexed {} => Some( - div() - .id("error") - .tooltip(|cx| Tooltip::text("Not Indexed", cx)) - .child(Icon::new(IconName::XCircle)) - .into_any_element() - ), - - SemanticIndexStatus::Indexing { - remaining_files, - rate_limit_expiry, - } => { - let mut status_text = if remaining_files == 0 { - "Indexing...".to_string() - } else { - format!("Remaining files to index: {remaining_files}") - }; - - if let Some(rate_limit_expiry) = rate_limit_expiry { - let remaining_seconds = rate_limit_expiry.duration_since(Instant::now()); - if remaining_seconds > Duration::from_secs(0) && remaining_files > 0 { - write!( - status_text, - " (rate limit expires in {}s)", - remaining_seconds.as_secs() - ) - .unwrap(); - } - } - - let status_text = SharedString::from(status_text); - Some( - div() - .id("update") - .tooltip(move |cx| Tooltip::text(status_text.clone(), cx)) - .child(Icon::new(IconName::Update).color(Color::Info)) - .into_any_element() - ) - } - - SemanticIndexStatus::Indexed {} => Some( - div() - .id("check") - .tooltip(|cx| Tooltip::text("Index up to date", cx)) - .child(Icon::new(IconName::Check).color(Color::Success)) - .into_any_element() - ), - } - } - fn toggle_include_conversation( &mut self, _: &ToggleIncludeConversation, @@ -3255,23 +2745,47 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { } } +fn report_assistant_event( + workspace: WeakView, + conversation: Option<&Conversation>, + assistant_kind: AssistantKind, + cx: &mut AppContext, +) { + let Some(workspace) = workspace.upgrade() else { + return; + }; + + let client = workspace.read(cx).project().read(cx).client(); + let telemetry = client.telemetry(); + + let conversation_id = conversation.and_then(|conversation| conversation.id.clone()); + let model_id = conversation + .map(|c| c.model.telemetry_id()) + .unwrap_or_else(|| { + CompletionProvider::global(cx) + .default_model() + .telemetry_id() + }); + telemetry.report_assistant_event(conversation_id, assistant_kind, model_id) +} + #[cfg(test)] mod tests { use super::*; - use crate::MessageId; - use ai::test::FakeCompletionProvider; + use crate::{FakeCompletionProvider, MessageId}; use gpui::{AppContext, TestAppContext}; use settings::SettingsStore; #[gpui::test] fn test_inserting_and_removing_messages(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); + cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); cx.set_global(settings_store); init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let completion_provider = Arc::new(FakeCompletionProvider::new()); - let conversation = cx.new_model(|cx| Conversation::new(registry, cx, completion_provider)); + let conversation = + cx.new_model(|cx| Conversation::new(LanguageModel::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3398,11 +2912,12 @@ mod tests { fn test_message_splitting(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); + cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let completion_provider = Arc::new(FakeCompletionProvider::new()); - let conversation = cx.new_model(|cx| Conversation::new(registry, cx, completion_provider)); + let conversation = + cx.new_model(|cx| Conversation::new(LanguageModel::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3496,11 +3011,12 @@ mod tests { #[gpui::test] fn test_messages_for_offsets(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); + cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); cx.set_global(settings_store); init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); - let completion_provider = Arc::new(FakeCompletionProvider::new()); - let conversation = cx.new_model(|cx| Conversation::new(registry, cx, completion_provider)); + let conversation = + cx.new_model(|cx| Conversation::new(LanguageModel::default(), registry, cx)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3581,11 +3097,11 @@ mod tests { async fn test_serialization(cx: &mut TestAppContext) { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); + cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); cx.update(init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); - let completion_provider = Arc::new(FakeCompletionProvider::new()); let conversation = - cx.new_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); + cx.new_model(|cx| Conversation::new(LanguageModel::default(), registry.clone(), cx)); let buffer = conversation.read_with(cx, |conversation, _| conversation.buffer.clone()); let message_0 = conversation.read_with(cx, |conversation, _| conversation.message_anchors[0].id); @@ -3621,6 +3137,7 @@ mod tests { let deserialized_conversation = Conversation::deserialize( conversation.read_with(cx, |conversation, cx| conversation.serialize(cx)), + LanguageModel::default(), Default::default(), registry.clone(), &mut cx.to_async(), @@ -3654,23 +3171,3 @@ mod tests { .collect() } } - -fn report_assistant_event( - workspace: WeakView, - conversation_id: Option, - assistant_kind: AssistantKind, - cx: &AppContext, -) { - let Some(workspace) = workspace.upgrade() else { - return; - }; - - let client = workspace.read(cx).project().read(cx).client(); - let telemetry = client.telemetry(); - - let Ok(model_name) = AssistantSettings::get_global(cx).provider_model_name() else { - return; - }; - - telemetry.report_assistant_event(conversation_id, assistant_kind, &model_name) -} diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 007e994389d4ac..f338f7e8fb9ebb 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -1,169 +1,296 @@ -use ai::providers::open_ai::{ - AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL, -}; -use anyhow::anyhow; +use std::fmt; + use gpui::Pixels; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; +pub use open_ai::Model as OpenAiModel; +use schemars::{ + schema::{InstanceType, Metadata, Schema, SchemaObject}, + JsonSchema, +}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; use settings::Settings; -#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum OpenAiModel { - #[serde(rename = "gpt-3.5-turbo-0613")] - ThreePointFiveTurbo, - #[serde(rename = "gpt-4-0613")] - Four, - #[serde(rename = "gpt-4-1106-preview")] - FourTurbo, +#[derive(Clone, Debug, Default, PartialEq)] +pub enum ZedDotDevModel { + GptThreePointFiveTurbo, + GptFour, + #[default] + GptFourTurbo, + Custom(String), } -impl OpenAiModel { - pub fn full_name(&self) -> &'static str { - match self { - Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613", - Self::Four => "gpt-4-0613", - Self::FourTurbo => "gpt-4-1106-preview", +impl Serialize for ZedDotDevModel { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(self.id()) + } +} + +impl<'de> Deserialize<'de> for ZedDotDevModel { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ZedDotDevModelVisitor; + + impl<'de> Visitor<'de> for ZedDotDevModelVisitor { + type Value = ZedDotDevModel; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string for a ZedDotDevModel variant or a custom model") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo), + "gpt-4" => Ok(ZedDotDevModel::GptFour), + "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo), + _ => Ok(ZedDotDevModel::Custom(value.to_owned())), + } + } } + + deserializer.deserialize_str(ZedDotDevModelVisitor) } +} + +impl JsonSchema for ZedDotDevModel { + fn schema_name() -> String { + "ZedDotDevModel".to_owned() + } + + fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { + let variants = vec![ + "gpt-3.5-turbo".to_owned(), + "gpt-4".to_owned(), + "gpt-4-turbo-preview".to_owned(), + ]; + Schema::Object(SchemaObject { + instance_type: Some(InstanceType::String.into()), + enum_values: Some(variants.into_iter().map(|s| s.into()).collect()), + metadata: Some(Box::new(Metadata { + title: Some("ZedDotDevModel".to_owned()), + default: Some(serde_json::json!("gpt-4-turbo-preview")), + examples: vec![ + serde_json::json!("gpt-3.5-turbo"), + serde_json::json!("gpt-4"), + serde_json::json!("gpt-4-turbo-preview"), + serde_json::json!("custom-model-name"), + ], + ..Default::default() + })), + ..Default::default() + }) + } +} - pub fn short_name(&self) -> &'static str { +impl ZedDotDevModel { + pub fn id(&self) -> &str { match self { - Self::ThreePointFiveTurbo => "gpt-3.5-turbo", - Self::Four => "gpt-4", - Self::FourTurbo => "gpt-4-turbo", + Self::GptThreePointFiveTurbo => "gpt-3.5-turbo", + Self::GptFour => "gpt-4", + Self::GptFourTurbo => "gpt-4-turbo-preview", + Self::Custom(id) => id, } } - pub fn cycle(&self) -> Self { + pub fn display_name(&self) -> &str { match self { - Self::ThreePointFiveTurbo => Self::Four, - Self::Four => Self::FourTurbo, - Self::FourTurbo => Self::ThreePointFiveTurbo, + Self::GptThreePointFiveTurbo => "gpt-3.5-turbo", + Self::GptFour => "gpt-4", + Self::GptFourTurbo => "gpt-4-turbo", + Self::Custom(id) => id.as_str(), } } } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum AssistantDockPosition { Left, + #[default] Right, Bottom, } -#[derive(Debug, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)] +#[serde(tag = "name", rename_all = "snake_case")] +pub enum AssistantProvider { + #[serde(rename = "zed.dev")] + ZedDotDev { + #[serde(default)] + default_model: ZedDotDevModel, + }, + #[serde(rename = "openai")] + OpenAi { + #[serde(default)] + default_model: OpenAiModel, + #[serde(default = "open_ai_url")] + api_url: String, + }, +} + +impl Default for AssistantProvider { + fn default() -> Self { + Self::ZedDotDev { + default_model: ZedDotDevModel::default(), + } + } +} + +fn open_ai_url() -> String { + "https://api.openai.com/v1".into() +} + +#[derive(Default, Debug, Deserialize, Serialize)] pub struct AssistantSettings { - /// Whether to show the assistant panel button in the status bar. pub button: bool, - /// Where to dock the assistant. pub dock: AssistantDockPosition, - /// Default width in pixels when the assistant is docked to the left or right. pub default_width: Pixels, - /// Default height in pixels when the assistant is docked to the bottom. pub default_height: Pixels, - /// The default OpenAI model to use when starting new conversations. - #[deprecated = "Please use `provider.default_model` instead."] - pub default_open_ai_model: OpenAiModel, - /// OpenAI API base URL to use when starting new conversations. - #[deprecated = "Please use `provider.api_url` instead."] - pub openai_api_url: String, - /// The settings for the AI provider. - pub provider: AiProviderSettings, + pub provider: AssistantProvider, } -impl AssistantSettings { - pub fn provider_kind(&self) -> anyhow::Result { - match &self.provider { - AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi), - AiProviderSettings::AzureOpenAi(settings) => { - let deployment_id = settings - .deployment_id - .clone() - .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?; - let api_version = settings - .api_version - .ok_or_else(|| anyhow!("no Azure OpenAI API version"))?; - - Ok(OpenAiCompletionProviderKind::AzureOpenAi { - deployment_id, - api_version, - }) - } - } +/// Assistant panel settings +#[derive(Clone, Serialize, Deserialize, Debug)] +#[serde(untagged)] +pub enum AssistantSettingsContent { + Versioned(VersionedAssistantSettingsContent), + Legacy(LegacyAssistantSettingsContent), +} + +impl JsonSchema for AssistantSettingsContent { + fn schema_name() -> String { + VersionedAssistantSettingsContent::schema_name() } - pub fn provider_api_url(&self) -> anyhow::Result { - match &self.provider { - AiProviderSettings::OpenAi(settings) => Ok(settings - .api_url - .clone() - .unwrap_or_else(|| OPEN_AI_API_URL.to_string())), - AiProviderSettings::AzureOpenAi(settings) => settings - .api_url - .clone() - .ok_or_else(|| anyhow!("no Azure OpenAI API URL")), - } + fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema { + VersionedAssistantSettingsContent::json_schema(gen) } - pub fn provider_model(&self) -> anyhow::Result { - match &self.provider { - AiProviderSettings::OpenAi(settings) => { - Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo)) - } - AiProviderSettings::AzureOpenAi(settings) => { - let deployment_id = settings - .deployment_id - .as_deref() - .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?; - - match deployment_id { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-preview - "gpt-4" | "gpt-4-32k" => Ok(OpenAiModel::Four), - // https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-35 - "gpt-35-turbo" | "gpt-35-turbo-16k" | "gpt-35-turbo-instruct" => { - Ok(OpenAiModel::ThreePointFiveTurbo) + fn is_referenceable() -> bool { + VersionedAssistantSettingsContent::is_referenceable() + } +} + +impl Default for AssistantSettingsContent { + fn default() -> Self { + Self::Versioned(VersionedAssistantSettingsContent::default()) + } +} + +impl AssistantSettingsContent { + fn upgrade(&self) -> AssistantSettingsContentV1 { + match self { + AssistantSettingsContent::Versioned(settings) => match settings { + VersionedAssistantSettingsContent::V1(settings) => settings.clone(), + }, + AssistantSettingsContent::Legacy(settings) => { + if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() { + AssistantSettingsContentV1 { + button: settings.button, + dock: settings.dock, + default_width: settings.default_width, + default_height: settings.default_height, + provider: Some(AssistantProvider::OpenAi { + default_model: settings + .default_open_ai_model + .clone() + .unwrap_or_default(), + api_url: open_ai_api_url.clone(), + }), + } + } else if let Some(open_ai_model) = settings.default_open_ai_model.clone() { + AssistantSettingsContentV1 { + button: settings.button, + dock: settings.dock, + default_width: settings.default_width, + default_height: settings.default_height, + provider: Some(AssistantProvider::OpenAi { + default_model: open_ai_model, + api_url: open_ai_url(), + }), + } + } else { + AssistantSettingsContentV1 { + button: settings.button, + dock: settings.dock, + default_width: settings.default_width, + default_height: settings.default_height, + provider: None, } - _ => Err(anyhow!( - "no matching OpenAI model found for deployment ID: '{deployment_id}'" - )), } } } } - pub fn provider_model_name(&self) -> anyhow::Result { - match &self.provider { - AiProviderSettings::OpenAi(settings) => Ok(settings - .default_model - .unwrap_or(OpenAiModel::FourTurbo) - .full_name() - .to_string()), - AiProviderSettings::AzureOpenAi(settings) => settings - .deployment_id - .clone() - .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")), + pub fn set_dock(&mut self, dock: AssistantDockPosition) { + match self { + AssistantSettingsContent::Versioned(settings) => match settings { + VersionedAssistantSettingsContent::V1(settings) => { + settings.dock = Some(dock); + } + }, + AssistantSettingsContent::Legacy(settings) => { + settings.dock = Some(dock); + } } } } -impl Settings for AssistantSettings { - const KEY: Option<&'static str> = Some("assistant"); - - type FileContent = AssistantSettingsContent; +#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] +#[serde(tag = "version")] +pub enum VersionedAssistantSettingsContent { + #[serde(rename = "1")] + V1(AssistantSettingsContentV1), +} - fn load( - default_value: &Self::FileContent, - user_values: &[&Self::FileContent], - _: &mut gpui::AppContext, - ) -> anyhow::Result { - Self::load_via_json_merge(default_value, user_values) +impl Default for VersionedAssistantSettingsContent { + fn default() -> Self { + Self::V1(AssistantSettingsContentV1 { + button: None, + dock: None, + default_width: None, + default_height: None, + provider: None, + }) } } -/// Assistant panel settings -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] -pub struct AssistantSettingsContent { +#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] +pub struct AssistantSettingsContentV1 { + /// Whether to show the assistant panel button in the status bar. + /// + /// Default: true + button: Option, + /// Where to dock the assistant. + /// + /// Default: right + dock: Option, + /// Default width in pixels when the assistant is docked to the left or right. + /// + /// Default: 640 + default_width: Option, + /// Default height in pixels when the assistant is docked to the bottom. + /// + /// Default: 320 + default_height: Option, + /// The provider of the assistant service. + /// + /// This can either be the internal `zed.dev` service or an external `openai` service, + /// each with their respective default models and configurations. + provider: Option, +} + +#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)] +pub struct LegacyAssistantSettingsContent { /// Whether to show the assistant panel button in the status bar. /// /// Default: true @@ -180,88 +307,164 @@ pub struct AssistantSettingsContent { /// /// Default: 320 pub default_height: Option, - /// Deprecated: Please use `provider.default_model` instead. /// The default OpenAI model to use when starting new conversations. /// /// Default: gpt-4-1106-preview - #[deprecated = "Please use `provider.default_model` instead."] pub default_open_ai_model: Option, - /// Deprecated: Please use `provider.api_url` instead. /// OpenAI API base URL to use when starting new conversations. /// /// Default: https://api.openai.com/v1 - #[deprecated = "Please use `provider.api_url` instead."] pub openai_api_url: Option, - /// The settings for the AI provider. - #[serde(default)] - pub provider: AiProviderSettingsContent, } -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum AiProviderSettings { - /// The settings for the OpenAI provider. - #[serde(rename = "openai")] - OpenAi(OpenAiProviderSettings), - /// The settings for the Azure OpenAI provider. - #[serde(rename = "azure_openai")] - AzureOpenAi(AzureOpenAiProviderSettings), -} +impl Settings for AssistantSettings { + const KEY: Option<&'static str> = Some("assistant"); -/// The settings for the AI provider used by the Zed Assistant. -#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum AiProviderSettingsContent { - /// The settings for the OpenAI provider. - #[serde(rename = "openai")] - OpenAi(OpenAiProviderSettingsContent), - /// The settings for the Azure OpenAI provider. - #[serde(rename = "azure_openai")] - AzureOpenAi(AzureOpenAiProviderSettingsContent), -} + type FileContent = AssistantSettingsContent; -impl Default for AiProviderSettingsContent { - fn default() -> Self { - Self::OpenAi(OpenAiProviderSettingsContent::default()) + fn load( + default_value: &Self::FileContent, + user_values: &[&Self::FileContent], + _: &mut gpui::AppContext, + ) -> anyhow::Result { + let mut settings = AssistantSettings::default(); + + for value in [default_value].iter().chain(user_values) { + let value = value.upgrade(); + merge(&mut settings.button, value.button); + merge(&mut settings.dock, value.dock); + merge( + &mut settings.default_width, + value.default_width.map(Into::into), + ); + merge( + &mut settings.default_height, + value.default_height.map(Into::into), + ); + if let Some(provider) = value.provider.clone() { + match (&mut settings.provider, provider) { + ( + AssistantProvider::ZedDotDev { default_model }, + AssistantProvider::ZedDotDev { + default_model: default_model_override, + }, + ) => { + *default_model = default_model_override; + } + ( + AssistantProvider::OpenAi { + default_model, + api_url, + }, + AssistantProvider::OpenAi { + default_model: default_model_override, + api_url: api_url_override, + }, + ) => { + *default_model = default_model_override; + *api_url = api_url_override; + } + (merged, provider_override) => { + *merged = provider_override; + } + } + } + } + + Ok(settings) } } -#[derive(Debug, Clone, Deserialize)] -pub struct OpenAiProviderSettings { - /// The OpenAI API base URL to use when starting new conversations. - pub api_url: Option, - /// The default OpenAI model to use when starting new conversations. - pub default_model: Option, +fn merge(target: &mut T, value: Option) { + if let Some(value) = value { + *target = value; + } } -#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)] -pub struct OpenAiProviderSettingsContent { - /// The OpenAI API base URL to use when starting new conversations. - /// - /// Default: https://api.openai.com/v1 - pub api_url: Option, - /// The default OpenAI model to use when starting new conversations. - /// - /// Default: gpt-4-1106-preview - pub default_model: Option, -} +#[cfg(test)] +mod tests { + use gpui::AppContext; + use settings::SettingsStore; -#[derive(Debug, Clone, Deserialize)] -pub struct AzureOpenAiProviderSettings { - /// The Azure OpenAI API base URL to use when starting new conversations. - pub api_url: Option, - /// The Azure OpenAI API version. - pub api_version: Option, - /// The Azure OpenAI API deployment ID. - pub deployment_id: Option, -} + use super::*; + + #[gpui::test] + fn test_deserialize_assistant_settings(cx: &mut AppContext) { + let store = settings::SettingsStore::test(cx); + cx.set_global(store); -#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)] -pub struct AzureOpenAiProviderSettingsContent { - /// The Azure OpenAI API base URL to use when starting new conversations. - pub api_url: Option, - /// The Azure OpenAI API version. - pub api_version: Option, - /// The Azure OpenAI deployment ID. - pub deployment_id: Option, + // Settings default to gpt-4-turbo. + AssistantSettings::register(cx); + assert_eq!( + AssistantSettings::get_global(cx).provider, + AssistantProvider::OpenAi { + default_model: OpenAiModel::FourTurbo, + api_url: open_ai_url() + } + ); + + // Ensure backward-compatibility. + cx.update_global::(|store, cx| { + store + .set_user_settings( + r#"{ + "assistant": { + "openai_api_url": "test-url", + } + }"#, + cx, + ) + .unwrap(); + }); + assert_eq!( + AssistantSettings::get_global(cx).provider, + AssistantProvider::OpenAi { + default_model: OpenAiModel::FourTurbo, + api_url: "test-url".into() + } + ); + cx.update_global::(|store, cx| { + store + .set_user_settings( + r#"{ + "assistant": { + "default_open_ai_model": "gpt-4-0613" + } + }"#, + cx, + ) + .unwrap(); + }); + assert_eq!( + AssistantSettings::get_global(cx).provider, + AssistantProvider::OpenAi { + default_model: OpenAiModel::Four, + api_url: open_ai_url() + } + ); + + // The new version supports setting a custom model when using zed.dev. + cx.update_global::(|store, cx| { + store + .set_user_settings( + r#"{ + "assistant": { + "version": "1", + "provider": { + "name": "zed.dev", + "default_model": "custom" + } + } + }"#, + cx, + ) + .unwrap(); + }); + assert_eq!( + AssistantSettings::get_global(cx).provider, + AssistantProvider::ZedDotDev { + default_model: ZedDotDevModel::Custom("custom".into()) + } + ); + } } diff --git a/crates/assistant/src/codegen.rs b/crates/assistant/src/codegen.rs index 04d08a3315f505..18672d6455c133 100644 --- a/crates/assistant/src/codegen.rs +++ b/crates/assistant/src/codegen.rs @@ -1,12 +1,13 @@ -use crate::streaming_diff::{Hunk, StreamingDiff}; -use ai::completion::{CompletionProvider, CompletionRequest}; +use crate::{ + streaming_diff::{Hunk, StreamingDiff}, + CompletionProvider, LanguageModelRequest, +}; use anyhow::Result; use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint}; use futures::{channel::mpsc, SinkExt, Stream, StreamExt}; use gpui::{EventEmitter, Model, ModelContext, Task}; use language::{Rope, TransactionId}; -use multi_buffer; -use std::{cmp, future, ops::Range, sync::Arc}; +use std::{cmp, future, ops::Range}; pub enum Event { Finished, @@ -20,7 +21,6 @@ pub enum CodegenKind { } pub struct Codegen { - provider: Arc, buffer: Model, snapshot: MultiBufferSnapshot, kind: CodegenKind, @@ -35,15 +35,9 @@ pub struct Codegen { impl EventEmitter for Codegen {} impl Codegen { - pub fn new( - buffer: Model, - kind: CodegenKind, - provider: Arc, - cx: &mut ModelContext, - ) -> Self { + pub fn new(buffer: Model, kind: CodegenKind, cx: &mut ModelContext) -> Self { let snapshot = buffer.read(cx).snapshot(cx); Self { - provider, buffer: buffer.clone(), snapshot, kind, @@ -94,7 +88,7 @@ impl Codegen { self.error.as_ref() } - pub fn start(&mut self, prompt: Box, cx: &mut ModelContext) { + pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext) { let range = self.range(); let snapshot = self.snapshot.clone(); let selected_text = snapshot @@ -108,7 +102,7 @@ impl Codegen { .next() .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row)); - let response = self.provider.complete(prompt); + let response = CompletionProvider::global(cx).complete(prompt); self.generation = cx.spawn(|this, mut cx| { async move { let generate = async { @@ -305,7 +299,7 @@ fn strip_invalid_spans_from_codeblock( } if first_line { - if buffer == "" || buffer == "`" || buffer == "``" { + if buffer.is_empty() || buffer == "`" || buffer == "``" { return future::ready(None); } else if buffer.starts_with("```") { starts_with_markdown_codeblock = true; @@ -360,8 +354,9 @@ fn strip_invalid_spans_from_codeblock( mod tests { use std::sync::Arc; + use crate::FakeCompletionProvider; + use super::*; - use ai::test::FakeCompletionProvider; use futures::stream::{self}; use gpui::{Context, TestAppContext}; use indoc::indoc; @@ -378,15 +373,11 @@ mod tests { pub name: String, } - impl CompletionRequest for DummyCompletionRequest { - fn data(&self) -> serde_json::Result { - serde_json::to_string(self) - } - } - #[gpui::test(iterations = 10)] async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { + let provider = FakeCompletionProvider::default(); cx.set_global(cx.update(SettingsStore::test)); + cx.set_global(CompletionProvider::Fake(provider.clone())); cx.update(language_settings::init); let text = indoc! {" @@ -405,19 +396,10 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let provider = Arc::new(FakeCompletionProvider::new()); - let codegen = cx.new_model(|cx| { - Codegen::new( - buffer.clone(), - CodegenKind::Transform { range }, - provider.clone(), - cx, - ) - }); + let codegen = + cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx)); - let request = Box::new(DummyCompletionRequest { - name: "test".to_string(), - }); + let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( @@ -430,8 +412,7 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); - println!("CHUNK: {:?}", &chunk); - provider.send_completion(chunk); + provider.send_completion(chunk.into()); new_text = suffix; cx.background_executor.run_until_parked(); } @@ -456,6 +437,8 @@ mod tests { cx: &mut TestAppContext, mut rng: StdRng, ) { + let provider = FakeCompletionProvider::default(); + cx.set_global(CompletionProvider::Fake(provider.clone())); cx.set_global(cx.update(SettingsStore::test)); cx.update(language_settings::init); @@ -472,19 +455,10 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6)) }); - let provider = Arc::new(FakeCompletionProvider::new()); - let codegen = cx.new_model(|cx| { - Codegen::new( - buffer.clone(), - CodegenKind::Generate { position }, - provider.clone(), - cx, - ) - }); + let codegen = + cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx)); - let request = Box::new(DummyCompletionRequest { - name: "test".to_string(), - }); + let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( @@ -497,7 +471,7 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk); + provider.send_completion(chunk.into()); new_text = suffix; cx.background_executor.run_until_parked(); } @@ -522,6 +496,8 @@ mod tests { cx: &mut TestAppContext, mut rng: StdRng, ) { + let provider = FakeCompletionProvider::default(); + cx.set_global(CompletionProvider::Fake(provider.clone())); cx.set_global(cx.update(SettingsStore::test)); cx.update(language_settings::init); @@ -538,19 +514,10 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2)) }); - let provider = Arc::new(FakeCompletionProvider::new()); - let codegen = cx.new_model(|cx| { - Codegen::new( - buffer.clone(), - CodegenKind::Generate { position }, - provider.clone(), - cx, - ) - }); + let codegen = + cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx)); - let request = Box::new(DummyCompletionRequest { - name: "test".to_string(), - }); + let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| codegen.start(request, cx)); let mut new_text = concat!( @@ -563,8 +530,7 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); - println!("{:?}", &chunk); - provider.send_completion(chunk); + provider.send_completion(chunk.into()); new_text = suffix; cx.background_executor.run_until_parked(); } diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs new file mode 100644 index 00000000000000..d3cdc9e71614e3 --- /dev/null +++ b/crates/assistant/src/completion_provider.rs @@ -0,0 +1,188 @@ +#[cfg(test)] +mod fake; +mod open_ai; +mod zed; + +#[cfg(test)] +pub use fake::*; +pub use open_ai::*; +pub use zed::*; + +use crate::{ + assistant_settings::{AssistantProvider, AssistantSettings}, + LanguageModel, LanguageModelRequest, +}; +use anyhow::Result; +use client::Client; +use futures::{future::BoxFuture, stream::BoxStream}; +use gpui::{AnyView, AppContext, Task, WindowContext}; +use settings::{Settings, SettingsStore}; +use std::sync::Arc; + +pub fn init(client: Arc, cx: &mut AppContext) { + let mut settings_version = 0; + let provider = match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { default_model } => { + CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( + default_model.clone(), + client.clone(), + settings_version, + cx, + )) + } + AssistantProvider::OpenAi { + default_model, + api_url, + } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new( + default_model.clone(), + api_url.clone(), + client.http_client(), + settings_version, + )), + }; + cx.set_global(provider); + + cx.observe_global::(move |cx| { + settings_version += 1; + cx.update_global::(|provider, cx| { + match (&mut *provider, &AssistantSettings::get_global(cx).provider) { + ( + CompletionProvider::OpenAi(provider), + AssistantProvider::OpenAi { + default_model, + api_url, + }, + ) => { + provider.update(default_model.clone(), api_url.clone(), settings_version); + } + ( + CompletionProvider::ZedDotDev(provider), + AssistantProvider::ZedDotDev { default_model }, + ) => { + provider.update(default_model.clone(), settings_version); + } + (CompletionProvider::OpenAi(_), AssistantProvider::ZedDotDev { default_model }) => { + *provider = CompletionProvider::ZedDotDev(ZedDotDevCompletionProvider::new( + default_model.clone(), + client.clone(), + settings_version, + cx, + )); + } + ( + CompletionProvider::ZedDotDev(_), + AssistantProvider::OpenAi { + default_model, + api_url, + }, + ) => { + *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new( + default_model.clone(), + api_url.clone(), + client.http_client(), + settings_version, + )); + } + #[cfg(test)] + (CompletionProvider::Fake(_), _) => unimplemented!(), + } + }) + }) + .detach(); +} + +pub enum CompletionProvider { + OpenAi(OpenAiCompletionProvider), + ZedDotDev(ZedDotDevCompletionProvider), + #[cfg(test)] + Fake(FakeCompletionProvider), +} + +impl gpui::Global for CompletionProvider {} + +impl CompletionProvider { + pub fn global(cx: &AppContext) -> &Self { + cx.global::() + } + + pub fn settings_version(&self) -> usize { + match self { + CompletionProvider::OpenAi(provider) => provider.settings_version(), + CompletionProvider::ZedDotDev(provider) => provider.settings_version(), + #[cfg(test)] + CompletionProvider::Fake(_) => unimplemented!(), + } + } + + pub fn is_authenticated(&self) -> bool { + match self { + CompletionProvider::OpenAi(provider) => provider.is_authenticated(), + CompletionProvider::ZedDotDev(provider) => provider.is_authenticated(), + #[cfg(test)] + CompletionProvider::Fake(_) => true, + } + } + + pub fn authenticate(&self, cx: &AppContext) -> Task> { + match self { + CompletionProvider::OpenAi(provider) => provider.authenticate(cx), + CompletionProvider::ZedDotDev(provider) => provider.authenticate(cx), + #[cfg(test)] + CompletionProvider::Fake(_) => Task::ready(Ok(())), + } + } + + pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + match self { + CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), + CompletionProvider::ZedDotDev(provider) => provider.authentication_prompt(cx), + #[cfg(test)] + CompletionProvider::Fake(_) => unimplemented!(), + } + } + + pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + match self { + CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), + CompletionProvider::ZedDotDev(_) => Task::ready(Ok(())), + #[cfg(test)] + CompletionProvider::Fake(_) => Task::ready(Ok(())), + } + } + + pub fn default_model(&self) -> LanguageModel { + match self { + CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.default_model()), + CompletionProvider::ZedDotDev(provider) => { + LanguageModel::ZedDotDev(provider.default_model()) + } + #[cfg(test)] + CompletionProvider::Fake(_) => unimplemented!(), + } + } + + pub fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + match self { + CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), + CompletionProvider::ZedDotDev(provider) => provider.count_tokens(request, cx), + #[cfg(test)] + CompletionProvider::Fake(_) => unimplemented!(), + } + } + + pub fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + match self { + CompletionProvider::OpenAi(provider) => provider.complete(request), + CompletionProvider::ZedDotDev(provider) => provider.complete(request), + #[cfg(test)] + CompletionProvider::Fake(provider) => provider.complete(), + } + } +} diff --git a/crates/assistant/src/completion_provider/fake.rs b/crates/assistant/src/completion_provider/fake.rs new file mode 100644 index 00000000000000..9c06796a376c74 --- /dev/null +++ b/crates/assistant/src/completion_provider/fake.rs @@ -0,0 +1,29 @@ +use anyhow::Result; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use std::sync::Arc; + +#[derive(Clone, Default)] +pub struct FakeCompletionProvider { + current_completion_tx: Arc>>>, +} + +impl FakeCompletionProvider { + pub fn complete(&self) -> BoxFuture<'static, Result>>> { + let (tx, rx) = mpsc::unbounded(); + *self.current_completion_tx.lock() = Some(tx); + async move { Ok(rx.map(Ok).boxed()) }.boxed() + } + + pub fn send_completion(&self, chunk: String) { + self.current_completion_tx + .lock() + .as_ref() + .unwrap() + .unbounded_send(chunk) + .unwrap(); + } + + pub fn finish_completion(&self) { + self.current_completion_tx.lock().take(); + } +} diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/assistant/src/completion_provider/open_ai.rs new file mode 100644 index 00000000000000..f4c29a47e826b3 --- /dev/null +++ b/crates/assistant/src/completion_provider/open_ai.rs @@ -0,0 +1,301 @@ +use crate::{ + assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, +}; +use anyhow::{anyhow, Result}; +use editor::{Editor, EditorElement, EditorStyle}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{AnyView, AppContext, FontStyle, FontWeight, Task, TextStyle, View, WhiteSpace}; +use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole}; +use settings::Settings; +use std::{env, sync::Arc}; +use theme::ThemeSettings; +use ui::prelude::*; +use util::{http::HttpClient, ResultExt}; + +pub struct OpenAiCompletionProvider { + api_key: Option, + api_url: String, + default_model: OpenAiModel, + http_client: Arc, + settings_version: usize, +} + +impl OpenAiCompletionProvider { + pub fn new( + default_model: OpenAiModel, + api_url: String, + http_client: Arc, + settings_version: usize, + ) -> Self { + Self { + api_key: None, + api_url, + default_model, + http_client, + settings_version, + } + } + + pub fn update(&mut self, default_model: OpenAiModel, api_url: String, settings_version: usize) { + self.default_model = default_model; + self.api_url = api_url; + self.settings_version = settings_version; + } + + pub fn settings_version(&self) -> usize { + self.settings_version + } + + pub fn is_authenticated(&self) -> bool { + self.api_key.is_some() + } + + pub fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + let api_url = self.api_url.clone(); + cx.spawn(|mut cx| async move { + let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { + api_key + } else { + let (_, api_key) = cx + .update(|cx| cx.read_credentials(&api_url))? + .await? + .ok_or_else(|| anyhow!("credentials not found"))?; + String::from_utf8(api_key)? + }; + cx.update_global::(|provider, _cx| { + if let CompletionProvider::OpenAi(provider) = provider { + provider.api_key = Some(api_key); + } + }) + }) + } + } + + pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + let delete_credentials = cx.delete_credentials(&self.api_url); + cx.spawn(|mut cx| async move { + delete_credentials.await.log_err(); + cx.update_global::(|provider, _cx| { + if let CompletionProvider::OpenAi(provider) = provider { + provider.api_key = None; + } + }) + }) + } + + pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx)) + .into() + } + + pub fn default_model(&self) -> OpenAiModel { + self.default_model.clone() + } + + pub fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + count_open_ai_tokens(request, cx.background_executor()) + } + + pub fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + let request = self.to_open_ai_request(request); + + let http_client = self.http_client.clone(); + let api_key = self.api_key.clone(); + let api_url = self.api_url.clone(); + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + + fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { + let model = match request.model { + LanguageModel::ZedDotDev(_) => self.default_model(), + LanguageModel::OpenAi(model) => model, + }; + + Request { + model, + messages: request + .messages + .into_iter() + .map(|msg| RequestMessage { + role: msg.role.into(), + content: msg.content, + }) + .collect(), + stream: true, + stop: request.stop, + temperature: request.temperature, + } + } +} + +pub fn count_open_ai_tokens( + request: LanguageModelRequest, + background_executor: &gpui::BackgroundExecutor, +) -> BoxFuture<'static, Result> { + background_executor + .spawn(async move { + let messages = request + .messages + .into_iter() + .map(|message| tiktoken_rs::ChatCompletionRequestMessage { + role: match message.role { + Role::User => "user".into(), + Role::Assistant => "assistant".into(), + Role::System => "system".into(), + }, + content: Some(message.content), + name: None, + function_call: None, + }) + .collect::>(); + + tiktoken_rs::num_tokens_from_messages(request.model.id(), &messages) + }) + .boxed() +} + +impl From for open_ai::Role { + fn from(val: Role) -> Self { + match val { + Role::User => OpenAiRole::User, + Role::Assistant => OpenAiRole::Assistant, + Role::System => OpenAiRole::System, + } + } +} + +struct AuthenticationPrompt { + api_key: View, + api_url: String, +} + +impl AuthenticationPrompt { + fn new(api_url: String, cx: &mut WindowContext) -> Self { + Self { + api_key: cx.new_view(|cx| { + let mut editor = Editor::single_line(cx); + editor.set_placeholder_text( + "sk-000000000000000000000000000000000000000000000000", + cx, + ); + editor + }), + api_url, + } + } + + fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + let api_key = self.api_key.read(cx).text(cx); + if api_key.is_empty() { + return; + } + + let write_credentials = cx.write_credentials(&self.api_url, "Bearer", api_key.as_bytes()); + cx.spawn(|_, mut cx| async move { + write_credentials.await?; + cx.update_global::(|provider, _cx| { + if let CompletionProvider::OpenAi(provider) = provider { + provider.api_key = Some(api_key); + } + }) + }) + .detach_and_log_err(cx); + } + + fn render_api_key_editor(&self, cx: &mut ViewContext) -> impl IntoElement { + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_features: settings.ui_font.features, + font_size: rems(0.875).into(), + font_weight: FontWeight::NORMAL, + font_style: FontStyle::Normal, + line_height: relative(1.3), + background_color: None, + underline: None, + strikethrough: None, + white_space: WhiteSpace::Normal, + }; + EditorElement::new( + &self.api_key, + EditorStyle { + background: cx.theme().colors().editor_background, + local_player: cx.theme().players().local(), + text: text_style, + ..Default::default() + }, + ) + } +} + +impl Render for AuthenticationPrompt { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + const INSTRUCTIONS: [&str; 6] = [ + "To use the assistant panel or inline assistant, you need to add your OpenAI API key.", + " - You can create an API key at: platform.openai.com/api-keys", + " - Make sure your OpenAI account has credits", + " - Having a subscription for another service like GitHub Copilot won't work.", + "", + "Paste your OpenAI API key below and hit enter to use the assistant:", + ]; + + v_flex() + .p_4() + .size_full() + .on_action(cx.listener(Self::save_api_key)) + .children( + INSTRUCTIONS.map(|instruction| Label::new(instruction).size(LabelSize::Small)), + ) + .child( + h_flex() + .w_full() + .my_2() + .px_2() + .py_1() + .bg(cx.theme().colors().editor_background) + .rounded_md() + .child(self.render_api_key_editor(cx)), + ) + .child( + Label::new( + "You can also assign the OPENAI_API_KEY environment variable and restart Zed.", + ) + .size(LabelSize::Small), + ) + .child( + h_flex() + .gap_2() + .child(Label::new("Click on").size(LabelSize::Small)) + .child(Icon::new(IconName::Ai).size(IconSize::XSmall)) + .child( + Label::new("in the status bar to close this panel.").size(LabelSize::Small), + ), + ) + .into_any() + } +} diff --git a/crates/assistant/src/completion_provider/zed.rs b/crates/assistant/src/completion_provider/zed.rs new file mode 100644 index 00000000000000..0febb05278ee5d --- /dev/null +++ b/crates/assistant/src/completion_provider/zed.rs @@ -0,0 +1,167 @@ +use crate::{ + assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, + LanguageModelRequest, +}; +use anyhow::{anyhow, Result}; +use client::{proto, Client}; +use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; +use gpui::{AnyView, AppContext, Task}; +use std::{future, sync::Arc}; +use ui::prelude::*; + +pub struct ZedDotDevCompletionProvider { + client: Arc, + default_model: ZedDotDevModel, + settings_version: usize, + status: client::Status, + _maintain_client_status: Task<()>, +} + +impl ZedDotDevCompletionProvider { + pub fn new( + default_model: ZedDotDevModel, + client: Arc, + settings_version: usize, + cx: &mut AppContext, + ) -> Self { + let mut status_rx = client.status(); + let status = *status_rx.borrow(); + let maintain_client_status = cx.spawn(|mut cx| async move { + while let Some(status) = status_rx.next().await { + let _ = cx.update_global::(|provider, _cx| { + if let CompletionProvider::ZedDotDev(provider) = provider { + provider.status = status; + } else { + unreachable!() + } + }); + } + }); + Self { + client, + default_model, + settings_version, + status, + _maintain_client_status: maintain_client_status, + } + } + + pub fn update(&mut self, default_model: ZedDotDevModel, settings_version: usize) { + self.default_model = default_model; + self.settings_version = settings_version; + } + + pub fn settings_version(&self) -> usize { + self.settings_version + } + + pub fn default_model(&self) -> ZedDotDevModel { + self.default_model.clone() + } + + pub fn is_authenticated(&self) -> bool { + self.status.is_connected() + } + + pub fn authenticate(&self, cx: &AppContext) -> Task> { + let client = self.client.clone(); + cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await }) + } + + pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|_cx| AuthenticationPrompt).into() + } + + pub fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + match request.model { + crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(), + crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour) + | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo) + | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => { + count_open_ai_tokens(request, cx.background_executor()) + } + crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => { + let request = self.client.request(proto::CountTokensWithLanguageModel { + model, + messages: request + .messages + .iter() + .map(|message| message.to_proto()) + .collect(), + }); + async move { + let response = request.await?; + Ok(response.token_count as usize) + } + .boxed() + } + } + } + + pub fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + let request = proto::CompleteWithLanguageModel { + model: request.model.id().to_string(), + messages: request + .messages + .iter() + .map(|message| message.to_proto()) + .collect(), + stop: request.stop, + temperature: request.temperature, + }; + + self.client + .request_stream(request) + .map_ok(|stream| { + stream + .filter_map(|response| async move { + match response { + Ok(mut response) => Some(Ok(response.choices.pop()?.delta?.content?)), + Err(error) => Some(Err(error)), + } + }) + .boxed() + }) + .boxed() + } +} + +struct AuthenticationPrompt; + +impl Render for AuthenticationPrompt { + fn render(&mut self, _cx: &mut ViewContext) -> impl IntoElement { + const LABEL: &str = "Generate and analyze code with language models. You can dialog with the assistant in this panel or transform code inline."; + + v_flex().gap_6().p_4().child(Label::new(LABEL)).child( + v_flex() + .gap_2() + .child( + Button::new("sign_in", "Sign in") + .icon_color(Color::Muted) + .icon(IconName::Github) + .icon_position(IconPosition::Start) + .style(ButtonStyle::Filled) + .full_width() + .on_click(|_, cx| { + CompletionProvider::global(cx) + .authenticate(cx) + .detach_and_log_err(cx); + }), + ) + .child( + div().flex().w_full().items_center().child( + Label::new("Sign in to enable collaboration.") + .color(Color::Muted) + .size(LabelSize::Small), + ), + ), + ) + } +} diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index f40a841f4c64cc..80dfc45c4f3607 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -1,394 +1,95 @@ -use ai::models::LanguageModel; -use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate}; -use ai::prompts::file_context::FileContext; -use ai::prompts::generate::GenerateInlineContent; -use ai::prompts::preamble::EngineerPreamble; -use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext}; -use ai::providers::open_ai::OpenAiLanguageModel; -use language::{BufferSnapshot, OffsetRangeExt, ToOffset}; -use std::cmp::{self, Reverse}; -use std::ops::Range; -use std::sync::Arc; - -#[allow(dead_code)] -fn summarize(buffer: &BufferSnapshot, selected_range: Range) -> String { - #[derive(Debug)] - struct Match { - collapse: Range, - keep: Vec>, - } - - let selected_range = selected_range.to_offset(buffer); - let mut ts_matches = buffer.matches(0..buffer.len(), |grammar| { - Some(&grammar.embedding_config.as_ref()?.query) - }); - let configs = ts_matches - .grammars() - .iter() - .map(|g| g.embedding_config.as_ref().unwrap()) - .collect::>(); - let mut matches = Vec::new(); - while let Some(mat) = ts_matches.peek() { - let config = &configs[mat.grammar_index]; - if let Some(collapse) = mat.captures.iter().find_map(|cap| { - if Some(cap.index) == config.collapse_capture_ix { - Some(cap.node.byte_range()) - } else { - None - } - }) { - let mut keep = Vec::new(); - for capture in mat.captures.iter() { - if Some(capture.index) == config.keep_capture_ix { - keep.push(capture.node.byte_range()); - } else { - continue; - } - } - ts_matches.advance(); - matches.push(Match { collapse, keep }); - } else { - ts_matches.advance(); - } - } - matches.sort_unstable_by_key(|mat| (mat.collapse.start, Reverse(mat.collapse.end))); - let mut matches = matches.into_iter().peekable(); - - let mut summary = String::new(); - let mut offset = 0; - let mut flushed_selection = false; - while let Some(mat) = matches.next() { - // Keep extending the collapsed range if the next match surrounds - // the current one. - while let Some(next_mat) = matches.peek() { - if mat.collapse.start <= next_mat.collapse.start - && mat.collapse.end >= next_mat.collapse.end - { - matches.next().unwrap(); - } else { - break; - } - } - - if offset > mat.collapse.start { - // Skip collapsed nodes that have already been summarized. - offset = cmp::max(offset, mat.collapse.end); - continue; - } - - if offset <= selected_range.start && selected_range.start <= mat.collapse.end { - if !flushed_selection { - // The collapsed node ends after the selection starts, so we'll flush the selection first. - summary.extend(buffer.text_for_range(offset..selected_range.start)); - summary.push_str("<|S|"); - if selected_range.end == selected_range.start { - summary.push_str(">"); - } else { - summary.extend(buffer.text_for_range(selected_range.clone())); - summary.push_str("|E|>"); - } - offset = selected_range.end; - flushed_selection = true; - } - - // If the selection intersects the collapsed node, we won't collapse it. - if selected_range.end >= mat.collapse.start { - continue; - } - } - - summary.extend(buffer.text_for_range(offset..mat.collapse.start)); - for keep in mat.keep { - summary.extend(buffer.text_for_range(keep)); - } - offset = mat.collapse.end; - } - - // Flush selection if we haven't already done so. - if !flushed_selection && offset <= selected_range.start { - summary.extend(buffer.text_for_range(offset..selected_range.start)); - summary.push_str("<|S|"); - if selected_range.end == selected_range.start { - summary.push_str(">"); - } else { - summary.extend(buffer.text_for_range(selected_range.clone())); - summary.push_str("|E|>"); - } - offset = selected_range.end; - } - - summary.extend(buffer.text_for_range(offset..buffer.len())); - summary -} +use language::BufferSnapshot; +use std::{fmt::Write, ops::Range}; pub fn generate_content_prompt( user_prompt: String, language_name: Option<&str>, buffer: BufferSnapshot, range: Range, - search_results: Vec, - model: &str, project_name: Option, ) -> anyhow::Result { - // Using new Prompt Templates - let openai_model: Arc = Arc::new(OpenAiLanguageModel::load(model)); - let lang_name = if let Some(language_name) = language_name { - Some(language_name.to_string()) - } else { - None - }; + let mut prompt = String::new(); - let args = PromptArguments { - model: openai_model, - language_name: lang_name.clone(), - project_name, - snippets: search_results.clone(), - reserved_tokens: 1000, - buffer: Some(buffer), - selected_range: Some(range), - user_prompt: Some(user_prompt.clone()), + let content_type = match language_name { + None | Some("Markdown" | "Plain Text") => { + writeln!(prompt, "You are an expert engineer.")?; + "Text" + } + Some(language_name) => { + writeln!(prompt, "You are an expert {language_name} engineer.")?; + writeln!( + prompt, + "Your answer MUST always and only be valid {}.", + language_name + )?; + "Code" + } }; - let templates: Vec<(PromptPriority, Box)> = vec![ - (PromptPriority::Mandatory, Box::new(EngineerPreamble {})), - ( - PromptPriority::Ordered { order: 1 }, - Box::new(RepositoryContext {}), - ), - ( - PromptPriority::Ordered { order: 0 }, - Box::new(FileContext {}), - ), - ( - PromptPriority::Mandatory, - Box::new(GenerateInlineContent {}), - ), - ]; - let chain = PromptChain::new(args, templates); - let (prompt, _) = chain.generate(true)?; - - anyhow::Ok(prompt) -} + if let Some(project_name) = project_name { + writeln!( + prompt, + "You are currently working inside the '{project_name}' project in code editor Zed." + )?; + } -#[cfg(test)] -pub(crate) mod tests { - use super::*; - use gpui::{AppContext, Context}; - use indoc::indoc; - use language::{ - language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig, - LanguageMatcher, Point, - }; - use settings::SettingsStore; - use std::sync::Arc; + // Include file content. + for chunk in buffer.text_for_range(0..range.start) { + prompt.push_str(chunk); + } - pub(crate) fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::language()), - ) - .with_embedding_query( - r#" - ( - [(line_comment) (attribute_item)]* @context - . - [ - (struct_item - name: (_) @name) + if range.is_empty() { + prompt.push_str("<|START|>"); + } else { + prompt.push_str("<|START|"); + } - (enum_item - name: (_) @name) + for chunk in buffer.text_for_range(range.clone()) { + prompt.push_str(chunk); + } - (impl_item - trait: (_)? @name - "for"? @name - type: (_) @name) + if !range.is_empty() { + prompt.push_str("|END|>"); + } - (trait_item - name: (_) @name) + for chunk in buffer.text_for_range(range.end..buffer.len()) { + prompt.push_str(chunk); + } - (function_item - name: (_) @name - body: (block - "{" @keep - "}" @keep) @collapse) + prompt.push('\n'); - (macro_definition - name: (_) @name) - ] @item - ) - "#, + if range.is_empty() { + writeln!( + prompt, + "Assume the cursor is located where the `<|START|>` span is." + ) + .unwrap(); + writeln!( + prompt, + "{content_type} can't be replaced, so assume your answer will be inserted at the cursor.", ) - .unwrap() + .unwrap(); + writeln!( + prompt, + "Generate {content_type} based on the users prompt: {user_prompt}", + ) + .unwrap(); + } else { + writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap(); + writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap(); + writeln!( + prompt, + "Double check that you only return code and not the '<|START|' and '|END|'> spans" + ) + .unwrap(); } - #[gpui::test] - fn test_outline_for_prompt(cx: &mut AppContext) { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language_settings::init(cx); - let text = indoc! {" - struct X { - a: usize, - b: usize, - } - - impl X { - - fn new() -> Self { - let a = 1; - let b = 2; - Self { a, b } - } - - pub fn a(&self, param: bool) -> usize { - self.a - } - - pub fn b(&self) -> usize { - self.b - } - } - "}; - let buffer = cx.new_model(|cx| { - Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx) - }); - let snapshot = buffer.read(cx).snapshot(); - - assert_eq!( - summarize(&snapshot, Point::new(1, 4)..Point::new(1, 4)), - indoc! {" - struct X { - <|S|>a: usize, - b: usize, - } - - impl X { - - fn new() -> Self {} - - pub fn a(&self, param: bool) -> usize {} - - pub fn b(&self) -> usize {} - } - "} - ); - - assert_eq!( - summarize(&snapshot, Point::new(8, 12)..Point::new(8, 14)), - indoc! {" - struct X { - a: usize, - b: usize, - } - - impl X { - - fn new() -> Self { - let <|S|a |E|>= 1; - let b = 2; - Self { a, b } - } - - pub fn a(&self, param: bool) -> usize {} - - pub fn b(&self) -> usize {} - } - "} - ); - - assert_eq!( - summarize(&snapshot, Point::new(6, 0)..Point::new(6, 0)), - indoc! {" - struct X { - a: usize, - b: usize, - } + writeln!(prompt, "Never make remarks about the output.").unwrap(); + writeln!( + prompt, + "Do not return anything else, except the generated {content_type}." + ) + .unwrap(); - impl X { - <|S|> - fn new() -> Self {} - - pub fn a(&self, param: bool) -> usize {} - - pub fn b(&self) -> usize {} - } - "} - ); - - assert_eq!( - summarize(&snapshot, Point::new(21, 0)..Point::new(21, 0)), - indoc! {" - struct X { - a: usize, - b: usize, - } - - impl X { - - fn new() -> Self {} - - pub fn a(&self, param: bool) -> usize {} - - pub fn b(&self) -> usize {} - } - <|S|>"} - ); - - // Ensure nested functions get collapsed properly. - let text = indoc! {" - struct X { - a: usize, - b: usize, - } - - impl X { - - fn new() -> Self { - let a = 1; - let b = 2; - Self { a, b } - } - - pub fn a(&self, param: bool) -> usize { - let a = 30; - fn nested() -> usize { - 3 - } - self.a + nested() - } - - pub fn b(&self) -> usize { - self.b - } - } - "}; - buffer.update(cx, |buffer, cx| buffer.set_text(text, cx)); - let snapshot = buffer.read(cx).snapshot(); - assert_eq!( - summarize(&snapshot, Point::new(0, 0)..Point::new(0, 0)), - indoc! {" - <|S|>struct X { - a: usize, - b: usize, - } - - impl X { - - fn new() -> Self {} - - pub fn a(&self, param: bool) -> usize {} - - pub fn b(&self) -> usize {} - } - "} - ); - } + Ok(prompt) } diff --git a/crates/assistant/src/saved_conversation.rs b/crates/assistant/src/saved_conversation.rs new file mode 100644 index 00000000000000..5e6ce613226ad6 --- /dev/null +++ b/crates/assistant/src/saved_conversation.rs @@ -0,0 +1,121 @@ +use crate::{assistant_settings::OpenAiModel, MessageId, MessageMetadata}; +use anyhow::{anyhow, Result}; +use collections::HashMap; +use fs::Fs; +use futures::StreamExt; +use regex::Regex; +use serde::{Deserialize, Serialize}; +use std::{ + cmp::Reverse, + ffi::OsStr, + path::{Path, PathBuf}, + sync::Arc, +}; +use util::paths::CONVERSATIONS_DIR; + +#[derive(Serialize, Deserialize)] +pub struct SavedMessage { + pub id: MessageId, + pub start: usize, +} + +#[derive(Serialize, Deserialize)] +pub struct SavedConversation { + pub id: Option, + pub zed: String, + pub version: String, + pub text: String, + pub messages: Vec, + pub message_metadata: HashMap, + pub summary: String, +} + +impl SavedConversation { + pub const VERSION: &'static str = "0.2.0"; + + pub async fn load(path: &Path, fs: &dyn Fs) -> Result { + let saved_conversation = fs.load(path).await?; + let saved_conversation_json = + serde_json::from_str::(&saved_conversation)?; + match saved_conversation_json + .get("version") + .ok_or_else(|| anyhow!("version not found"))? + { + serde_json::Value::String(version) => match version.as_str() { + Self::VERSION => Ok(serde_json::from_value::(saved_conversation_json)?), + "0.1.0" => { + let saved_conversation = + serde_json::from_value::(saved_conversation_json)?; + Ok(Self { + id: saved_conversation.id, + zed: saved_conversation.zed, + version: saved_conversation.version, + text: saved_conversation.text, + messages: saved_conversation.messages, + message_metadata: saved_conversation.message_metadata, + summary: saved_conversation.summary, + }) + } + _ => Err(anyhow!( + "unrecognized saved conversation version: {}", + version + )), + }, + _ => Err(anyhow!("version not found on saved conversation")), + } + } +} + +#[derive(Serialize, Deserialize)] +struct SavedConversationV0_1_0 { + id: Option, + zed: String, + version: String, + text: String, + messages: Vec, + message_metadata: HashMap, + summary: String, + api_url: Option, + model: OpenAiModel, +} + +pub struct SavedConversationMetadata { + pub title: String, + pub path: PathBuf, + pub mtime: chrono::DateTime, +} + +impl SavedConversationMetadata { + pub async fn list(fs: Arc) -> Result> { + fs.create_dir(&CONVERSATIONS_DIR).await?; + + let mut paths = fs.read_dir(&CONVERSATIONS_DIR).await?; + let mut conversations = Vec::::new(); + while let Some(path) = paths.next().await { + let path = path?; + if path.extension() != Some(OsStr::new("json")) { + continue; + } + + let pattern = r" - \d+.zed.json$"; + let re = Regex::new(pattern).unwrap(); + + let metadata = fs.metadata(&path).await?; + if let Some((file_name, metadata)) = path + .file_name() + .and_then(|name| name.to_str()) + .zip(metadata) + { + let title = re.replace(file_name, ""); + conversations.push(Self { + title: title.into_owned(), + path, + mtime: metadata.mtime.into(), + }); + } + } + conversations.sort_unstable_by_key(|conversation| Reverse(conversation.mtime)); + + Ok(conversations) + } +} diff --git a/crates/assistant/src/streaming_diff.rs b/crates/assistant/src/streaming_diff.rs index 7399a7b4faf262..cba7758dde15af 100644 --- a/crates/assistant/src/streaming_diff.rs +++ b/crates/assistant/src/streaming_diff.rs @@ -197,12 +197,10 @@ impl StreamingDiff { } else { hunks.push(Hunk::Remove { len: char_len }) } + } else if let Some(Hunk::Keep { len }) = hunks.last_mut() { + *len += char_len; } else { - if let Some(Hunk::Keep { len }) = hunks.last_mut() { - *len += char_len; - } else { - hunks.push(Hunk::Keep { len: char_len }) - } + hunks.push(Hunk::Keep { len: char_len }) } } diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 754a47baa49537..5abd5305795159 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -13,7 +13,7 @@ use async_tungstenite::tungstenite::{ use clock::SystemClock; use collections::HashMap; use futures::{ - channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, + channel::oneshot, future::LocalBoxFuture, AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, }; use gpui::{ @@ -36,7 +36,10 @@ use std::{ future::Future, marker::PhantomData, path::PathBuf, - sync::{atomic::AtomicU64, Arc, Weak}, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, Weak, + }, time::{Duration, Instant}, }; use telemetry::Telemetry; @@ -442,7 +445,7 @@ impl Client { } pub fn id(&self) -> u64 { - self.id.load(std::sync::atomic::Ordering::SeqCst) + self.id.load(Ordering::SeqCst) } pub fn http_client(&self) -> Arc { @@ -450,7 +453,7 @@ impl Client { } pub fn set_id(&self, id: u64) -> &Self { - self.id.store(id, std::sync::atomic::Ordering::SeqCst); + self.id.store(id, Ordering::SeqCst); self } @@ -1260,6 +1263,30 @@ impl Client { .map_ok(|envelope| envelope.payload) } + pub fn request_stream( + &self, + request: T, + ) -> impl Future>>> { + let client_id = self.id.load(Ordering::SeqCst); + log::debug!( + "rpc request start. client_id:{}. name:{}", + client_id, + T::NAME + ); + let response = self + .connection_id() + .map(|conn_id| self.peer.request_stream(conn_id, request)); + async move { + let response = response?.await; + log::debug!( + "rpc request finish. client_id:{}. name:{}", + client_id, + T::NAME + ); + response + } + } + pub fn request_envelope( &self, request: T, diff --git a/crates/client/src/telemetry.rs b/crates/client/src/telemetry.rs index d769bcaa5cdc24..ad370253fdfdf8 100644 --- a/crates/client/src/telemetry.rs +++ b/crates/client/src/telemetry.rs @@ -261,7 +261,7 @@ impl Telemetry { self: &Arc, conversation_id: Option, kind: AssistantKind, - model: &str, + model: String, ) { let event = Event::Assistant(AssistantEvent { conversation_id, diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 49cdcebbffd5e1..6fcea2ae04a90b 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -31,10 +31,12 @@ collections.workspace = true dashmap = "5.4" envy = "0.4.2" futures.workspace = true +google_ai.workspace = true hex.workspace = true live_kit_server.workspace = true log.workspace = true nanoid = "0.4" +open_ai.workspace = true parking_lot.workspace = true prometheus = "0.13" prost.workspace = true @@ -80,7 +82,6 @@ git = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } -lazy_static.workspace = true live_kit_client = { workspace = true, features = ["test-support"] } lsp = { workspace = true, features = ["test-support"] } menu.workspace = true diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index ae26e219d0e5a3..d51b88f668dbac 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -379,6 +379,16 @@ CREATE TABLE extension_versions ( CREATE UNIQUE INDEX "index_extensions_external_id" ON "extensions" ("external_id"); CREATE INDEX "index_extensions_total_download_count" ON "extensions" ("total_download_count"); +CREATE TABLE rate_buckets ( + user_id INT NOT NULL, + rate_limit_name VARCHAR(255) NOT NULL, + token_count INT NOT NULL, + last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL, + PRIMARY KEY (user_id, rate_limit_name), + FOREIGN KEY (user_id) REFERENCES users(id) +); +CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name); + CREATE TABLE hosted_projects ( id INTEGER PRIMARY KEY AUTOINCREMENT, channel_id INTEGER NOT NULL REFERENCES channels(id), diff --git a/crates/collab/migrations/20240220234826_add_rate_buckets.sql b/crates/collab/migrations/20240220234826_add_rate_buckets.sql new file mode 100644 index 00000000000000..864a4373034fc5 --- /dev/null +++ b/crates/collab/migrations/20240220234826_add_rate_buckets.sql @@ -0,0 +1,11 @@ +CREATE TABLE IF NOT EXISTS rate_buckets ( + user_id INT NOT NULL, + rate_limit_name VARCHAR(255) NOT NULL, + token_count INT NOT NULL, + last_refill TIMESTAMP WITHOUT TIME ZONE NOT NULL, + PRIMARY KEY (user_id, rate_limit_name), + CONSTRAINT fk_user + FOREIGN KEY (user_id) REFERENCES users(id) +); + +CREATE INDEX idx_user_id_rate_limit ON rate_buckets (user_id, rate_limit_name); diff --git a/crates/collab/src/ai.rs b/crates/collab/src/ai.rs new file mode 100644 index 00000000000000..4634166799aed2 --- /dev/null +++ b/crates/collab/src/ai.rs @@ -0,0 +1,75 @@ +use anyhow::{anyhow, Result}; +use rpc::proto; + +pub fn language_model_request_to_open_ai( + request: proto::CompleteWithLanguageModel, +) -> Result { + Ok(open_ai::Request { + model: open_ai::Model::from_id(&request.model).unwrap_or(open_ai::Model::FourTurbo), + messages: request + .messages + .into_iter() + .map(|message| { + let role = proto::LanguageModelRole::from_i32(message.role) + .ok_or_else(|| anyhow!("invalid role {}", message.role))?; + Ok(open_ai::RequestMessage { + role: match role { + proto::LanguageModelRole::LanguageModelUser => open_ai::Role::User, + proto::LanguageModelRole::LanguageModelAssistant => { + open_ai::Role::Assistant + } + proto::LanguageModelRole::LanguageModelSystem => open_ai::Role::System, + }, + content: message.content, + }) + }) + .collect::>>()?, + stream: true, + stop: request.stop, + temperature: request.temperature, + }) +} + +pub fn language_model_request_to_google_ai( + request: proto::CompleteWithLanguageModel, +) -> Result { + Ok(google_ai::GenerateContentRequest { + contents: request + .messages + .into_iter() + .map(language_model_request_message_to_google_ai) + .collect::>>()?, + generation_config: None, + safety_settings: None, + }) +} + +pub fn language_model_request_message_to_google_ai( + message: proto::LanguageModelRequestMessage, +) -> Result { + let role = proto::LanguageModelRole::from_i32(message.role) + .ok_or_else(|| anyhow!("invalid role {}", message.role))?; + + Ok(google_ai::Content { + parts: vec![google_ai::Part::TextPart(google_ai::TextPart { + text: message.content, + })], + role: match role { + proto::LanguageModelRole::LanguageModelUser => google_ai::Role::User, + proto::LanguageModelRole::LanguageModelAssistant => google_ai::Role::Model, + proto::LanguageModelRole::LanguageModelSystem => google_ai::Role::User, + }, + }) +} + +pub fn count_tokens_request_to_google_ai( + request: proto::CountTokensWithLanguageModel, +) -> Result { + Ok(google_ai::CountTokensRequest { + contents: request + .messages + .into_iter() + .map(language_model_request_message_to_google_ai) + .collect::>>()?, + }) +} diff --git a/crates/collab/src/api/extensions.rs b/crates/collab/src/api/extensions.rs index 4e52aac56d39bf..9e9038d886a9fb 100644 --- a/crates/collab/src/api/extensions.rs +++ b/crates/collab/src/api/extensions.rs @@ -1,6 +1,5 @@ use crate::{ db::{ExtensionMetadata, NewExtensionVersion}, - executor::Executor, AppState, Error, Result, }; use anyhow::{anyhow, Context as _}; @@ -136,7 +135,7 @@ async fn download_extension( const EXTENSION_FETCH_INTERVAL: Duration = Duration::from_secs(5 * 60); const EXTENSION_DOWNLOAD_URL_LIFETIME: Duration = Duration::from_secs(3 * 60); -pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc, executor: Executor) { +pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc) { let Some(blob_store_client) = app_state.blob_store_client.clone() else { log::info!("no blob store client"); return; @@ -146,6 +145,7 @@ pub fn fetch_extensions_from_blob_store_periodically(app_state: Arc, e return; }; + let executor = app_state.executor.clone(); executor.spawn_detached({ let executor = executor.clone(); async move { diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 0326cf43743207..7f2e345a591507 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -10,6 +10,7 @@ pub mod hosted_projects; pub mod messages; pub mod notifications; pub mod projects; +pub mod rate_buckets; pub mod rooms; pub mod servers; pub mod users; diff --git a/crates/collab/src/db/queries/rate_buckets.rs b/crates/collab/src/db/queries/rate_buckets.rs new file mode 100644 index 00000000000000..58b62170f4a35f --- /dev/null +++ b/crates/collab/src/db/queries/rate_buckets.rs @@ -0,0 +1,58 @@ +use super::*; +use crate::db::tables::rate_buckets; +use sea_orm::{ColumnTrait, EntityTrait, QueryFilter}; + +impl Database { + /// Saves the rate limit for the given user and rate limit name if the last_refill is later + /// than the currently saved timestamp. + pub async fn save_rate_buckets(&self, buckets: &[rate_buckets::Model]) -> Result<()> { + if buckets.is_empty() { + return Ok(()); + } + + self.transaction(|tx| async move { + rate_buckets::Entity::insert_many(buckets.iter().map(|bucket| { + rate_buckets::ActiveModel { + user_id: ActiveValue::Set(bucket.user_id), + rate_limit_name: ActiveValue::Set(bucket.rate_limit_name.clone()), + token_count: ActiveValue::Set(bucket.token_count), + last_refill: ActiveValue::Set(bucket.last_refill), + } + })) + .on_conflict( + OnConflict::columns([ + rate_buckets::Column::UserId, + rate_buckets::Column::RateLimitName, + ]) + .update_columns([ + rate_buckets::Column::TokenCount, + rate_buckets::Column::LastRefill, + ]) + .to_owned(), + ) + .exec(&*tx) + .await?; + + Ok(()) + }) + .await + } + + /// Retrieves the rate limit for the given user and rate limit name. + pub async fn get_rate_bucket( + &self, + user_id: UserId, + rate_limit_name: &str, + ) -> Result> { + self.transaction(|tx| async move { + let rate_limit = rate_buckets::Entity::find() + .filter(rate_buckets::Column::UserId.eq(user_id)) + .filter(rate_buckets::Column::RateLimitName.eq(rate_limit_name)) + .one(&*tx) + .await?; + + Ok(rate_limit) + }) + .await + } +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index 468e7390ab3c64..6864cc3782dd96 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -22,6 +22,7 @@ pub mod observed_buffer_edits; pub mod observed_channel_messages; pub mod project; pub mod project_collaborator; +pub mod rate_buckets; pub mod room; pub mod room_participant; pub mod server; diff --git a/crates/collab/src/db/tables/rate_buckets.rs b/crates/collab/src/db/tables/rate_buckets.rs new file mode 100644 index 00000000000000..e16db36814c7f0 --- /dev/null +++ b/crates/collab/src/db/tables/rate_buckets.rs @@ -0,0 +1,31 @@ +use crate::db::UserId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "rate_buckets")] +pub struct Model { + #[sea_orm(primary_key, auto_increment = false)] + pub user_id: UserId, + #[sea_orm(primary_key, auto_increment = false)] + pub rate_limit_name: String, + pub token_count: i32, + pub last_refill: DateTime, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 71ca4788e1fe1a..9cc4271e2d5477 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,8 +1,10 @@ +pub mod ai; pub mod api; pub mod auth; pub mod db; pub mod env; pub mod executor; +mod rate_limiter; pub mod rpc; #[cfg(test)] @@ -13,6 +15,7 @@ use aws_config::{BehaviorVersion, Region}; use axum::{http::StatusCode, response::IntoResponse}; use db::{ChannelId, Database}; use executor::Executor; +pub use rate_limiter::*; use serde::Deserialize; use std::{path::PathBuf, sync::Arc}; use util::ResultExt; @@ -126,6 +129,8 @@ pub struct Config { pub blob_store_secret_key: Option, pub blob_store_bucket: Option, pub zed_environment: Arc, + pub openai_api_key: Option>, + pub google_ai_api_key: Option>, pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, @@ -147,12 +152,14 @@ pub struct AppState { pub db: Arc, pub live_kit_client: Option>, pub blob_store_client: Option, + pub rate_limiter: Arc, + pub executor: Executor, pub clickhouse_client: Option, pub config: Config, } impl AppState { - pub async fn new(config: Config) -> Result> { + pub async fn new(config: Config, executor: Executor) -> Result> { let mut db_options = db::ConnectOptions::new(config.database_url.clone()); db_options.max_connections(config.database_max_connections); let mut db = Database::new(db_options, Executor::Production).await?; @@ -173,10 +180,13 @@ impl AppState { None }; + let db = Arc::new(db); let this = Self { - db: Arc::new(db), + db: db.clone(), live_kit_client, blob_store_client: build_blob_store_client(&config).await.log_err(), + rate_limiter: Arc::new(RateLimiter::new(db)), + executor, clickhouse_client: config .clickhouse_url .as_ref() diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 4ac872de445300..59b4d377afe416 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -7,7 +7,7 @@ use axum::{ }; use collab::{ api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState, - Config, MigrateConfig, Result, + Config, MigrateConfig, RateLimiter, Result, }; use db::Database; use std::{ @@ -62,18 +62,27 @@ async fn main() -> Result<()> { run_migrations().await?; - let state = AppState::new(config).await?; + let state = AppState::new(config, Executor::Production).await?; let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) .expect("failed to bind TCP listener"); + let epoch = state + .db + .create_server(&state.config.zed_environment) + .await?; + let rpc_server = collab::rpc::Server::new(epoch, state.clone()); + rpc_server.start().await?; + + fetch_extensions_from_blob_store_periodically(state.clone()); + RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone()); + let rpc_server = if is_collab { let epoch = state .db .create_server(&state.config.zed_environment) .await?; - let rpc_server = - collab::rpc::Server::new(epoch, state.clone(), Executor::Production); + let rpc_server = collab::rpc::Server::new(epoch, state.clone()); rpc_server.start().await?; Some(rpc_server) @@ -82,7 +91,7 @@ async fn main() -> Result<()> { }; if is_api { - fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production); + fetch_extensions_from_blob_store_periodically(state.clone()); } let mut app = collab::api::routes(rpc_server.clone(), state.clone()); diff --git a/crates/collab/src/rate_limiter.rs b/crates/collab/src/rate_limiter.rs new file mode 100644 index 00000000000000..e6f1fcbaebdcc6 --- /dev/null +++ b/crates/collab/src/rate_limiter.rs @@ -0,0 +1,274 @@ +use crate::{db::UserId, executor::Executor, Database, Error, Result}; +use anyhow::anyhow; +use chrono::{DateTime, Duration, Utc}; +use dashmap::{DashMap, DashSet}; +use sea_orm::prelude::DateTimeUtc; +use std::sync::Arc; +use util::ResultExt; + +pub trait RateLimit: 'static { + fn capacity() -> usize; + fn refill_duration() -> Duration; + fn db_name() -> &'static str; +} + +/// Used to enforce per-user rate limits +pub struct RateLimiter { + buckets: DashMap<(UserId, String), RateBucket>, + dirty_buckets: DashSet<(UserId, String)>, + db: Arc, +} + +impl RateLimiter { + pub fn new(db: Arc) -> Self { + RateLimiter { + buckets: DashMap::new(), + dirty_buckets: DashSet::new(), + db, + } + } + + /// Spawns a new task that periodically saves rate limit data to the database. + pub fn save_periodically(rate_limiter: Arc, executor: Executor) { + const RATE_LIMITER_SAVE_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10); + + executor.clone().spawn_detached(async move { + loop { + executor.sleep(RATE_LIMITER_SAVE_INTERVAL).await; + rate_limiter.save().await.log_err(); + } + }); + } + + /// Returns an error if the user has exceeded the specified `RateLimit`. + /// Attempts to read the from the database if no cached RateBucket currently exists. + pub async fn check(&self, user_id: UserId) -> Result<()> { + self.check_internal::(user_id, Utc::now()).await + } + + async fn check_internal(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> { + let bucket_key = (user_id, T::db_name().to_string()); + + // Attempt to fetch the bucket from the database if it hasn't been cached. + // For now, we keep buckets in memory for the lifetime of the process rather than expiring them, + // but this enforces limits across restarts so long as the database is reachable. + if !self.buckets.contains_key(&bucket_key) { + if let Some(bucket) = self.load_bucket::(user_id).await.log_err().flatten() { + self.buckets.insert(bucket_key.clone(), bucket); + self.dirty_buckets.insert(bucket_key.clone()); + } + } + + let mut bucket = self + .buckets + .entry(bucket_key.clone()) + .or_insert_with(|| RateBucket::new(T::capacity(), T::refill_duration(), now)); + + if bucket.value_mut().allow(now) { + self.dirty_buckets.insert(bucket_key); + Ok(()) + } else { + Err(anyhow!("rate limit exceeded"))? + } + } + + async fn load_bucket( + &self, + user_id: UserId, + ) -> Result, Error> { + Ok(self + .db + .get_rate_bucket(user_id, K::db_name()) + .await? + .map(|saved_bucket| RateBucket { + capacity: K::capacity(), + refill_time_per_token: K::refill_duration(), + token_count: saved_bucket.token_count as usize, + last_refill: DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc), + })) + } + + pub async fn save(&self) -> Result<()> { + let mut buckets = Vec::new(); + self.dirty_buckets.retain(|key| { + if let Some(bucket) = self.buckets.get(&key) { + buckets.push(crate::db::rate_buckets::Model { + user_id: key.0, + rate_limit_name: key.1.clone(), + token_count: bucket.token_count as i32, + last_refill: bucket.last_refill.naive_utc(), + }); + } + false + }); + + match self.db.save_rate_buckets(&buckets).await { + Ok(()) => Ok(()), + Err(err) => { + for bucket in buckets { + self.dirty_buckets + .insert((bucket.user_id, bucket.rate_limit_name)); + } + Err(err) + } + } + } +} + +#[derive(Clone)] +struct RateBucket { + capacity: usize, + token_count: usize, + refill_time_per_token: Duration, + last_refill: DateTimeUtc, +} + +impl RateBucket { + fn new(capacity: usize, refill_duration: Duration, now: DateTimeUtc) -> Self { + RateBucket { + capacity, + token_count: capacity, + refill_time_per_token: refill_duration / capacity as i32, + last_refill: now, + } + } + + fn allow(&mut self, now: DateTimeUtc) -> bool { + self.refill(now); + if self.token_count > 0 { + self.token_count -= 1; + true + } else { + false + } + } + + fn refill(&mut self, now: DateTimeUtc) { + let elapsed = now - self.last_refill; + if elapsed >= self.refill_time_per_token { + let new_tokens = + elapsed.num_milliseconds() / self.refill_time_per_token.num_milliseconds(); + + self.token_count = (self.token_count + new_tokens as usize).min(self.capacity); + self.last_refill = now; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::{NewUserParams, TestDb}; + use gpui::TestAppContext; + + #[gpui::test] + async fn test_rate_limiter(cx: &mut TestAppContext) { + let test_db = TestDb::sqlite(cx.executor().clone()); + let db = test_db.db().clone(); + let user_1 = db + .create_user( + "user-1@zed.dev", + false, + NewUserParams { + github_login: "user-1".into(), + github_user_id: 1, + }, + ) + .await + .unwrap() + .user_id; + let user_2 = db + .create_user( + "user-2@zed.dev", + false, + NewUserParams { + github_login: "user-2".into(), + github_user_id: 2, + }, + ) + .await + .unwrap() + .user_id; + + let mut now = Utc::now(); + + let rate_limiter = RateLimiter::new(db.clone()); + + // User 1 can access resource A two times before being rate-limited. + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap(); + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap(); + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap_err(); + + // User 2 can access resource A and user 1 can access resource B. + rate_limiter + .check_internal::(user_2, now) + .await + .unwrap(); + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap(); + + // After one second, user 1 can make another request before being rate-limited again. + now += Duration::seconds(1); + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap(); + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap_err(); + + rate_limiter.save().await.unwrap(); + + // Rate limits are reloaded from the database, so user A is still rate-limited + // for resource A. + let rate_limiter = RateLimiter::new(db.clone()); + rate_limiter + .check_internal::(user_1, now) + .await + .unwrap_err(); + } + + struct RateLimitA; + + impl RateLimit for RateLimitA { + fn capacity() -> usize { + 2 + } + + fn refill_duration() -> Duration { + Duration::seconds(2) + } + + fn db_name() -> &'static str { + "rate-limit-a" + } + } + + struct RateLimitB; + + impl RateLimit for RateLimitB { + fn capacity() -> usize { + 10 + } + + fn refill_duration() -> Duration { + Duration::seconds(3) + } + + fn db_name() -> &'static str { + "rate-limit-b" + } + } +} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 319f752f8bbb9d..959f3aef62c881 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -9,9 +9,9 @@ use crate::{ User, UserId, }, executor::Executor, - AppState, Error, Result, + AppState, Error, RateLimit, RateLimiter, Result, }; -use anyhow::anyhow; +use anyhow::{anyhow, Context as _}; use async_tungstenite::tungstenite::{ protocol::CloseFrame as TungsteniteCloseFrame, Message as TungsteniteMessage, }; @@ -30,6 +30,8 @@ use axum::{ }; use collections::{HashMap, HashSet}; pub use connection_pool::{ConnectionPool, ZedVersion}; +use core::fmt::{self, Debug, Formatter}; + use futures::{ channel::oneshot, future::{self, BoxFuture}, @@ -39,15 +41,14 @@ use futures::{ use prometheus::{register_int_gauge, IntGauge}; use rpc::{ proto::{ - self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo, - RequestMessage, ShareProject, UpdateChannelBufferCollaborators, + self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LanguageModelRole, + LiveKitConnectionInfo, RequestMessage, ShareProject, UpdateChannelBufferCollaborators, }, Connection, ConnectionId, ErrorCode, ErrorCodeExt, ErrorExt, Peer, Receipt, TypedEnvelope, }; use serde::{Serialize, Serializer}; use std::{ any::TypeId, - fmt, future::Future, marker::PhantomData, mem, @@ -64,7 +65,7 @@ use time::OffsetDateTime; use tokio::sync::{watch, Semaphore}; use tower::ServiceBuilder; use tracing::{field, info_span, instrument, Instrument}; -use util::SemanticVersion; +use util::{http::IsahcHttpClient, SemanticVersion}; pub const RECONNECT_TIMEOUT: Duration = Duration::from_secs(30); @@ -92,6 +93,18 @@ impl Response { } } +struct StreamingResponse { + peer: Arc, + receipt: Receipt, +} + +impl StreamingResponse { + fn send(&self, payload: R::Response) -> Result<()> { + self.peer.respond(self.receipt, payload)?; + Ok(()) + } +} + #[derive(Clone)] struct Session { user_id: UserId, @@ -100,6 +113,8 @@ struct Session { peer: Arc, connection_pool: Arc>, live_kit_client: Option>, + http_client: IsahcHttpClient, + rate_limiter: Arc, _executor: Executor, } @@ -124,8 +139,8 @@ impl Session { } } -impl fmt::Debug for Session { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl Debug for Session { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("Session") .field("user_id", &self.user_id) .field("connection_id", &self.connection_id) @@ -148,7 +163,6 @@ pub struct Server { peer: Arc, pub(crate) connection_pool: Arc>, app_state: Arc, - executor: Executor, handlers: HashMap, teardown: watch::Sender, } @@ -175,12 +189,11 @@ where } impl Server { - pub fn new(id: ServerId, app_state: Arc, executor: Executor) -> Arc { + pub fn new(id: ServerId, app_state: Arc) -> Arc { let mut server = Self { id: parking_lot::Mutex::new(id), peer: Peer::new(id.0 as u32), - app_state, - executor, + app_state: app_state.clone(), connection_pool: Default::default(), handlers: Default::default(), teardown: watch::channel(false).0, @@ -280,7 +293,30 @@ impl Server { .add_message_handler(update_followers) .add_request_handler(get_private_user_info) .add_message_handler(acknowledge_channel_message) - .add_message_handler(acknowledge_buffer_version); + .add_message_handler(acknowledge_buffer_version) + .add_streaming_request_handler({ + let app_state = app_state.clone(); + move |request, response, session| { + complete_with_language_model( + request, + response, + session, + app_state.config.openai_api_key.clone(), + app_state.config.google_ai_api_key.clone(), + ) + } + }) + .add_request_handler({ + let app_state = app_state.clone(); + move |request, response, session| { + count_tokens_with_language_model( + request, + response, + session, + app_state.config.google_ai_api_key.clone(), + ) + } + }); Arc::new(server) } @@ -289,12 +325,12 @@ impl Server { let server_id = *self.id.lock(); let app_state = self.app_state.clone(); let peer = self.peer.clone(); - let timeout = self.executor.sleep(CLEANUP_TIMEOUT); + let timeout = self.app_state.executor.sleep(CLEANUP_TIMEOUT); let pool = self.connection_pool.clone(); let live_kit_client = self.app_state.live_kit_client.clone(); let span = info_span!("start server"); - self.executor.spawn_detached( + self.app_state.executor.spawn_detached( async move { tracing::info!("waiting for cleanup timeout"); timeout.await; @@ -536,6 +572,40 @@ impl Server { }) } + fn add_streaming_request_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(M, StreamingResponse, Session) -> Fut, + Fut: Send + Future>, + M: RequestMessage, + { + let handler = Arc::new(handler); + self.add_handler(move |envelope, session| { + let receipt = envelope.receipt(); + let handler = handler.clone(); + async move { + let peer = session.peer.clone(); + let response = StreamingResponse { + peer: peer.clone(), + receipt, + }; + match (handler)(envelope.payload, response, session).await { + Ok(()) => { + peer.end_stream(receipt)?; + Ok(()) + } + Err(error) => { + let proto_err = match &error { + Error::Internal(err) => err.to_proto(), + _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(), + }; + peer.respond_with_error(receipt, proto_err)?; + Err(error) + } + } + } + }) + } + #[allow(clippy::too_many_arguments)] pub fn handle_connection( self: &Arc, @@ -569,6 +639,14 @@ impl Server { tracing::Span::current().record("connection_id", format!("{}", connection_id)); tracing::info!("connection opened"); + let http_client = match IsahcHttpClient::new() { + Ok(http_client) => http_client, + Err(error) => { + tracing::error!(?error, "failed to create HTTP client"); + return; + } + }; + let session = Session { user_id, connection_id, @@ -576,7 +654,9 @@ impl Server { peer: this.peer.clone(), connection_pool: this.connection_pool.clone(), live_kit_client: this.app_state.live_kit_client.clone(), - _executor: executor.clone() + http_client, + rate_limiter: this.app_state.rate_limiter.clone(), + _executor: executor.clone(), }; if let Err(error) = this.send_initial_client_update(connection_id, user, zed_version, send_connection_id, &session).await { @@ -3220,6 +3300,207 @@ async fn acknowledge_buffer_version( Ok(()) } +struct CompleteWithLanguageModelRateLimit; + +impl RateLimit for CompleteWithLanguageModelRateLimit { + fn capacity() -> usize { + std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(120) // Picked arbitrarily + } + + fn refill_duration() -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name() -> &'static str { + "complete-with-language-model" + } +} + +async fn complete_with_language_model( + request: proto::CompleteWithLanguageModel, + response: StreamingResponse, + session: Session, + open_ai_api_key: Option>, + google_ai_api_key: Option>, +) -> Result<()> { + authorize_access_to_language_models(&session).await?; + session + .rate_limiter + .check::(session.user_id) + .await?; + + if request.model.starts_with("gpt") { + let api_key = + open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?; + complete_with_open_ai(request, response, session, api_key).await?; + } else if request.model.starts_with("gemini") { + let api_key = google_ai_api_key + .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; + complete_with_google_ai(request, response, session, api_key).await?; + } + + Ok(()) +} + +async fn complete_with_open_ai( + request: proto::CompleteWithLanguageModel, + response: StreamingResponse, + session: Session, + api_key: Arc, +) -> Result<()> { + const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; + + let mut completion_stream = open_ai::stream_completion( + &session.http_client, + OPEN_AI_API_URL, + &api_key, + crate::ai::language_model_request_to_open_ai(request)?, + ) + .await + .context("open_ai::stream_completion request failed")?; + + while let Some(event) = completion_stream.next().await { + let event = event?; + response.send(proto::LanguageModelResponse { + choices: event + .choices + .into_iter() + .map(|choice| proto::LanguageModelChoiceDelta { + index: choice.index, + delta: Some(proto::LanguageModelResponseMessage { + role: choice.delta.role.map(|role| match role { + open_ai::Role::User => LanguageModelRole::LanguageModelUser, + open_ai::Role::Assistant => LanguageModelRole::LanguageModelAssistant, + open_ai::Role::System => LanguageModelRole::LanguageModelSystem, + } as i32), + content: choice.delta.content, + }), + finish_reason: choice.finish_reason, + }) + .collect(), + })?; + } + + Ok(()) +} + +async fn complete_with_google_ai( + request: proto::CompleteWithLanguageModel, + response: StreamingResponse, + session: Session, + api_key: Arc, +) -> Result<()> { + let mut stream = google_ai::stream_generate_content( + &session.http_client, + google_ai::API_URL, + api_key.as_ref(), + crate::ai::language_model_request_to_google_ai(request)?, + ) + .await + .context("google_ai::stream_generate_content request failed")?; + + while let Some(event) = stream.next().await { + let event = event?; + response.send(proto::LanguageModelResponse { + choices: event + .candidates + .unwrap_or_default() + .into_iter() + .map(|candidate| proto::LanguageModelChoiceDelta { + index: candidate.index as u32, + delta: Some(proto::LanguageModelResponseMessage { + role: Some(match candidate.content.role { + google_ai::Role::User => LanguageModelRole::LanguageModelUser, + google_ai::Role::Model => LanguageModelRole::LanguageModelAssistant, + } as i32), + content: Some( + candidate + .content + .parts + .into_iter() + .filter_map(|part| match part { + google_ai::Part::TextPart(part) => Some(part.text), + google_ai::Part::InlineDataPart(_) => None, + }) + .collect(), + ), + }), + finish_reason: candidate.finish_reason.map(|reason| reason.to_string()), + }) + .collect(), + })?; + } + + Ok(()) +} + +struct CountTokensWithLanguageModelRateLimit; + +impl RateLimit for CountTokensWithLanguageModelRateLimit { + fn capacity() -> usize { + std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(600) // Picked arbitrarily + } + + fn refill_duration() -> chrono::Duration { + chrono::Duration::hours(1) + } + + fn db_name() -> &'static str { + "count-tokens-with-language-model" + } +} + +async fn count_tokens_with_language_model( + request: proto::CountTokensWithLanguageModel, + response: Response, + session: Session, + google_ai_api_key: Option>, +) -> Result<()> { + authorize_access_to_language_models(&session).await?; + + if !request.model.starts_with("gemini") { + return Err(anyhow!( + "counting tokens for model: {:?} is not supported", + request.model + ))?; + } + + session + .rate_limiter + .check::(session.user_id) + .await?; + + let api_key = google_ai_api_key + .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; + let tokens_response = google_ai::count_tokens( + &session.http_client, + google_ai::API_URL, + &api_key, + crate::ai::count_tokens_request_to_google_ai(request)?, + ) + .await?; + response.send(proto::CountTokensResponse { + token_count: tokens_response.total_tokens as u32, + })?; + Ok(()) +} + +async fn authorize_access_to_language_models(session: &Session) -> Result<(), Error> { + let db = session.db().await; + let flags = db.get_user_flags(session.user_id).await?; + if flags.iter().any(|flag| flag == "language-models") { + Ok(()) + } else { + Err(anyhow!("permission denied"))? + } +} + /// Start receiving chat updates for a channel async fn join_channel_chat( request: proto::JoinChannelChat, diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 6998343ab0a113..1aa13a2e6c7714 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -2,7 +2,7 @@ use crate::{ db::{tests::TestDb, NewUserParams, UserId}, executor::Executor, rpc::{Server, ZedVersion, CLEANUP_TIMEOUT, RECONNECT_TIMEOUT}, - AppState, Config, + AppState, Config, RateLimiter, }; use anyhow::anyhow; use call::ActiveCall; @@ -93,17 +93,14 @@ impl TestServer { deterministic.clone(), ) .unwrap(); - let app_state = Self::build_app_state(&test_db, &live_kit_server).await; + let executor = Executor::Deterministic(deterministic.clone()); + let app_state = Self::build_app_state(&test_db, &live_kit_server, executor.clone()).await; let epoch = app_state .db .create_server(&app_state.config.zed_environment) .await .unwrap(); - let server = Server::new( - epoch, - app_state.clone(), - Executor::Deterministic(deterministic.clone()), - ); + let server = Server::new(epoch, app_state.clone()); server.start().await.unwrap(); // Advance clock to ensure the server's cleanup task is finished. deterministic.advance_clock(CLEANUP_TIMEOUT); @@ -482,12 +479,15 @@ impl TestServer { pub async fn build_app_state( test_db: &TestDb, - fake_server: &live_kit_client::TestServer, + live_kit_test_server: &live_kit_client::TestServer, + executor: Executor, ) -> Arc { Arc::new(AppState { db: test_db.db().clone(), - live_kit_client: Some(Arc::new(fake_server.create_api_client())), + live_kit_client: Some(Arc::new(live_kit_test_server.create_api_client())), blob_store_client: None, + rate_limiter: Arc::new(RateLimiter::new(test_db.db().clone())), + executor, clickhouse_client: None, config: Config { http_port: 0, @@ -506,6 +506,8 @@ impl TestServer { blob_store_access_key: None, blob_store_secret_key: None, blob_store_bucket: None, + openai_api_key: None, + google_ai_api_key: None, clickhouse_url: None, clickhouse_user: None, clickhouse_password: None, diff --git a/crates/google_ai/Cargo.toml b/crates/google_ai/Cargo.toml new file mode 100644 index 00000000000000..8d383608d53c97 --- /dev/null +++ b/crates/google_ai/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "google_ai" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/google_ai.rs" + +[dependencies] +anyhow.workspace = true +futures.workspace = true +serde.workspace = true +serde_json.workspace = true +util.workspace = true diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs new file mode 100644 index 00000000000000..4fe461981fe97f --- /dev/null +++ b/crates/google_ai/src/google_ai.rs @@ -0,0 +1,266 @@ +use anyhow::{anyhow, Result}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use util::http::HttpClient; + +pub const API_URL: &str = "https://generativelanguage.googleapis.com"; + +pub async fn stream_generate_content( + client: &T, + api_url: &str, + api_key: &str, + request: GenerateContentRequest, +) -> Result>> { + let uri = format!( + "{}/v1beta/models/gemini-pro:streamGenerateContent?alt=sse&key={}", + api_url, api_key + ); + + let request = serde_json::to_string(&request)?; + let mut response = client.post_json(&uri, request.into()).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + if let Some(line) = line.strip_prefix("data: ") { + match serde_json::from_str(line) { + Ok(response) => Some(Ok(response)), + Err(error) => Some(Err(anyhow!(error))), + } + } else { + None + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut text = String::new(); + response.body_mut().read_to_string(&mut text).await?; + Err(anyhow!( + "error during streamGenerateContent, status code: {:?}, body: {}", + response.status(), + text + )) + } +} + +pub async fn count_tokens( + client: &T, + api_url: &str, + api_key: &str, + request: CountTokensRequest, +) -> Result { + let uri = format!( + "{}/v1beta/models/gemini-pro:countTokens?key={}", + api_url, api_key + ); + let request = serde_json::to_string(&request)?; + let mut response = client.post_json(&uri, request.into()).await?; + let mut text = String::new(); + response.body_mut().read_to_string(&mut text).await?; + if response.status().is_success() { + Ok(serde_json::from_str::(&text)?) + } else { + Err(anyhow!( + "error during countTokens, status code: {:?}, body: {}", + response.status(), + text + )) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum Task { + #[serde(rename = "generateContent")] + GenerateContent, + #[serde(rename = "streamGenerateContent")] + StreamGenerateContent, + #[serde(rename = "countTokens")] + CountTokens, + #[serde(rename = "embedContent")] + EmbedContent, + #[serde(rename = "batchEmbedContents")] + BatchEmbedContents, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentRequest { + pub contents: Vec, + pub generation_config: Option, + pub safety_settings: Option>, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentResponse { + pub candidates: Option>, + pub prompt_feedback: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentCandidate { + pub index: usize, + pub content: Content, + pub finish_reason: Option, + pub finish_message: Option, + pub safety_ratings: Option>, + pub citation_metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Content { + pub parts: Vec, + pub role: Role, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub enum Role { + User, + Model, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Part { + TextPart(TextPart), + InlineDataPart(InlineDataPart), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TextPart { + pub text: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InlineDataPart { + pub inline_data: GenerativeContentBlob, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerativeContentBlob { + pub mime_type: String, + pub data: String, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CitationSource { + pub start_index: Option, + pub end_index: Option, + pub uri: Option, + pub license: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CitationMetadata { + pub citation_sources: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptFeedback { + pub block_reason: Option, + pub safety_ratings: Vec, + pub block_reason_message: Option, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerationConfig { + pub candidate_count: Option, + pub stop_sequences: Option>, + pub max_output_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SafetySetting { + pub category: HarmCategory, + pub threshold: HarmBlockThreshold, +} + +#[derive(Debug, Serialize, Deserialize)] +pub enum HarmCategory { + #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] + Unspecified, + #[serde(rename = "HARM_CATEGORY_DEROGATORY")] + Derogatory, + #[serde(rename = "HARM_CATEGORY_TOXICITY")] + Toxicity, + #[serde(rename = "HARM_CATEGORY_VIOLENCE")] + Violence, + #[serde(rename = "HARM_CATEGORY_SEXUAL")] + Sexual, + #[serde(rename = "HARM_CATEGORY_MEDICAL")] + Medical, + #[serde(rename = "HARM_CATEGORY_DANGEROUS")] + Dangerous, + #[serde(rename = "HARM_CATEGORY_HARASSMENT")] + Harassment, + #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")] + HateSpeech, + #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")] + SexuallyExplicit, + #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")] + DangerousContent, +} + +#[derive(Debug, Serialize)] +pub enum HarmBlockThreshold { + #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] + Unspecified, + #[serde(rename = "BLOCK_LOW_AND_ABOVE")] + BlockLowAndAbove, + #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")] + BlockMediumAndAbove, + #[serde(rename = "BLOCK_ONLY_HIGH")] + BlockOnlyHigh, + #[serde(rename = "BLOCK_NONE")] + BlockNone, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum HarmProbability { + #[serde(rename = "HARM_PROBABILITY_UNSPECIFIED")] + Unspecified, + Negligible, + Low, + Medium, + High, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SafetyRating { + pub category: HarmCategory, + pub probability: HarmProbability, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct CountTokensRequest { + pub contents: Vec, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CountTokensResponse { + pub total_tokens: usize, +} diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml new file mode 100644 index 00000000000000..4c9acdcbc0f1d5 --- /dev/null +++ b/crates/open_ai/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "open_ai" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/open_ai.rs" + +[features] +default = [] +schemars = ["dep:schemars"] + +[dependencies] +anyhow.workspace = true +futures.workspace = true +schemars = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +util.workspace = true diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs new file mode 100644 index 00000000000000..7bd7e19d5ded54 --- /dev/null +++ b/crates/open_ai/src/open_ai.rs @@ -0,0 +1,182 @@ +use anyhow::{anyhow, Result}; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use serde::{Deserialize, Serialize}; +use std::convert::TryFrom; +use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl TryFrom for Role { + type Error = anyhow::Error; + + fn try_from(value: String) -> Result { + match value.as_str() { + "user" => Ok(Self::User), + "assistant" => Ok(Self::Assistant), + "system" => Ok(Self::System), + _ => Err(anyhow!("invalid role '{value}'")), + } + } +} + +impl From for String { + fn from(val: Role) -> Self { + match val { + Role::User => "user".to_owned(), + Role::Assistant => "assistant".to_owned(), + Role::System => "system".to_owned(), + } + } +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] +pub enum Model { + #[serde(rename = "gpt-3.5-turbo", alias = "gpt-3.5-turbo-0613")] + ThreePointFiveTurbo, + #[serde(rename = "gpt-4", alias = "gpt-4-0613")] + Four, + #[serde(rename = "gpt-4-turbo-preview", alias = "gpt-4-1106-preview")] + #[default] + FourTurbo, +} + +impl Model { + pub fn from_id(id: &str) -> Result { + match id { + "gpt-3.5-turbo" => Ok(Self::ThreePointFiveTurbo), + "gpt-4" => Ok(Self::Four), + "gpt-4-turbo-preview" => Ok(Self::FourTurbo), + _ => Err(anyhow!("invalid model id")), + } + } + + pub fn id(&self) -> &'static str { + match self { + Self::ThreePointFiveTurbo => "gpt-3.5-turbo", + Self::Four => "gpt-4", + Self::FourTurbo => "gpt-4-turbo-preview", + } + } + + pub fn display_name(&self) -> &'static str { + match self { + Self::ThreePointFiveTurbo => "gpt-3.5-turbo", + Self::Four => "gpt-4", + Self::FourTurbo => "gpt-4-turbo", + } + } +} + +#[derive(Debug, Serialize)] +pub struct Request { + pub model: Model, + pub messages: Vec, + pub stream: bool, + pub stop: Vec, + pub temperature: f32, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct RequestMessage { + pub role: Role, + pub content: String, +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ResponseMessage { + pub role: Option, + pub content: Option, +} + +#[derive(Deserialize, Debug)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, +} + +#[derive(Deserialize, Debug)] +pub struct ChoiceDelta { + pub index: u32, + pub delta: ResponseMessage, + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub struct ResponseStreamEvent { + pub created: u32, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +pub async fn stream_completion( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result>> { + let uri = format!("{api_url}/chat/completions"); + let request = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", api_key)) + .body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + if line == "[DONE]" { + None + } else { + match serde_json::from_str(line) { + Ok(response) => Some(Ok(response)), + Err(error) => Some(Err(anyhow!(error))), + } + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + #[derive(Deserialize)] + struct OpenAiResponse { + error: OpenAiError, + } + + #[derive(Deserialize)] + struct OpenAiError { + message: String, + } + + match serde_json::from_str::(&body) { + Ok(response) if !response.error.message.is_empty() => Err(anyhow!( + "Failed to connect to OpenAI API: {}", + response.error.message, + )), + + _ => Err(anyhow!( + "Failed to connect to OpenAI API: {} {}", + response.status(), + body, + )), + } + } +} diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 2bdfa23011c60f..0d67d62d902f4d 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -1,7 +1,7 @@ syntax = "proto3"; package zed.messages; -// Looking for a number? Search "// Current max" +// Looking for a number? Search "// current max" message PeerId { uint32 owner_id = 1; @@ -26,6 +26,7 @@ message Envelope { Error error = 6; Ping ping = 7; Test test = 8; + EndStream end_stream = 165; CreateRoom create_room = 9; CreateRoomResponse create_room_response = 10; @@ -198,6 +199,11 @@ message Envelope { GetImplementationResponse get_implementation_response = 163; JoinHostedProject join_hosted_project = 164; + + CompleteWithLanguageModel complete_with_language_model = 166; + LanguageModelResponse language_model_response = 167; + CountTokensWithLanguageModel count_tokens_with_language_model = 168; + CountTokensResponse count_tokens_response = 169; // current max } reserved 158 to 161; @@ -236,6 +242,8 @@ enum ErrorCode { reserved 6; } +message EndStream {} + message Test { uint64 id = 1; } @@ -1718,3 +1726,45 @@ message SetRoomParticipantRole { uint64 user_id = 2; ChannelRole role = 3; } + +message CompleteWithLanguageModel { + string model = 1; + repeated LanguageModelRequestMessage messages = 2; + repeated string stop = 3; + float temperature = 4; +} + +message LanguageModelRequestMessage { + LanguageModelRole role = 1; + string content = 2; +} + +enum LanguageModelRole { + LanguageModelUser = 0; + LanguageModelAssistant = 1; + LanguageModelSystem = 2; +} + +message LanguageModelResponseMessage { + optional LanguageModelRole role = 1; + optional string content = 2; +} + +message LanguageModelResponse { + repeated LanguageModelChoiceDelta choices = 1; +} + +message LanguageModelChoiceDelta { + uint32 index = 1; + LanguageModelResponseMessage delta = 2; + optional string finish_reason = 3; +} + +message CountTokensWithLanguageModel { + string model = 1; + repeated LanguageModelRequestMessage messages = 2; +} + +message CountTokensResponse { + uint32 token_count = 1; +} diff --git a/crates/rpc/src/error.rs b/crates/rpc/src/error.rs index 858029a02b4d24..f589863f2d0a35 100644 --- a/crates/rpc/src/error.rs +++ b/crates/rpc/src/error.rs @@ -80,7 +80,7 @@ pub trait ErrorExt { fn error_tag(&self, k: &str) -> Option<&str>; /// to_proto() converts the error into a proto::Error fn to_proto(&self) -> proto::Error; - /// + /// Clones the error and turns into an [anyhow::Error]. fn cloned(&self) -> anyhow::Error; } diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index cd1cdaffcccabf..8e026953c13270 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -9,19 +9,21 @@ use collections::HashMap; use futures::{ channel::{mpsc, oneshot}, stream::BoxStream, - FutureExt, SinkExt, StreamExt, TryFutureExt, + FutureExt, SinkExt, Stream, StreamExt, TryFutureExt, }; use parking_lot::{Mutex, RwLock}; use serde::{ser::SerializeStruct, Serialize}; -use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant}; use std::{ + fmt, future, future::Future, marker::PhantomData, + sync::atomic::Ordering::SeqCst, sync::{ atomic::{self, AtomicU32}, Arc, }, time::Duration, + time::Instant, }; use tracing::instrument; @@ -118,6 +120,15 @@ pub struct ConnectionState { >, >, >, + #[allow(clippy::type_complexity)] + #[serde(skip)] + stream_response_channels: Arc< + Mutex< + Option< + HashMap, oneshot::Sender<()>)>>, + >, + >, + >, } const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1); @@ -171,17 +182,28 @@ impl Peer { outgoing_tx, next_message_id: Default::default(), response_channels: Arc::new(Mutex::new(Some(Default::default()))), + stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))), }; let mut writer = MessageStream::new(connection.tx); let mut reader = MessageStream::new(connection.rx); let this = self.clone(); let response_channels = connection_state.response_channels.clone(); + let stream_response_channels = connection_state.stream_response_channels.clone(); + let handle_io = async move { tracing::trace!(%connection_id, "handle io future: start"); let _end_connection = util::defer(|| { response_channels.lock().take(); + if let Some(channels) = stream_response_channels.lock().take() { + for channel in channels.values() { + let _ = channel.unbounded_send(( + Err(anyhow!("connection closed")), + oneshot::channel().0, + )); + } + } this.connections.write().remove(&connection_id); tracing::trace!(%connection_id, "handle io future: end"); }); @@ -273,12 +295,14 @@ impl Peer { }; let response_channels = connection_state.response_channels.clone(); + let stream_response_channels = connection_state.stream_response_channels.clone(); self.connections .write() .insert(connection_id, connection_state); let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| { let response_channels = response_channels.clone(); + let stream_response_channels = stream_response_channels.clone(); async move { let message_id = incoming.id; tracing::trace!(?incoming, "incoming message future: start"); @@ -293,8 +317,15 @@ impl Peer { responding_to, "incoming response: received" ); - let channel = response_channels.lock().as_mut()?.remove(&responding_to); - if let Some(tx) = channel { + let response_channel = + response_channels.lock().as_mut()?.remove(&responding_to); + let stream_response_channel = stream_response_channels + .lock() + .as_ref()? + .get(&responding_to) + .cloned(); + + if let Some(tx) = response_channel { let requester_resumed = oneshot::channel(); if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) { tracing::trace!( @@ -319,6 +350,31 @@ impl Peer { responding_to, "incoming response: requester resumed" ); + } else if let Some(tx) = stream_response_channel { + let requester_resumed = oneshot::channel(); + if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) { + tracing::debug!( + %connection_id, + message_id, + responding_to = responding_to, + ?error, + "incoming stream response: request future dropped", + ); + } + + tracing::debug!( + %connection_id, + message_id, + responding_to, + "incoming stream response: waiting to resume requester" + ); + let _ = requester_resumed.1.await; + tracing::debug!( + %connection_id, + message_id, + responding_to, + "incoming stream response: requester resumed" + ); } else { let message_type = proto::build_typed_envelope(connection_id, received_at, incoming) @@ -451,6 +507,66 @@ impl Peer { } } + pub fn request_stream( + &self, + receiver_id: ConnectionId, + request: T, + ) -> impl Future>>> { + let (tx, rx) = mpsc::unbounded(); + let send = self.connection_state(receiver_id).and_then(|connection| { + let message_id = connection.next_message_id.fetch_add(1, SeqCst); + let stream_response_channels = connection.stream_response_channels.clone(); + stream_response_channels + .lock() + .as_mut() + .ok_or_else(|| anyhow!("connection was closed"))? + .insert(message_id, tx); + connection + .outgoing_tx + .unbounded_send(proto::Message::Envelope( + request.into_envelope(message_id, None, None), + )) + .map_err(|_| anyhow!("connection was closed"))?; + Ok((message_id, stream_response_channels)) + }); + + async move { + let (message_id, stream_response_channels) = send?; + let stream_response_channels = Arc::downgrade(&stream_response_channels); + + Ok(rx.filter_map(move |(response, _barrier)| { + let stream_response_channels = stream_response_channels.clone(); + future::ready(match response { + Ok(response) => { + if let Some(proto::envelope::Payload::Error(error)) = &response.payload { + Some(Err(anyhow!( + "RPC request {} failed - {}", + T::NAME, + error.message + ))) + } else if let Some(proto::envelope::Payload::EndStream(_)) = + &response.payload + { + // Remove the transmitting end of the response channel to end the stream. + if let Some(channels) = stream_response_channels.upgrade() { + if let Some(channels) = channels.lock().as_mut() { + channels.remove(&message_id); + } + } + None + } else { + Some( + T::Response::from_envelope(response) + .ok_or_else(|| anyhow!("received response of the wrong type")), + ) + } + } + Err(error) => Some(Err(error)), + }) + })) + } + } + pub fn send(&self, receiver_id: ConnectionId, message: T) -> Result<()> { let connection = self.connection_state(receiver_id)?; let message_id = connection @@ -503,6 +619,24 @@ impl Peer { Ok(()) } + pub fn end_stream(&self, receipt: Receipt) -> Result<()> { + let connection = self.connection_state(receipt.sender_id)?; + let message_id = connection + .next_message_id + .fetch_add(1, atomic::Ordering::SeqCst); + + let message = proto::EndStream {}; + + connection + .outgoing_tx + .unbounded_send(proto::Message::Envelope(message.into_envelope( + message_id, + Some(receipt.message_id), + None, + )))?; + Ok(()) + } + pub fn respond_with_error( &self, receipt: Receipt, diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 40d0d6e3c4f164..b25b01a798ef78 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -149,7 +149,10 @@ messages!( (CallCanceled, Foreground), (CancelCall, Foreground), (ChannelMessageSent, Foreground), + (CompleteWithLanguageModel, Background), (CopyProjectEntry, Foreground), + (CountTokensWithLanguageModel, Background), + (CountTokensResponse, Background), (CreateBufferForPeer, Foreground), (CreateChannel, Foreground), (CreateChannelResponse, Foreground), @@ -160,6 +163,7 @@ messages!( (DeleteChannel, Foreground), (DeleteNotification, Foreground), (DeleteProjectEntry, Foreground), + (EndStream, Foreground), (Error, Foreground), (ExpandProjectEntry, Foreground), (ExpandProjectEntryResponse, Foreground), @@ -211,6 +215,7 @@ messages!( (JoinProjectResponse, Foreground), (JoinRoom, Foreground), (JoinRoomResponse, Foreground), + (LanguageModelResponse, Background), (LeaveChannelBuffer, Background), (LeaveChannelChat, Foreground), (LeaveProject, Foreground), @@ -300,6 +305,8 @@ request_messages!( (Call, Ack), (CancelCall, Ack), (CopyProjectEntry, ProjectEntryResponse), + (CompleteWithLanguageModel, LanguageModelResponse), + (CountTokensWithLanguageModel, CountTokensResponse), (CreateChannel, CreateChannelResponse), (CreateProjectEntry, ProjectEntryResponse), (CreateRoom, CreateRoomResponse), diff --git a/crates/search/Cargo.toml b/crates/search/Cargo.toml index e54b5e0338cb5e..a6546f6ae74f1e 100644 --- a/crates/search/Cargo.toml +++ b/crates/search/Cargo.toml @@ -22,7 +22,6 @@ gpui.workspace = true language.workspace = true menu.workspace = true project.workspace = true -semantic_index.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/search/src/buffer_search.rs b/crates/search/src/buffer_search.rs index 7d7626dd2f5405..f14ad4100e47e2 100644 --- a/crates/search/src/buffer_search.rs +++ b/crates/search/src/buffer_search.rs @@ -705,11 +705,6 @@ impl BufferSearchBar { option.as_button(is_active, action) } pub fn activate_search_mode(&mut self, mode: SearchMode, cx: &mut ViewContext) { - assert_ne!( - mode, - SearchMode::Semantic, - "Semantic search is not supported in buffer search" - ); if mode == self.current_mode { return; } @@ -1022,7 +1017,7 @@ impl BufferSearchBar { } } fn cycle_mode(&mut self, _: &CycleMode, cx: &mut ViewContext) { - self.activate_search_mode(next_mode(&self.current_mode, false), cx); + self.activate_search_mode(next_mode(&self.current_mode), cx); } fn toggle_replace(&mut self, _: &ToggleReplace, cx: &mut ViewContext) { if let Some(_) = &self.active_searchable_item { diff --git a/crates/search/src/mode.rs b/crates/search/src/mode.rs index 3fd53cee49e88a..957eb707a5262b 100644 --- a/crates/search/src/mode.rs +++ b/crates/search/src/mode.rs @@ -1,13 +1,12 @@ use gpui::{Action, SharedString}; -use crate::{ActivateRegexMode, ActivateSemanticMode, ActivateTextMode}; +use crate::{ActivateRegexMode, ActivateTextMode}; // TODO: Update the default search mode to get from config #[derive(Copy, Clone, Debug, Default, PartialEq)] pub enum SearchMode { #[default] Text, - Semantic, Regex, } @@ -15,7 +14,6 @@ impl SearchMode { pub(crate) fn label(&self) -> &'static str { match self { SearchMode::Text => "Text", - SearchMode::Semantic => "Semantic", SearchMode::Regex => "Regex", } } @@ -25,22 +23,14 @@ impl SearchMode { pub(crate) fn action(&self) -> Box { match self { SearchMode::Text => ActivateTextMode.boxed_clone(), - SearchMode::Semantic => ActivateSemanticMode.boxed_clone(), SearchMode::Regex => ActivateRegexMode.boxed_clone(), } } } -pub(crate) fn next_mode(mode: &SearchMode, semantic_enabled: bool) -> SearchMode { +pub(crate) fn next_mode(mode: &SearchMode) -> SearchMode { match mode { SearchMode::Text => SearchMode::Regex, - SearchMode::Regex => { - if semantic_enabled { - SearchMode::Semantic - } else { - SearchMode::Text - } - } - SearchMode::Semantic => SearchMode::Text, + SearchMode::Regex => SearchMode::Text, } } diff --git a/crates/search/src/project_search.rs b/crates/search/src/project_search.rs index da8d84b437f6c8..15f3d6184df4bd 100644 --- a/crates/search/src/project_search.rs +++ b/crates/search/src/project_search.rs @@ -1,33 +1,26 @@ use crate::{ - history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateSemanticMode, - ActivateTextMode, CycleMode, NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, - SearchOptions, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored, - ToggleReplace, ToggleWholeWord, + history::SearchHistory, mode::SearchMode, ActivateRegexMode, ActivateTextMode, CycleMode, + NextHistoryQuery, PreviousHistoryQuery, ReplaceAll, ReplaceNext, SearchOptions, + SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleIncludeIgnored, ToggleReplace, + ToggleWholeWord, }; -use anyhow::{Context as _, Result}; -use collections::HashMap; +use anyhow::Context as _; +use collections::{HashMap, HashSet}; use editor::{ actions::SelectAll, items::active_match_index, scroll::{Autoscroll, Axis}, - Anchor, Editor, EditorEvent, MultiBuffer, MAX_TAB_TITLE_LEN, + Anchor, Editor, EditorElement, EditorEvent, EditorStyle, MultiBuffer, MAX_TAB_TITLE_LEN, }; -use editor::{EditorElement, EditorStyle}; use gpui::{ actions, div, Action, AnyElement, AnyView, AppContext, Context as _, Element, EntityId, EventEmitter, FocusHandle, FocusableView, FontStyle, FontWeight, Global, Hsla, - InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point, - PromptLevel, Render, SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext, - VisualContext, WeakModel, WeakView, WhiteSpace, WindowContext, + InteractiveElement, IntoElement, KeyContext, Model, ModelContext, ParentElement, Point, Render, + SharedString, Styled, Subscription, Task, TextStyle, View, ViewContext, VisualContext, + WeakModel, WeakView, WhiteSpace, WindowContext, }; use menu::Confirm; -use project::{ - search::{SearchInputs, SearchQuery}, - Project, -}; -use semantic_index::{SemanticIndex, SemanticIndexStatus}; - -use collections::HashSet; +use project::{search::SearchQuery, Project}; use settings::Settings; use smol::stream::StreamExt; use std::{ @@ -35,22 +28,20 @@ use std::{ mem, ops::{Not, Range}, path::{Path, PathBuf}, - time::{Duration, Instant}, }; use theme::ThemeSettings; -use workspace::{DeploySearch, NewSearch}; - use ui::{ h_flex, prelude::*, v_flex, Icon, IconButton, IconName, Label, LabelCommon, LabelSize, Selectable, ToggleButton, Tooltip, }; -use util::{paths::PathMatcher, ResultExt as _}; +use util::paths::PathMatcher; use workspace::{ item::{BreadcrumbText, Item, ItemEvent, ItemHandle}, searchable::{Direction, SearchableItem, SearchableItemHandle}, ItemNavHistory, Pane, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView, Workspace, WorkspaceId, }; +use workspace::{DeploySearch, NewSearch}; const MIN_INPUT_WIDTH_REMS: f32 = 15.; const MAX_INPUT_WIDTH_REMS: f32 = 30.; @@ -86,12 +77,6 @@ pub fn init(cx: &mut AppContext) { register_workspace_action(workspace, move |search_bar, _: &ActivateTextMode, cx| { search_bar.activate_search_mode(SearchMode::Text, cx) }); - register_workspace_action( - workspace, - move |search_bar, _: &ActivateSemanticMode, cx| { - search_bar.activate_search_mode(SearchMode::Semantic, cx) - }, - ); register_workspace_action(workspace, move |search_bar, action: &CycleMode, cx| { search_bar.cycle_mode(action, cx) }); @@ -159,8 +144,6 @@ pub struct ProjectSearchView { query_editor: View, replacement_editor: View, results_editor: View, - semantic_state: Option, - semantic_permissioned: Option, search_options: SearchOptions, panels_with_errors: HashSet, active_match_index: Option, @@ -174,12 +157,6 @@ pub struct ProjectSearchView { _subscriptions: Vec, } -struct SemanticState { - index_status: SemanticIndexStatus, - maintain_rate_limit: Option>, - _subscription: Subscription, -} - #[derive(Debug, Clone)] struct ProjectSearchSettings { search_options: SearchOptions, @@ -282,68 +259,6 @@ impl ProjectSearch { })); cx.notify(); } - - fn semantic_search(&mut self, inputs: &SearchInputs, cx: &mut ModelContext) { - let search = SemanticIndex::global(cx).map(|index| { - index.update(cx, |semantic_index, cx| { - semantic_index.search_project( - self.project.clone(), - inputs.as_str().to_owned(), - 10, - inputs.files_to_include().to_vec(), - inputs.files_to_exclude().to_vec(), - cx, - ) - }) - }); - self.search_id += 1; - self.match_ranges.clear(); - self.search_history.add(inputs.as_str().to_string()); - self.no_results = None; - self.pending_search = Some(cx.spawn(|this, mut cx| async move { - let results = search?.await.log_err()?; - let matches = results - .into_iter() - .map(|result| (result.buffer, vec![result.range.start..result.range.start])); - - this.update(&mut cx, |this, cx| { - this.no_results = Some(true); - this.excerpts.update(cx, |excerpts, cx| { - excerpts.clear(cx); - }); - }) - .ok()?; - for (buffer, ranges) in matches { - let mut match_ranges = this - .update(&mut cx, |this, cx| { - this.no_results = Some(false); - this.excerpts.update(cx, |excerpts, cx| { - excerpts.stream_excerpts_with_context_lines(buffer, ranges, 3, cx) - }) - }) - .ok()?; - while let Some(match_range) = match_ranges.next().await { - this.update(&mut cx, |this, cx| { - this.match_ranges.push(match_range); - while let Ok(Some(match_range)) = match_ranges.try_next() { - this.match_ranges.push(match_range); - } - cx.notify(); - }) - .ok()?; - } - } - - this.update(&mut cx, |this, cx| { - this.pending_search.take(); - cx.notify(); - }) - .ok()?; - - None - })); - cx.notify(); - } } #[derive(Clone, Debug, PartialEq, Eq)] @@ -358,8 +273,6 @@ impl EventEmitter for ProjectSearchView {} impl Render for ProjectSearchView { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - const PLEASE_AUTHENTICATE: &str = "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables. If you authenticated using the Assistant Panel, please restart Zed to Authenticate."; - if self.has_matches() { div() .flex_1() @@ -370,7 +283,7 @@ impl Render for ProjectSearchView { let model = self.model.read(cx); let has_no_results = model.no_results.unwrap_or(false); let is_search_underway = model.pending_search.is_some(); - let mut major_text = if is_search_underway { + let major_text = if is_search_underway { Label::new("Searching...") } else if has_no_results { Label::new("No results") @@ -378,43 +291,6 @@ impl Render for ProjectSearchView { Label::new(format!("{} search all files", self.current_mode.label())) }; - let mut show_minor_text = true; - let semantic_status = self.semantic_state.as_ref().and_then(|semantic| { - let status = semantic.index_status; - match status { - SemanticIndexStatus::NotAuthenticated => { - major_text = Label::new("Not Authenticated"); - show_minor_text = false; - Some(PLEASE_AUTHENTICATE.to_string()) - } - SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()), - SemanticIndexStatus::Indexing { - remaining_files, - rate_limit_expiry, - } => { - if remaining_files == 0 { - Some("Indexing...".to_string()) - } else { - if let Some(rate_limit_expiry) = rate_limit_expiry { - let remaining_seconds = - rate_limit_expiry.duration_since(Instant::now()); - if remaining_seconds > Duration::from_secs(0) { - Some(format!( - "Remaining files to index (rate limit resets in {}s): {}", - remaining_seconds.as_secs(), - remaining_files - )) - } else { - Some(format!("Remaining files to index: {}", remaining_files)) - } - } else { - Some(format!("Remaining files to index: {}", remaining_files)) - } - } - } - SemanticIndexStatus::NotIndexed => None, - } - }); let major_text = div().justify_center().max_w_96().child(major_text); let minor_text: Option = if let Some(no_results) = model.no_results { @@ -424,12 +300,7 @@ impl Render for ProjectSearchView { None } } else { - if let Some(mut semantic_status) = semantic_status { - semantic_status.extend(self.landing_text_minor().chars()); - Some(semantic_status.into()) - } else { - Some(self.landing_text_minor()) - } + Some(self.landing_text_minor()) }; let minor_text = minor_text.map(|text| { div() @@ -676,58 +547,6 @@ impl ProjectSearchView { }); } - fn index_project(&mut self, cx: &mut ViewContext) { - if let Some(semantic_index) = SemanticIndex::global(cx) { - // Semantic search uses no options - self.search_options = SearchOptions::none(); - - let project = self.model.read(cx).project.clone(); - - semantic_index.update(cx, |semantic_index, cx| { - semantic_index - .index_project(project.clone(), cx) - .detach_and_log_err(cx); - }); - - self.semantic_state = Some(SemanticState { - index_status: semantic_index.read(cx).status(&project), - maintain_rate_limit: None, - _subscription: cx.observe(&semantic_index, Self::semantic_index_changed), - }); - self.semantic_index_changed(semantic_index, cx); - } - } - - fn semantic_index_changed( - &mut self, - semantic_index: Model, - cx: &mut ViewContext, - ) { - let project = self.model.read(cx).project.clone(); - if let Some(semantic_state) = self.semantic_state.as_mut() { - cx.notify(); - semantic_state.index_status = semantic_index.read(cx).status(&project); - if let SemanticIndexStatus::Indexing { - rate_limit_expiry: Some(_), - .. - } = &semantic_state.index_status - { - if semantic_state.maintain_rate_limit.is_none() { - semantic_state.maintain_rate_limit = - Some(cx.spawn(|this, mut cx| async move { - loop { - cx.background_executor().timer(Duration::from_secs(1)).await; - this.update(&mut cx, |_, cx| cx.notify()).log_err(); - } - })); - return; - } - } else { - semantic_state.maintain_rate_limit = None; - } - } - } - fn clear_search(&mut self, cx: &mut ViewContext) { self.model.update(cx, |model, cx| { model.pending_search = None; @@ -750,63 +569,7 @@ impl ProjectSearchView { self.clear_search(cx); self.current_mode = mode; self.active_match_index = None; - - match mode { - SearchMode::Semantic => { - let has_permission = self.semantic_permissioned(cx); - self.active_match_index = None; - cx.spawn(|this, mut cx| async move { - let has_permission = has_permission.await?; - - if !has_permission { - let answer = this.update(&mut cx, |this, cx| { - let project = this.model.read(cx).project.clone(); - let project_name = project - .read(cx) - .worktree_root_names(cx) - .collect::>() - .join("/"); - let is_plural = - project_name.chars().filter(|letter| *letter == '/').count() > 0; - let prompt_text = format!("Would you like to index the '{}' project{} for semantic search? This requires sending code to the OpenAI API", project_name, - if is_plural { - "s" - } else {""}); - cx.prompt( - PromptLevel::Info, - prompt_text.as_str(), - None, - &["Continue", "Cancel"], - ) - })?; - - if answer.await? == 0 { - this.update(&mut cx, |this, _| { - this.semantic_permissioned = Some(true); - })?; - } else { - this.update(&mut cx, |this, cx| { - this.semantic_permissioned = Some(false); - debug_assert_ne!(previous_mode, SearchMode::Semantic, "Tried to re-enable semantic search mode after user modal was rejected"); - this.activate_search_mode(previous_mode, cx); - })?; - return anyhow::Ok(()); - } - } - - this.update(&mut cx, |this, cx| { - this.index_project(cx); - })?; - - anyhow::Ok(()) - }).detach_and_log_err(cx); - } - SearchMode::Regex | SearchMode::Text => { - self.semantic_state = None; - self.active_match_index = None; - self.search(cx); - } - } + self.search(cx); cx.update_global(|state: &mut ActiveSettings, cx| { state.0.insert( @@ -973,8 +736,6 @@ impl ProjectSearchView { model, query_editor, results_editor, - semantic_state: None, - semantic_permissioned: None, search_options: options, panels_with_errors: HashSet::default(), active_match_index: None, @@ -990,19 +751,6 @@ impl ProjectSearchView { this } - fn semantic_permissioned(&mut self, cx: &mut ViewContext) -> Task> { - if let Some(value) = self.semantic_permissioned { - return Task::ready(Ok(value)); - } - - SemanticIndex::global(cx) - .map(|semantic| { - let project = self.model.read(cx).project.clone(); - semantic.update(cx, |this, cx| this.project_previously_indexed(&project, cx)) - }) - .unwrap_or(Task::ready(Ok(false))) - } - pub fn new_search_in_directory( workspace: &mut Workspace, dir_path: &Path, @@ -1126,22 +874,8 @@ impl ProjectSearchView { } fn search(&mut self, cx: &mut ViewContext) { - let mode = self.current_mode; - match mode { - SearchMode::Semantic => { - if self.semantic_state.is_some() { - if let Some(query) = self.build_search_query(cx) { - self.model - .update(cx, |model, cx| model.semantic_search(query.as_inner(), cx)); - } - } - } - - _ => { - if let Some(query) = self.build_search_query(cx) { - self.model.update(cx, |model, cx| model.search(query, cx)); - } - } + if let Some(query) = self.build_search_query(cx) { + self.model.update(cx, |model, cx| model.search(query, cx)); } } @@ -1356,7 +1090,6 @@ impl ProjectSearchView { fn landing_text_minor(&self) -> SharedString { match self.current_mode { SearchMode::Text | SearchMode::Regex => "Include/exclude specific paths with the filter option. Matching exact word and/or casing is available too.".into(), - SearchMode::Semantic => "\nSimply explain the code you are looking to find. ex. 'prompt user for permissions to index their project'".into() } } fn border_color_for(&self, panel: InputPanel, cx: &WindowContext) -> Hsla { @@ -1387,8 +1120,7 @@ impl ProjectSearchBar { fn cycle_mode(&self, _: &CycleMode, cx: &mut ViewContext) { if let Some(view) = self.active_project_search.as_ref() { view.update(cx, |this, cx| { - let new_mode = - crate::mode::next_mode(&this.current_mode, SemanticIndex::enabled(cx)); + let new_mode = crate::mode::next_mode(&this.current_mode); this.activate_search_mode(new_mode, cx); let editor_handle = this.query_editor.focus_handle(cx); cx.focus(&editor_handle); @@ -1681,7 +1413,6 @@ impl Render for ProjectSearchBar { }); } let search = search.read(cx); - let semantic_is_available = SemanticIndex::enabled(cx); let query_column = h_flex() .flex_1() @@ -1711,12 +1442,8 @@ impl Render for ProjectSearchBar { .unwrap_or_default(), ), ) - .when(search.current_mode != SearchMode::Semantic, |this| { - this.child( - IconButton::new( - "project-search-case-sensitive", - IconName::CaseSensitive, - ) + .child( + IconButton::new("project-search-case-sensitive", IconName::CaseSensitive) .tooltip(|cx| { Tooltip::for_action( "Toggle case sensitive", @@ -1728,18 +1455,17 @@ impl Render for ProjectSearchBar { .on_click(cx.listener(|this, _, cx| { this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx); })), - ) - .child( - IconButton::new("project-search-whole-word", IconName::WholeWord) - .tooltip(|cx| { - Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx) - }) - .selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx)) - .on_click(cx.listener(|this, _, cx| { - this.toggle_search_option(SearchOptions::WHOLE_WORD, cx); - })), - ) - }), + ) + .child( + IconButton::new("project-search-whole-word", IconName::WholeWord) + .tooltip(|cx| { + Tooltip::for_action("Toggle whole word", &ToggleWholeWord, cx) + }) + .selected(self.is_option_enabled(SearchOptions::WHOLE_WORD, cx)) + .on_click(cx.listener(|this, _, cx| { + this.toggle_search_option(SearchOptions::WHOLE_WORD, cx); + })), + ), ); let mode_column = v_flex().items_start().justify_start().child( @@ -1775,33 +1501,8 @@ impl Render for ProjectSearchBar { cx, ) }) - .map(|this| { - if semantic_is_available { - this.middle() - } else { - this.last() - } - }), - ) - .when(semantic_is_available, |this| { - this.child( - ToggleButton::new("project-search-semantic-button", "Semantic") - .style(ButtonStyle::Filled) - .size(ButtonSize::Large) - .selected(search.current_mode == SearchMode::Semantic) - .on_click(cx.listener(|this, _, cx| { - this.activate_search_mode(SearchMode::Semantic, cx) - })) - .tooltip(|cx| { - Tooltip::for_action( - "Toggle semantic search", - &ActivateSemanticMode, - cx, - ) - }) - .last(), - ) - }), + .last(), + ), ) .child( IconButton::new("project-search-toggle-replace", IconName::Replace) @@ -1929,21 +1630,16 @@ impl Render for ProjectSearchBar { .border_color(search.border_color_for(InputPanel::Include, cx)) .rounded_lg() .child(self.render_text_input(&search.included_files_editor, cx)) - .when(search.current_mode != SearchMode::Semantic, |this| { - this.child( - SearchOptions::INCLUDE_IGNORED.as_button( - search - .search_options - .contains(SearchOptions::INCLUDE_IGNORED), - cx.listener(|this, _, cx| { - this.toggle_search_option( - SearchOptions::INCLUDE_IGNORED, - cx, - ); - }), - ), - ) - }), + .child( + SearchOptions::INCLUDE_IGNORED.as_button( + search + .search_options + .contains(SearchOptions::INCLUDE_IGNORED), + cx.listener(|this, _, cx| { + this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx); + }), + ), + ), ) .child( h_flex() @@ -1972,9 +1668,6 @@ impl Render for ProjectSearchBar { .on_action(cx.listener(|this, _: &ActivateRegexMode, cx| { this.activate_search_mode(SearchMode::Regex, cx) })) - .on_action(cx.listener(|this, _: &ActivateSemanticMode, cx| { - this.activate_search_mode(SearchMode::Semantic, cx) - })) .capture_action(cx.listener(|this, action, cx| { this.tab(action, cx); cx.stop_propagation(); @@ -1987,35 +1680,33 @@ impl Render for ProjectSearchBar { .on_action(cx.listener(|this, action, cx| { this.cycle_mode(action, cx); })) - .when(search.current_mode != SearchMode::Semantic, |this| { - this.on_action(cx.listener(|this, action, cx| { - this.toggle_replace(action, cx); - })) - .on_action(cx.listener(|this, _: &ToggleWholeWord, cx| { - this.toggle_search_option(SearchOptions::WHOLE_WORD, cx); - })) - .on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| { - this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx); - })) - .on_action(cx.listener(|this, action, cx| { - if let Some(search) = this.active_project_search.as_ref() { - search.update(cx, |this, cx| { - this.replace_next(action, cx); - }) - } - })) - .on_action(cx.listener(|this, action, cx| { - if let Some(search) = this.active_project_search.as_ref() { - search.update(cx, |this, cx| { - this.replace_all(action, cx); - }) - } + .on_action(cx.listener(|this, action, cx| { + this.toggle_replace(action, cx); + })) + .on_action(cx.listener(|this, _: &ToggleWholeWord, cx| { + this.toggle_search_option(SearchOptions::WHOLE_WORD, cx); + })) + .on_action(cx.listener(|this, _: &ToggleCaseSensitive, cx| { + this.toggle_search_option(SearchOptions::CASE_SENSITIVE, cx); + })) + .on_action(cx.listener(|this, action, cx| { + if let Some(search) = this.active_project_search.as_ref() { + search.update(cx, |this, cx| { + this.replace_next(action, cx); + }) + } + })) + .on_action(cx.listener(|this, action, cx| { + if let Some(search) = this.active_project_search.as_ref() { + search.update(cx, |this, cx| { + this.replace_all(action, cx); + }) + } + })) + .when(search.filters_enabled, |this| { + this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| { + this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx); })) - .when(search.filters_enabled, |this| { - this.on_action(cx.listener(|this, _: &ToggleIncludeIgnored, cx| { - this.toggle_search_option(SearchOptions::INCLUDE_IGNORED, cx); - })) - }) }) .on_action(cx.listener(Self::select_next_match)) .on_action(cx.listener(Self::select_prev_match)) @@ -2039,12 +1730,6 @@ impl ToolbarItemView for ProjectSearchBar { self.subscription = None; self.active_project_search = None; if let Some(search) = active_pane_item.and_then(|i| i.downcast::()) { - search.update(cx, |search, cx| { - if search.current_mode == SearchMode::Semantic { - search.index_project(cx); - } - }); - self.subscription = Some(cx.observe(&search, |_, _, cx| cx.notify())); self.active_project_search = Some(search); ToolbarItemLocation::PrimaryLeft {} @@ -2123,9 +1808,8 @@ pub mod tests { use editor::DisplayPoint; use gpui::{Action, TestAppContext, WindowHandle}; use project::FakeFs; - use semantic_index::semantic_index_settings::SemanticIndexSettings; use serde_json::json; - use settings::{Settings, SettingsStore}; + use settings::SettingsStore; use std::sync::Arc; use workspace::DeploySearch; @@ -3446,8 +3130,6 @@ pub mod tests { let settings = SettingsStore::test(cx); cx.set_global(settings); - SemanticIndexSettings::register(cx); - theme::init(theme::LoadThemes::JustBase, cx); language::init(cx); diff --git a/crates/search/src/search.rs b/crates/search/src/search.rs index 18e287bfeebbad..a585c15361fc12 100644 --- a/crates/search/src/search.rs +++ b/crates/search/src/search.rs @@ -33,7 +33,6 @@ actions!( NextHistoryQuery, PreviousHistoryQuery, ActivateTextMode, - ActivateSemanticMode, ActivateRegexMode, ReplaceAll, ReplaceNext, diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml deleted file mode 100644 index 957a5e3cdf680e..00000000000000 --- a/crates/semantic_index/Cargo.toml +++ /dev/null @@ -1,66 +0,0 @@ -[package] -name = "semantic_index" -version = "0.1.0" -edition = "2021" -publish = false -license = "GPL-3.0-or-later" - -[lints] -workspace = true - -[lib] -path = "src/semantic_index.rs" -doctest = false - -[dependencies] -ai.workspace = true -anyhow.workspace = true -collections.workspace = true -futures.workspace = true -gpui.workspace = true -language.workspace = true -lazy_static.workspace = true -log.workspace = true -ndarray = { version = "0.15.0" } -ordered-float.workspace = true -parking_lot.workspace = true -postage.workspace = true -project.workspace = true -rand.workspace = true -release_channel.workspace = true -rpc.workspace = true -rusqlite.workspace = true -schemars.workspace = true -serde.workspace = true -serde_json.workspace = true -settings.workspace = true -sha1 = "0.10.5" -smol.workspace = true -tree-sitter.workspace = true -util.workspace = true -workspace.workspace = true - -[dev-dependencies] -ai = { workspace = true, features = ["test-support"] } -collections = { workspace = true, features = ["test-support"] } -ctor.workspace = true -env_logger.workspace = true -gpui = { workspace = true, features = ["test-support"] } -language = { workspace = true, features = ["test-support"] } -pretty_assertions.workspace = true -project = { workspace = true, features = ["test-support"] } -rand.workspace = true -rpc = { workspace = true, features = ["test-support"] } -settings = { workspace = true, features = ["test-support"]} -tempfile.workspace = true -tree-sitter-cpp.workspace = true -tree-sitter-elixir.workspace = true -tree-sitter-json.workspace = true -tree-sitter-lua.workspace = true -tree-sitter-php.workspace = true -tree-sitter-ruby.workspace = true -tree-sitter-rust.workspace = true -tree-sitter-toml.workspace = true -tree-sitter-typescript.workspace = true -unindent.workspace = true -workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/semantic_index/LICENSE-GPL b/crates/semantic_index/LICENSE-GPL deleted file mode 120000 index 89e542f750cd38..00000000000000 --- a/crates/semantic_index/LICENSE-GPL +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-GPL \ No newline at end of file diff --git a/crates/semantic_index/README.md b/crates/semantic_index/README.md deleted file mode 100644 index 75ccb41b84468b..00000000000000 --- a/crates/semantic_index/README.md +++ /dev/null @@ -1,20 +0,0 @@ - -# Semantic Index - -## Evaluation - -### Metrics - -nDCG@k: -- "The value of NDCG is determined by comparing the relevance of the items returned by the search engine to the relevance of the item that a hypothetical "ideal" search engine would return. -- "The relevance of result is represented by a score (also known as a 'grade') that is assigned to the search query. The scores of these results are then discounted based on their position in the search results -- did they get recommended first or last?" - -MRR@k: -- "Mean reciprocal rank quantifies the rank of the first relevant item found in the recommendation list." - -MAP@k: -- "Mean average precision averages the precision@k metric at each relevant item position in the recommendation list. - -Resources: -- [Evaluating recommendation metrics](https://www.shaped.ai/blog/evaluating-recommendation-systems-map-mmr-ndcg) -- [Math Walkthrough](https://towardsdatascience.com/demystifying-ndcg-bee3be58cfe0) diff --git a/crates/semantic_index/eval/gpt-engineer.json b/crates/semantic_index/eval/gpt-engineer.json deleted file mode 100644 index d008cc65d13b0c..00000000000000 --- a/crates/semantic_index/eval/gpt-engineer.json +++ /dev/null @@ -1,114 +0,0 @@ -{ - "repo": "https://github.com/AntonOsika/gpt-engineer.git", - "commit": "7735a6445bae3611c62f521e6464c67c957f87c2", - "assertions": [ - { - "query": "How do I contribute to this project?", - "matches": [ - ".github/CONTRIBUTING.md:1", - "ROADMAP.md:48" - ] - }, - { - "query": "What version of the openai package is active?", - "matches": [ - "pyproject.toml:14" - ] - }, - { - "query": "Ask user for clarification", - "matches": [ - "gpt_engineer/steps.py:69" - ] - }, - { - "query": "generate tests for python code", - "matches": [ - "gpt_engineer/steps.py:153" - ] - }, - { - "query": "get item from database based on key", - "matches": [ - "gpt_engineer/db.py:42", - "gpt_engineer/db.py:68" - ] - }, - { - "query": "prompt user to select files", - "matches": [ - "gpt_engineer/file_selector.py:171", - "gpt_engineer/file_selector.py:306", - "gpt_engineer/file_selector.py:289", - "gpt_engineer/file_selector.py:234" - ] - }, - { - "query": "send to rudderstack", - "matches": [ - "gpt_engineer/collect.py:11", - "gpt_engineer/collect.py:38" - ] - }, - { - "query": "parse code blocks from chat messages", - "matches": [ - "gpt_engineer/chat_to_files.py:10", - "docs/intro/chat_parsing.md:1" - ] - }, - { - "query": "how do I use the docker cli?", - "matches": [ - "docker/README.md:1" - ] - }, - { - "query": "ask the user if the code ran successfully?", - "matches": [ - "gpt_engineer/learning.py:54" - ] - }, - { - "query": "how is consent granted by the user?", - "matches": [ - "gpt_engineer/learning.py:107", - "gpt_engineer/learning.py:130", - "gpt_engineer/learning.py:152" - ] - }, - { - "query": "what are all the different steps the agent can take?", - "matches": [ - "docs/intro/steps_module.md:1", - "gpt_engineer/steps.py:391" - ] - }, - { - "query": "ask the user for clarification?", - "matches": [ - "gpt_engineer/steps.py:69" - ] - }, - { - "query": "what models are available?", - "matches": [ - "gpt_engineer/ai.py:315", - "gpt_engineer/ai.py:341", - "docs/open-models.md:1" - ] - }, - { - "query": "what is the current focus of the project?", - "matches": [ - "ROADMAP.md:11" - ] - }, - { - "query": "does the agent know how to fix code?", - "matches": [ - "gpt_engineer/steps.py:367" - ] - } - ] -} diff --git a/crates/semantic_index/eval/tree-sitter.json b/crates/semantic_index/eval/tree-sitter.json deleted file mode 100644 index d3dcc86937d723..00000000000000 --- a/crates/semantic_index/eval/tree-sitter.json +++ /dev/null @@ -1,104 +0,0 @@ -{ - "repo": "https://github.com/tree-sitter/tree-sitter.git", - "commit": "46af27796a76c72d8466627d499f2bca4af958ee", - "assertions": [ - { - "query": "What attributes are available for the tags configuration struct?", - "matches": [ - "tags/src/lib.rs:24" - ] - }, - { - "query": "create a new tag configuration", - "matches": [ - "tags/src/lib.rs:119" - ] - }, - { - "query": "generate tags based on config", - "matches": [ - "tags/src/lib.rs:261" - ] - }, - { - "query": "match on ts quantifier in rust", - "matches": [ - "lib/binding_rust/lib.rs:139" - ] - }, - { - "query": "cli command to generate tags", - "matches": [ - "cli/src/tags.rs:10" - ] - }, - { - "query": "what version of the tree-sitter-tags package is active?", - "matches": [ - "tags/Cargo.toml:4" - ] - }, - { - "query": "Insert a new parse state", - "matches": [ - "cli/src/generate/build_tables/build_parse_table.rs:153" - ] - }, - { - "query": "Handle conflict when numerous actions occur on the same symbol", - "matches": [ - "cli/src/generate/build_tables/build_parse_table.rs:363", - "cli/src/generate/build_tables/build_parse_table.rs:442" - ] - }, - { - "query": "Match based on associativity of actions", - "matches": [ - "cri/src/generate/build_tables/build_parse_table.rs:542" - ] - }, - { - "query": "Format token set display", - "matches": [ - "cli/src/generate/build_tables/item.rs:246" - ] - }, - { - "query": "extract choices from rule", - "matches": [ - "cli/src/generate/prepare_grammar/flatten_grammar.rs:124" - ] - }, - { - "query": "How do we identify if a symbol is being used?", - "matches": [ - "cli/src/generate/prepare_grammar/flatten_grammar.rs:175" - ] - }, - { - "query": "How do we launch the playground?", - "matches": [ - "cli/src/playground.rs:46" - ] - }, - { - "query": "How do we test treesitter query matches in rust?", - "matches": [ - "cli/src/query_testing.rs:152", - "cli/src/tests/query_test.rs:781", - "cli/src/tests/query_test.rs:2163", - "cli/src/tests/query_test.rs:3781", - "cli/src/tests/query_test.rs:887" - ] - }, - { - "query": "What does the CLI do?", - "matches": [ - "cli/README.md:10", - "cli/loader/README.md:3", - "docs/section-5-implementation.md:14", - "docs/section-5-implementation.md:18" - ] - } - ] -} diff --git a/crates/semantic_index/src/db.rs b/crates/semantic_index/src/db.rs deleted file mode 100644 index 242e80026a4fff..00000000000000 --- a/crates/semantic_index/src/db.rs +++ /dev/null @@ -1,594 +0,0 @@ -use crate::{ - parsing::{Span, SpanDigest}, - SEMANTIC_INDEX_VERSION, -}; -use ai::embedding::Embedding; -use anyhow::{anyhow, Context, Result}; -use collections::HashMap; -use futures::channel::oneshot; -use gpui::BackgroundExecutor; -use ndarray::{Array1, Array2}; -use ordered_float::OrderedFloat; -use project::Fs; -use rpc::proto::Timestamp; -use rusqlite::params; -use rusqlite::types::Value; -use std::{ - future::Future, - ops::Range, - path::{Path, PathBuf}, - rc::Rc, - sync::Arc, - time::SystemTime, -}; -use util::{paths::PathMatcher, TryFutureExt}; - -pub fn argsort(data: &[T]) -> Vec { - let mut indices = (0..data.len()).collect::>(); - indices.sort_by_key(|&i| &data[i]); - indices.reverse(); - indices -} - -#[derive(Debug)] -pub struct FileRecord { - pub id: usize, - pub relative_path: String, - pub mtime: Timestamp, -} - -#[derive(Clone)] -pub struct VectorDatabase { - path: Arc, - transactions: - smol::channel::Sender>, -} - -impl VectorDatabase { - pub async fn new( - fs: Arc, - path: Arc, - executor: BackgroundExecutor, - ) -> Result { - if let Some(db_directory) = path.parent() { - fs.create_dir(db_directory).await?; - } - - let (transactions_tx, transactions_rx) = smol::channel::unbounded::< - Box, - >(); - executor - .spawn({ - let path = path.clone(); - async move { - let mut connection = rusqlite::Connection::open(&path)?; - - connection.pragma_update(None, "journal_mode", "wal")?; - connection.pragma_update(None, "synchronous", "normal")?; - connection.pragma_update(None, "cache_size", 1000000)?; - connection.pragma_update(None, "temp_store", "MEMORY")?; - - while let Ok(transaction) = transactions_rx.recv().await { - transaction(&mut connection); - } - - anyhow::Ok(()) - } - .log_err() - }) - .detach(); - let this = Self { - transactions: transactions_tx, - path, - }; - this.initialize_database().await?; - Ok(this) - } - - pub fn path(&self) -> &Arc { - &self.path - } - - fn transact(&self, f: F) -> impl Future> - where - F: 'static + Send + FnOnce(&rusqlite::Transaction) -> Result, - T: 'static + Send, - { - let (tx, rx) = oneshot::channel(); - let transactions = self.transactions.clone(); - async move { - if transactions - .send(Box::new(|connection| { - let result = connection - .transaction() - .map_err(|err| anyhow!(err)) - .and_then(|transaction| { - let result = f(&transaction)?; - transaction.commit()?; - Ok(result) - }); - let _ = tx.send(result); - })) - .await - .is_err() - { - return Err(anyhow!("connection was dropped"))?; - } - rx.await? - } - } - - fn initialize_database(&self) -> impl Future> { - self.transact(|db| { - rusqlite::vtab::array::load_module(&db)?; - - // Delete existing tables, if SEMANTIC_INDEX_VERSION is bumped - let version_query = db.prepare("SELECT version from semantic_index_config"); - let version = version_query - .and_then(|mut query| query.query_row([], |row| row.get::<_, i64>(0))); - if version.map_or(false, |version| version == SEMANTIC_INDEX_VERSION as i64) { - log::trace!("vector database schema up to date"); - return Ok(()); - } - - log::trace!("vector database schema out of date. updating..."); - // We renamed the `documents` table to `spans`, so we want to drop - // `documents` without recreating it if it exists. - db.execute("DROP TABLE IF EXISTS documents", []) - .context("failed to drop 'documents' table")?; - db.execute("DROP TABLE IF EXISTS spans", []) - .context("failed to drop 'spans' table")?; - db.execute("DROP TABLE IF EXISTS files", []) - .context("failed to drop 'files' table")?; - db.execute("DROP TABLE IF EXISTS worktrees", []) - .context("failed to drop 'worktrees' table")?; - db.execute("DROP TABLE IF EXISTS semantic_index_config", []) - .context("failed to drop 'semantic_index_config' table")?; - - // Initialize Vector Databasing Tables - db.execute( - "CREATE TABLE semantic_index_config ( - version INTEGER NOT NULL - )", - [], - )?; - - db.execute( - "INSERT INTO semantic_index_config (version) VALUES (?1)", - params![SEMANTIC_INDEX_VERSION], - )?; - - db.execute( - "CREATE TABLE worktrees ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - absolute_path VARCHAR NOT NULL - ); - CREATE UNIQUE INDEX worktrees_absolute_path ON worktrees (absolute_path); - ", - [], - )?; - - db.execute( - "CREATE TABLE files ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - worktree_id INTEGER NOT NULL, - relative_path VARCHAR NOT NULL, - mtime_seconds INTEGER NOT NULL, - mtime_nanos INTEGER NOT NULL, - FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE - )", - [], - )?; - - db.execute( - "CREATE UNIQUE INDEX files_worktree_id_and_relative_path ON files (worktree_id, relative_path)", - [], - )?; - - db.execute( - "CREATE TABLE spans ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - file_id INTEGER NOT NULL, - start_byte INTEGER NOT NULL, - end_byte INTEGER NOT NULL, - name VARCHAR NOT NULL, - embedding BLOB NOT NULL, - digest BLOB NOT NULL, - FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE - )", - [], - )?; - db.execute( - "CREATE INDEX spans_digest ON spans (digest)", - [], - )?; - - log::trace!("vector database initialized with updated schema."); - Ok(()) - }) - } - - pub fn delete_file( - &self, - worktree_id: i64, - delete_path: Arc, - ) -> impl Future> { - self.transact(move |db| { - db.execute( - "DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2", - params![worktree_id, delete_path.to_str()], - )?; - Ok(()) - }) - } - - pub fn insert_file( - &self, - worktree_id: i64, - path: Arc, - mtime: SystemTime, - spans: Vec, - ) -> impl Future> { - self.transact(move |db| { - // Return the existing ID, if both the file and mtime match - let mtime = Timestamp::from(mtime); - - db.execute( - " - REPLACE INTO files - (worktree_id, relative_path, mtime_seconds, mtime_nanos) - VALUES (?1, ?2, ?3, ?4) - ", - params![worktree_id, path.to_str(), mtime.seconds, mtime.nanos], - )?; - - let file_id = db.last_insert_rowid(); - - let mut query = db.prepare( - " - INSERT INTO spans - (file_id, start_byte, end_byte, name, embedding, digest) - VALUES (?1, ?2, ?3, ?4, ?5, ?6) - ", - )?; - - for span in spans { - query.execute(params![ - file_id, - span.range.start.to_string(), - span.range.end.to_string(), - span.name, - span.embedding, - span.digest - ])?; - } - - Ok(()) - }) - } - - pub fn worktree_previously_indexed( - &self, - worktree_root_path: &Path, - ) -> impl Future> { - let worktree_root_path = worktree_root_path.to_string_lossy().into_owned(); - self.transact(move |db| { - let mut worktree_query = - db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; - let worktree_id = - worktree_query.query_row(params![worktree_root_path], |row| row.get::<_, i64>(0)); - - Ok(worktree_id.is_ok()) - }) - } - - pub fn embeddings_for_digests( - &self, - digests: Vec, - ) -> impl Future>> { - self.transact(move |db| { - let mut query = db.prepare( - " - SELECT digest, embedding - FROM spans - WHERE digest IN rarray(?) - ", - )?; - let mut embeddings_by_digest = HashMap::default(); - let digests = Rc::new( - digests - .into_iter() - .map(|digest| Value::Blob(digest.0.to_vec())) - .collect::>(), - ); - let rows = query.query_map(params![digests], |row| { - Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?)) - })?; - - for (digest, embedding) in rows.flatten() { - embeddings_by_digest.insert(digest, embedding); - } - - Ok(embeddings_by_digest) - }) - } - - pub fn embeddings_for_files( - &self, - worktree_id_file_paths: HashMap>>, - ) -> impl Future>> { - self.transact(move |db| { - let mut query = db.prepare( - " - SELECT digest, embedding - FROM spans - LEFT JOIN files ON files.id = spans.file_id - WHERE files.worktree_id = ? AND files.relative_path IN rarray(?) - ", - )?; - let mut embeddings_by_digest = HashMap::default(); - for (worktree_id, file_paths) in worktree_id_file_paths { - let file_paths = Rc::new( - file_paths - .into_iter() - .map(|p| Value::Text(p.to_string_lossy().into_owned())) - .collect::>(), - ); - let rows = query.query_map(params![worktree_id, file_paths], |row| { - Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?)) - })?; - - for (digest, embedding) in rows.flatten() { - embeddings_by_digest.insert(digest, embedding); - } - } - - Ok(embeddings_by_digest) - }) - } - - pub fn find_or_create_worktree( - &self, - worktree_root_path: Arc, - ) -> impl Future> { - self.transact(move |db| { - let mut worktree_query = - db.prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?; - let worktree_id = worktree_query - .query_row(params![worktree_root_path.to_string_lossy()], |row| { - row.get::<_, i64>(0) - }); - - if worktree_id.is_ok() { - return Ok(worktree_id?); - } - - // If worktree_id is Err, insert new worktree - db.execute( - "INSERT into worktrees (absolute_path) VALUES (?1)", - params![worktree_root_path.to_string_lossy()], - )?; - Ok(db.last_insert_rowid()) - }) - } - - pub fn get_file_mtimes( - &self, - worktree_id: i64, - ) -> impl Future>> { - self.transact(move |db| { - let mut statement = db.prepare( - " - SELECT relative_path, mtime_seconds, mtime_nanos - FROM files - WHERE worktree_id = ?1 - ORDER BY relative_path", - )?; - let mut result: HashMap = HashMap::default(); - for row in statement.query_map(params![worktree_id], |row| { - Ok(( - row.get::<_, String>(0)?.into(), - Timestamp { - seconds: row.get(1)?, - nanos: row.get(2)?, - } - .into(), - )) - })? { - let row = row?; - result.insert(row.0, row.1); - } - Ok(result) - }) - } - - pub fn top_k_search( - &self, - query_embedding: &Embedding, - limit: usize, - file_ids: &[i64], - ) -> impl Future)>>> { - let file_ids = file_ids.to_vec(); - let query = query_embedding.clone().0; - let query = Array1::from_vec(query); - self.transact(move |db| { - let mut query_statement = db.prepare( - " - SELECT - id, embedding - FROM - spans - WHERE - file_id IN rarray(?) - ", - )?; - - let deserialized_rows = query_statement - .query_map(params![ids_to_sql(&file_ids)], |row| { - Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?)) - })? - .filter_map(|row| row.ok()) - .collect::>(); - - if deserialized_rows.len() == 0 { - return Ok(Vec::new()); - } - - // Get Length of Embeddings Returned - let embedding_len = deserialized_rows[0].1 .0.len(); - - let batch_n = 1000; - let mut batches = Vec::new(); - let mut batch_ids = Vec::new(); - let mut batch_embeddings: Vec = Vec::new(); - deserialized_rows.iter().for_each(|(id, embedding)| { - batch_ids.push(id); - batch_embeddings.extend(&embedding.0); - - if batch_ids.len() == batch_n { - let embeddings = std::mem::take(&mut batch_embeddings); - let ids = std::mem::take(&mut batch_ids); - let array = Array2::from_shape_vec((ids.len(), embedding_len), embeddings); - match array { - Ok(array) => { - batches.push((ids, array)); - } - Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), - } - } - }); - - if batch_ids.len() > 0 { - let array = Array2::from_shape_vec( - (batch_ids.len(), embedding_len), - batch_embeddings.clone(), - ); - match array { - Ok(array) => { - batches.push((batch_ids.clone(), array)); - } - Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err), - } - } - - let mut ids: Vec = Vec::new(); - let mut results = Vec::new(); - for (batch_ids, array) in batches { - let scores = array - .dot(&query.t()) - .to_vec() - .iter() - .map(|score| OrderedFloat(*score)) - .collect::>>(); - results.extend(scores); - ids.extend(batch_ids); - } - - let sorted_idx = argsort(&results); - let mut sorted_results = Vec::new(); - let last_idx = limit.min(sorted_idx.len()); - for idx in &sorted_idx[0..last_idx] { - sorted_results.push((ids[*idx] as i64, results[*idx])) - } - - Ok(sorted_results) - }) - } - - pub fn retrieve_included_file_ids( - &self, - worktree_ids: &[i64], - includes: &[PathMatcher], - excludes: &[PathMatcher], - ) -> impl Future>> { - let worktree_ids = worktree_ids.to_vec(); - let includes = includes.to_vec(); - let excludes = excludes.to_vec(); - self.transact(move |db| { - let mut file_query = db.prepare( - " - SELECT - id, relative_path - FROM - files - WHERE - worktree_id IN rarray(?) - ", - )?; - - let mut file_ids = Vec::::new(); - let mut rows = file_query.query([ids_to_sql(&worktree_ids)])?; - - while let Some(row) = rows.next()? { - let file_id = row.get(0)?; - let relative_path = row.get_ref(1)?.as_str()?; - let included = - includes.is_empty() || includes.iter().any(|glob| glob.is_match(relative_path)); - let excluded = excludes.iter().any(|glob| glob.is_match(relative_path)); - if included && !excluded { - file_ids.push(file_id); - } - } - - anyhow::Ok(file_ids) - }) - } - - pub fn spans_for_ids( - &self, - ids: &[i64], - ) -> impl Future)>>> { - let ids = ids.to_vec(); - self.transact(move |db| { - let mut statement = db.prepare( - " - SELECT - spans.id, - files.worktree_id, - files.relative_path, - spans.start_byte, - spans.end_byte - FROM - spans, files - WHERE - spans.file_id = files.id AND - spans.id in rarray(?) - ", - )?; - - let result_iter = statement.query_map(params![ids_to_sql(&ids)], |row| { - Ok(( - row.get::<_, i64>(0)?, - row.get::<_, i64>(1)?, - row.get::<_, String>(2)?.into(), - row.get(3)?..row.get(4)?, - )) - })?; - - let mut values_by_id = HashMap::)>::default(); - for row in result_iter { - let (id, worktree_id, path, range) = row?; - values_by_id.insert(id, (worktree_id, path, range)); - } - - let mut results = Vec::with_capacity(ids.len()); - for id in &ids { - let value = values_by_id - .remove(id) - .ok_or(anyhow!("missing span id {}", id))?; - results.push(value); - } - - Ok(results) - }) - } -} - -fn ids_to_sql(ids: &[i64]) -> Rc> { - Rc::new( - ids.iter() - .copied() - .map(|v| rusqlite::types::Value::from(v)) - .collect::>(), - ) -} diff --git a/crates/semantic_index/src/embedding_queue.rs b/crates/semantic_index/src/embedding_queue.rs deleted file mode 100644 index a2371a1196b598..00000000000000 --- a/crates/semantic_index/src/embedding_queue.rs +++ /dev/null @@ -1,169 +0,0 @@ -use crate::{parsing::Span, JobHandle}; -use ai::embedding::EmbeddingProvider; -use gpui::BackgroundExecutor; -use parking_lot::Mutex; -use smol::channel; -use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime}; - -#[derive(Clone)] -pub struct FileToEmbed { - pub worktree_id: i64, - pub path: Arc, - pub mtime: SystemTime, - pub spans: Vec, - pub job_handle: JobHandle, -} - -impl std::fmt::Debug for FileToEmbed { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("FileToEmbed") - .field("worktree_id", &self.worktree_id) - .field("path", &self.path) - .field("mtime", &self.mtime) - .field("spans", &self.spans) - .finish_non_exhaustive() - } -} - -impl PartialEq for FileToEmbed { - fn eq(&self, other: &Self) -> bool { - self.worktree_id == other.worktree_id - && self.path == other.path - && self.mtime == other.mtime - && self.spans == other.spans - } -} - -pub struct EmbeddingQueue { - embedding_provider: Arc, - pending_batch: Vec, - executor: BackgroundExecutor, - pending_batch_token_count: usize, - finished_files_tx: channel::Sender, - finished_files_rx: channel::Receiver, -} - -#[derive(Clone)] -pub struct FileFragmentToEmbed { - file: Arc>, - span_range: Range, -} - -impl EmbeddingQueue { - pub fn new( - embedding_provider: Arc, - executor: BackgroundExecutor, - ) -> Self { - let (finished_files_tx, finished_files_rx) = channel::unbounded(); - Self { - embedding_provider, - executor, - pending_batch: Vec::new(), - pending_batch_token_count: 0, - finished_files_tx, - finished_files_rx, - } - } - - pub fn push(&mut self, file: FileToEmbed) { - if file.spans.is_empty() { - self.finished_files_tx.try_send(file).unwrap(); - return; - } - - let file = Arc::new(Mutex::new(file)); - - self.pending_batch.push(FileFragmentToEmbed { - file: file.clone(), - span_range: 0..0, - }); - - let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; - for (ix, span) in file.lock().spans.iter().enumerate() { - let span_token_count = if span.embedding.is_none() { - span.token_count - } else { - 0 - }; - - let next_token_count = self.pending_batch_token_count + span_token_count; - if next_token_count > self.embedding_provider.max_tokens_per_batch() { - let range_end = fragment_range.end; - self.flush(); - self.pending_batch.push(FileFragmentToEmbed { - file: file.clone(), - span_range: range_end..range_end, - }); - fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range; - } - - fragment_range.end = ix + 1; - self.pending_batch_token_count += span_token_count; - } - } - - pub fn flush(&mut self) { - let batch = mem::take(&mut self.pending_batch); - self.pending_batch_token_count = 0; - if batch.is_empty() { - return; - } - - let finished_files_tx = self.finished_files_tx.clone(); - let embedding_provider = self.embedding_provider.clone(); - - self.executor - .spawn(async move { - let mut spans = Vec::new(); - for fragment in &batch { - let file = fragment.file.lock(); - spans.extend( - file.spans[fragment.span_range.clone()] - .iter() - .filter(|d| d.embedding.is_none()) - .map(|d| d.content.clone()), - ); - } - - // If spans is 0, just send the fragment to the finished files if its the last one. - if spans.is_empty() { - for fragment in batch.clone() { - if let Some(file) = Arc::into_inner(fragment.file) { - finished_files_tx.try_send(file.into_inner()).unwrap(); - } - } - return; - }; - - match embedding_provider.embed_batch(spans).await { - Ok(embeddings) => { - let mut embeddings = embeddings.into_iter(); - for fragment in batch { - for span in &mut fragment.file.lock().spans[fragment.span_range.clone()] - .iter_mut() - .filter(|d| d.embedding.is_none()) - { - if let Some(embedding) = embeddings.next() { - span.embedding = Some(embedding); - } else { - log::error!("number of embeddings != number of documents"); - } - } - - if let Some(file) = Arc::into_inner(fragment.file) { - finished_files_tx.try_send(file.into_inner()).unwrap(); - } - } - } - Err(error) => { - log::error!("{:?}", error); - } - } - }) - .detach(); - } - - pub fn finished_files(&self) -> channel::Receiver { - self.finished_files_rx.clone() - } -} diff --git a/crates/semantic_index/src/parsing.rs b/crates/semantic_index/src/parsing.rs deleted file mode 100644 index e6f4a37d10daab..00000000000000 --- a/crates/semantic_index/src/parsing.rs +++ /dev/null @@ -1,414 +0,0 @@ -use ai::{ - embedding::{Embedding, EmbeddingProvider}, - models::TruncationDirection, -}; -use anyhow::{anyhow, Result}; -use collections::HashSet; -use language::{Grammar, Language}; -use rusqlite::{ - types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef}, - ToSql, -}; -use sha1::{Digest, Sha1}; -use std::{ - borrow::Cow, - cmp::{self, Reverse}, - ops::Range, - path::Path, - sync::Arc, -}; -use tree_sitter::{Parser, QueryCursor}; - -#[derive(Debug, PartialEq, Eq, Clone, Hash)] -pub struct SpanDigest(pub [u8; 20]); - -impl FromSql for SpanDigest { - fn column_result(value: ValueRef) -> FromSqlResult { - let blob = value.as_blob()?; - let bytes = - blob.try_into() - .map_err(|_| rusqlite::types::FromSqlError::InvalidBlobSize { - expected_size: 20, - blob_size: blob.len(), - })?; - return Ok(SpanDigest(bytes)); - } -} - -impl ToSql for SpanDigest { - fn to_sql(&self) -> rusqlite::Result { - self.0.to_sql() - } -} - -impl From<&'_ str> for SpanDigest { - fn from(value: &'_ str) -> Self { - let mut sha1 = Sha1::new(); - sha1.update(value); - Self(sha1.finalize().into()) - } -} - -#[derive(Debug, PartialEq, Clone)] -pub struct Span { - pub name: String, - pub range: Range, - pub content: String, - pub embedding: Option, - pub digest: SpanDigest, - pub token_count: usize, -} - -const CODE_CONTEXT_TEMPLATE: &str = - "The below code snippet is from file ''\n\n```\n\n```"; -const ENTIRE_FILE_TEMPLATE: &str = - "The below snippet is from file ''\n\n```\n\n```"; -const MARKDOWN_CONTEXT_TEMPLATE: &str = "The below file contents is from file ''\n\n"; -pub const PARSEABLE_ENTIRE_FILE_TYPES: &[&str] = &[ - "TOML", "YAML", "CSS", "HEEX", "ERB", "SVELTE", "HTML", "Scheme", -]; - -pub struct CodeContextRetriever { - pub parser: Parser, - pub cursor: QueryCursor, - pub embedding_provider: Arc, -} - -// Every match has an item, this represents the fundamental treesitter symbol and anchors the search -// Every match has one or more 'name' captures. These indicate the display range of the item for deduplication. -// If there are preceding comments, we track this with a context capture -// If there is a piece that should be collapsed in hierarchical queries, we capture it with a collapse capture -// If there is a piece that should be kept inside a collapsed node, we capture it with a keep capture -#[derive(Debug, Clone)] -pub struct CodeContextMatch { - pub start_col: usize, - pub item_range: Option>, - pub name_range: Option>, - pub context_ranges: Vec>, - pub collapse_ranges: Vec>, -} - -impl CodeContextRetriever { - pub fn new(embedding_provider: Arc) -> Self { - Self { - parser: Parser::new(), - cursor: QueryCursor::new(), - embedding_provider, - } - } - - fn parse_entire_file( - &self, - relative_path: Option<&Path>, - language_name: Arc, - content: &str, - ) -> Result> { - let document_span = ENTIRE_FILE_TEMPLATE - .replace( - "", - &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), - ) - .replace("", language_name.as_ref()) - .replace("", &content); - let digest = SpanDigest::from(document_span.as_str()); - let model = self.embedding_provider.base_model(); - let document_span = model.truncate( - &document_span, - model.capacity()?, - ai::models::TruncationDirection::End, - )?; - let token_count = model.count_tokens(&document_span)?; - - Ok(vec![Span { - range: 0..content.len(), - content: document_span, - embedding: Default::default(), - name: language_name.to_string(), - digest, - token_count, - }]) - } - - fn parse_markdown_file( - &self, - relative_path: Option<&Path>, - content: &str, - ) -> Result> { - let document_span = MARKDOWN_CONTEXT_TEMPLATE - .replace( - "", - &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), - ) - .replace("", &content); - let digest = SpanDigest::from(document_span.as_str()); - - let model = self.embedding_provider.base_model(); - let document_span = model.truncate( - &document_span, - model.capacity()?, - ai::models::TruncationDirection::End, - )?; - let token_count = model.count_tokens(&document_span)?; - - Ok(vec![Span { - range: 0..content.len(), - content: document_span, - embedding: None, - name: "Markdown".to_string(), - digest, - token_count, - }]) - } - - fn get_matches_in_file( - &mut self, - content: &str, - grammar: &Arc, - ) -> Result> { - let embedding_config = grammar - .embedding_config - .as_ref() - .ok_or_else(|| anyhow!("no embedding queries"))?; - self.parser.set_language(&grammar.ts_language).unwrap(); - - let tree = self - .parser - .parse(&content, None) - .ok_or_else(|| anyhow!("parsing failed"))?; - - let mut captures: Vec = Vec::new(); - let mut collapse_ranges: Vec> = Vec::new(); - let mut keep_ranges: Vec> = Vec::new(); - for mat in self.cursor.matches( - &embedding_config.query, - tree.root_node(), - content.as_bytes(), - ) { - let mut start_col = 0; - let mut item_range: Option> = None; - let mut name_range: Option> = None; - let mut context_ranges: Vec> = Vec::new(); - collapse_ranges.clear(); - keep_ranges.clear(); - for capture in mat.captures { - if capture.index == embedding_config.item_capture_ix { - item_range = Some(capture.node.byte_range()); - start_col = capture.node.start_position().column; - } else if Some(capture.index) == embedding_config.name_capture_ix { - name_range = Some(capture.node.byte_range()); - } else if Some(capture.index) == embedding_config.context_capture_ix { - context_ranges.push(capture.node.byte_range()); - } else if Some(capture.index) == embedding_config.collapse_capture_ix { - collapse_ranges.push(capture.node.byte_range()); - } else if Some(capture.index) == embedding_config.keep_capture_ix { - keep_ranges.push(capture.node.byte_range()); - } - } - - captures.push(CodeContextMatch { - start_col, - item_range, - name_range, - context_ranges, - collapse_ranges: subtract_ranges(&collapse_ranges, &keep_ranges), - }); - } - Ok(captures) - } - - pub fn parse_file_with_template( - &mut self, - relative_path: Option<&Path>, - content: &str, - language: Arc, - ) -> Result> { - let language_name = language.name(); - - if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) { - return self.parse_entire_file(relative_path, language_name, &content); - } else if ["Markdown", "Plain Text"].contains(&language_name.as_ref()) { - return self.parse_markdown_file(relative_path, &content); - } - - let mut spans = self.parse_file(content, language)?; - for span in &mut spans { - let document_content = CODE_CONTEXT_TEMPLATE - .replace( - "", - &relative_path.map_or(Cow::Borrowed("untitled"), |path| path.to_string_lossy()), - ) - .replace("", language_name.as_ref()) - .replace("item", &span.content); - - let model = self.embedding_provider.base_model(); - let document_content = model.truncate( - &document_content, - model.capacity()?, - TruncationDirection::End, - )?; - let token_count = model.count_tokens(&document_content)?; - - span.content = document_content; - span.token_count = token_count; - } - Ok(spans) - } - - pub fn parse_file(&mut self, content: &str, language: Arc) -> Result> { - let grammar = language - .grammar() - .ok_or_else(|| anyhow!("no grammar for language"))?; - - // Iterate through query matches - let matches = self.get_matches_in_file(content, grammar)?; - - let language_scope = language.default_scope(); - let placeholder = language_scope.collapsed_placeholder(); - - let mut spans = Vec::new(); - let mut collapsed_ranges_within = Vec::new(); - let mut parsed_name_ranges = HashSet::default(); - for (i, context_match) in matches.iter().enumerate() { - // Items which are collapsible but not embeddable have no item range - let item_range = if let Some(item_range) = context_match.item_range.clone() { - item_range - } else { - continue; - }; - - // Checks for deduplication - let name; - if let Some(name_range) = context_match.name_range.clone() { - name = content - .get(name_range.clone()) - .map_or(String::new(), |s| s.to_string()); - if parsed_name_ranges.contains(&name_range) { - continue; - } - parsed_name_ranges.insert(name_range); - } else { - name = String::new(); - } - - collapsed_ranges_within.clear(); - 'outer: for remaining_match in &matches[(i + 1)..] { - for collapsed_range in &remaining_match.collapse_ranges { - if item_range.start <= collapsed_range.start - && item_range.end >= collapsed_range.end - { - collapsed_ranges_within.push(collapsed_range.clone()); - } else { - break 'outer; - } - } - } - - collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end))); - - let mut span_content = String::new(); - for context_range in &context_match.context_ranges { - add_content_from_range( - &mut span_content, - content, - context_range.clone(), - context_match.start_col, - ); - span_content.push_str("\n"); - } - - let mut offset = item_range.start; - for collapsed_range in &collapsed_ranges_within { - if collapsed_range.start > offset { - add_content_from_range( - &mut span_content, - content, - offset..collapsed_range.start, - context_match.start_col, - ); - offset = collapsed_range.start; - } - - if collapsed_range.end > offset { - span_content.push_str(placeholder); - offset = collapsed_range.end; - } - } - - if offset < item_range.end { - add_content_from_range( - &mut span_content, - content, - offset..item_range.end, - context_match.start_col, - ); - } - - let sha1 = SpanDigest::from(span_content.as_str()); - spans.push(Span { - name, - content: span_content, - range: item_range.clone(), - embedding: None, - digest: sha1, - token_count: 0, - }) - } - - return Ok(spans); - } -} - -pub(crate) fn subtract_ranges( - ranges: &[Range], - ranges_to_subtract: &[Range], -) -> Vec> { - let mut result = Vec::new(); - - let mut ranges_to_subtract = ranges_to_subtract.iter().peekable(); - - for range in ranges { - let mut offset = range.start; - - while offset < range.end { - if let Some(range_to_subtract) = ranges_to_subtract.peek() { - if offset < range_to_subtract.start { - let next_offset = cmp::min(range_to_subtract.start, range.end); - result.push(offset..next_offset); - offset = next_offset; - } else { - let next_offset = cmp::min(range_to_subtract.end, range.end); - offset = next_offset; - } - - if offset >= range_to_subtract.end { - ranges_to_subtract.next(); - } - } else { - result.push(offset..range.end); - offset = range.end; - } - } - } - - result -} - -fn add_content_from_range( - output: &mut String, - content: &str, - range: Range, - start_col: usize, -) { - for mut line in content.get(range.clone()).unwrap_or("").lines() { - for _ in 0..start_col { - if line.starts_with(' ') { - line = &line[1..]; - } else { - break; - } - } - output.push_str(line); - output.push('\n'); - } - output.pop(); -} diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs deleted file mode 100644 index b93bd433cd541d..00000000000000 --- a/crates/semantic_index/src/semantic_index.rs +++ /dev/null @@ -1,1308 +0,0 @@ -mod db; -mod embedding_queue; -mod parsing; -pub mod semantic_index_settings; - -#[cfg(test)] -mod semantic_index_tests; - -use crate::semantic_index_settings::SemanticIndexSettings; -use ai::embedding::{Embedding, EmbeddingProvider}; -use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL}; -use anyhow::{anyhow, Context as _, Result}; -use collections::{BTreeMap, HashMap, HashSet}; -use db::VectorDatabase; -use embedding_queue::{EmbeddingQueue, FileToEmbed}; -use futures::{future, FutureExt, StreamExt}; -use gpui::{ - AppContext, AsyncAppContext, BorrowWindow, Context, Global, Model, ModelContext, Task, - ViewContext, WeakModel, -}; -use language::{Anchor, Bias, Buffer, Language, LanguageRegistry}; -use lazy_static::lazy_static; -use ordered_float::OrderedFloat; -use parking_lot::Mutex; -use parsing::{CodeContextRetriever, Span, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES}; -use postage::watch; -use project::{Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId}; -use release_channel::ReleaseChannel; -use settings::Settings; -use smol::channel; -use std::{ - cmp::Reverse, - env, - future::Future, - mem, - ops::Range, - path::{Path, PathBuf}, - sync::{Arc, Weak}, - time::{Duration, Instant, SystemTime}, -}; -use util::paths::PathMatcher; -use util::{http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt}; -use workspace::Workspace; - -const SEMANTIC_INDEX_VERSION: usize = 11; -const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60); -const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250); - -lazy_static! { - static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); -} - -pub fn init( - fs: Arc, - http_client: Arc, - language_registry: Arc, - cx: &mut AppContext, -) { - SemanticIndexSettings::register(cx); - - let db_file_path = EMBEDDINGS_DIR - .join(Path::new(ReleaseChannel::global(cx).dev_name())) - .join("embeddings_db"); - - cx.observe_new_views( - |workspace: &mut Workspace, cx: &mut ViewContext| { - let Some(semantic_index) = SemanticIndex::global(cx) else { - return; - }; - let project = workspace.project().clone(); - - if project.read(cx).is_local() { - cx.app_mut() - .spawn(|mut cx| async move { - let previously_indexed = semantic_index - .update(&mut cx, |index, cx| { - index.project_previously_indexed(&project, cx) - })? - .await?; - if previously_indexed { - semantic_index - .update(&mut cx, |index, cx| index.index_project(project, cx))? - .await?; - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - }, - ) - .detach(); - - cx.spawn(move |cx| async move { - let embedding_provider = OpenAiEmbeddingProvider::new( - // TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not - OPEN_AI_API_URL.to_string(), - http_client, - cx.background_executor().clone(), - ) - .await; - let semantic_index = SemanticIndex::new( - fs, - db_file_path, - Arc::new(embedding_provider), - language_registry, - cx.clone(), - ) - .await?; - - cx.update(|cx| cx.set_global(GlobalSemanticIndex(semantic_index.clone())))?; - - anyhow::Ok(()) - }) - .detach(); -} - -#[derive(Copy, Clone, Debug)] -pub enum SemanticIndexStatus { - NotAuthenticated, - NotIndexed, - Indexed, - Indexing { - remaining_files: usize, - rate_limit_expiry: Option, - }, -} - -pub struct SemanticIndex { - fs: Arc, - db: VectorDatabase, - embedding_provider: Arc, - language_registry: Arc, - parsing_files_tx: channel::Sender<(Arc>, PendingFile)>, - _embedding_task: Task<()>, - _parsing_files_tasks: Vec>, - projects: HashMap, ProjectState>, -} - -struct GlobalSemanticIndex(Model); - -impl Global for GlobalSemanticIndex {} - -struct ProjectState { - worktrees: HashMap, - pending_file_count_rx: watch::Receiver, - pending_file_count_tx: Arc>>, - pending_index: usize, - _subscription: gpui::Subscription, - _observe_pending_file_count: Task<()>, -} - -enum WorktreeState { - Registering(RegisteringWorktreeState), - Registered(RegisteredWorktreeState), -} - -impl WorktreeState { - fn is_registered(&self) -> bool { - matches!(self, Self::Registered(_)) - } - - fn paths_changed( - &mut self, - changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, - worktree: &Worktree, - ) { - let changed_paths = match self { - Self::Registering(state) => &mut state.changed_paths, - Self::Registered(state) => &mut state.changed_paths, - }; - - for (path, entry_id, change) in changes.iter() { - let Some(entry) = worktree.entry_for_id(*entry_id) else { - continue; - }; - let Some(mtime) = entry.mtime else { - continue; - }; - if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() { - continue; - } - changed_paths.insert( - path.clone(), - ChangedPathInfo { - mtime, - is_deleted: *change == PathChange::Removed, - }, - ); - } - } -} - -struct RegisteringWorktreeState { - changed_paths: BTreeMap, ChangedPathInfo>, - done_rx: watch::Receiver>, - _registration: Task<()>, -} - -impl RegisteringWorktreeState { - fn done(&self) -> impl Future { - let mut done_rx = self.done_rx.clone(); - async move { - while let Some(result) = done_rx.next().await { - if result.is_some() { - break; - } - } - } - } -} - -struct RegisteredWorktreeState { - db_id: i64, - changed_paths: BTreeMap, ChangedPathInfo>, -} - -struct ChangedPathInfo { - mtime: SystemTime, - is_deleted: bool, -} - -#[derive(Clone)] -pub struct JobHandle { - /// The outer Arc is here to count the clones of a JobHandle instance; - /// when the last handle to a given job is dropped, we decrement a counter (just once). - tx: Arc>>>, -} - -impl JobHandle { - fn new(tx: &Arc>>) -> Self { - *tx.lock().borrow_mut() += 1; - Self { - tx: Arc::new(Arc::downgrade(&tx)), - } - } -} - -impl ProjectState { - fn new(subscription: gpui::Subscription, cx: &mut ModelContext) -> Self { - let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0); - let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx)); - Self { - worktrees: Default::default(), - pending_file_count_rx: pending_file_count_rx.clone(), - pending_file_count_tx, - pending_index: 0, - _subscription: subscription, - _observe_pending_file_count: cx.spawn({ - let mut pending_file_count_rx = pending_file_count_rx.clone(); - |this, mut cx| async move { - while let Some(_) = pending_file_count_rx.next().await { - if this.update(&mut cx, |_, cx| cx.notify()).is_err() { - break; - } - } - } - }), - } - } - - fn worktree_id_for_db_id(&self, id: i64) -> Option { - self.worktrees - .iter() - .find_map(|(worktree_id, worktree_state)| match worktree_state { - WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id), - _ => None, - }) - } -} - -#[derive(Clone)] -pub struct PendingFile { - worktree_db_id: i64, - relative_path: Arc, - absolute_path: PathBuf, - language: Option>, - modified_time: SystemTime, - job_handle: JobHandle, -} - -#[derive(Clone)] -pub struct SearchResult { - pub buffer: Model, - pub range: Range, - pub similarity: OrderedFloat, -} - -impl SemanticIndex { - pub fn global(cx: &mut AppContext) -> Option> { - cx.try_global::() - .map(|semantic_index| semantic_index.0.clone()) - } - - pub fn authenticate(&mut self, cx: &mut AppContext) -> Task { - if !self.embedding_provider.has_credentials() { - let embedding_provider = self.embedding_provider.clone(); - cx.spawn(|cx| async move { - if let Some(retrieve_credentials) = cx - .update(|cx| embedding_provider.retrieve_credentials(cx)) - .log_err() - { - retrieve_credentials.await; - } - - embedding_provider.has_credentials() - }) - } else { - Task::ready(true) - } - } - - pub fn is_authenticated(&self) -> bool { - self.embedding_provider.has_credentials() - } - - pub fn enabled(cx: &AppContext) -> bool { - SemanticIndexSettings::get_global(cx).enabled - } - - pub fn status(&self, project: &Model) -> SemanticIndexStatus { - if !self.is_authenticated() { - return SemanticIndexStatus::NotAuthenticated; - } - - if let Some(project_state) = self.projects.get(&project.downgrade()) { - if project_state - .worktrees - .values() - .all(|worktree| worktree.is_registered()) - && project_state.pending_index == 0 - { - SemanticIndexStatus::Indexed - } else { - SemanticIndexStatus::Indexing { - remaining_files: *project_state.pending_file_count_rx.borrow(), - rate_limit_expiry: self.embedding_provider.rate_limit_expiration(), - } - } - } else { - SemanticIndexStatus::NotIndexed - } - } - - pub async fn new( - fs: Arc, - database_path: PathBuf, - embedding_provider: Arc, - language_registry: Arc, - mut cx: AsyncAppContext, - ) -> Result> { - let t0 = Instant::now(); - let database_path = Arc::from(database_path); - let db = VectorDatabase::new(fs.clone(), database_path, cx.background_executor().clone()) - .await?; - - log::trace!( - "db initialization took {:?} milliseconds", - t0.elapsed().as_millis() - ); - - cx.new_model(|cx| { - let t0 = Instant::now(); - let embedding_queue = - EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor().clone()); - let _embedding_task = cx.background_executor().spawn({ - let embedded_files = embedding_queue.finished_files(); - let db = db.clone(); - async move { - while let Ok(file) = embedded_files.recv().await { - db.insert_file(file.worktree_id, file.path, file.mtime, file.spans) - .await - .log_err(); - } - } - }); - - // Parse files into embeddable spans. - let (parsing_files_tx, parsing_files_rx) = - channel::unbounded::<(Arc>, PendingFile)>(); - let embedding_queue = Arc::new(Mutex::new(embedding_queue)); - let mut _parsing_files_tasks = Vec::new(); - for _ in 0..cx.background_executor().num_cpus() { - let fs = fs.clone(); - let mut parsing_files_rx = parsing_files_rx.clone(); - let embedding_provider = embedding_provider.clone(); - let embedding_queue = embedding_queue.clone(); - let background = cx.background_executor().clone(); - _parsing_files_tasks.push(cx.background_executor().spawn(async move { - let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - loop { - let mut timer = background.timer(EMBEDDING_QUEUE_FLUSH_TIMEOUT).fuse(); - let mut next_file_to_parse = parsing_files_rx.next().fuse(); - futures::select_biased! { - next_file_to_parse = next_file_to_parse => { - if let Some((embeddings_for_digest, pending_file)) = next_file_to_parse { - Self::parse_file( - &fs, - pending_file, - &mut retriever, - &embedding_queue, - &embeddings_for_digest, - ) - .await - } else { - break; - } - }, - _ = timer => { - embedding_queue.lock().flush(); - } - } - } - })); - } - - log::trace!( - "semantic index task initialization took {:?} milliseconds", - t0.elapsed().as_millis() - ); - Self { - fs, - db, - embedding_provider, - language_registry, - parsing_files_tx, - _embedding_task, - _parsing_files_tasks, - projects: Default::default(), - } - }) - } - - async fn parse_file( - fs: &Arc, - pending_file: PendingFile, - retriever: &mut CodeContextRetriever, - embedding_queue: &Arc>, - embeddings_for_digest: &HashMap, - ) { - let Some(language) = pending_file.language else { - return; - }; - - if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() { - if let Some(mut spans) = retriever - .parse_file_with_template(Some(&pending_file.relative_path), &content, language) - .log_err() - { - log::trace!( - "parsed path {:?}: {} spans", - pending_file.relative_path, - spans.len() - ); - - for span in &mut spans { - if let Some(embedding) = embeddings_for_digest.get(&span.digest) { - span.embedding = Some(embedding.to_owned()); - } - } - - embedding_queue.lock().push(FileToEmbed { - worktree_id: pending_file.worktree_db_id, - path: pending_file.relative_path, - mtime: pending_file.modified_time, - job_handle: pending_file.job_handle, - spans, - }); - } - } - } - - pub fn project_previously_indexed( - &mut self, - project: &Model, - cx: &mut ModelContext, - ) -> Task> { - let worktrees_indexed_previously = project - .read(cx) - .worktrees() - .map(|worktree| { - self.db - .worktree_previously_indexed(&worktree.read(cx).abs_path()) - }) - .collect::>(); - cx.spawn(|_, _cx| async move { - let worktree_indexed_previously = - futures::future::join_all(worktrees_indexed_previously).await; - - Ok(worktree_indexed_previously - .iter() - .filter(|worktree| worktree.is_ok()) - .all(|v| v.as_ref().log_err().is_some_and(|v| v.to_owned()))) - }) - } - - fn project_entries_changed( - &mut self, - project: Model, - worktree_id: WorktreeId, - changes: Arc<[(Arc, ProjectEntryId, PathChange)]>, - cx: &mut ModelContext, - ) { - let Some(worktree) = project.read(cx).worktree_for_id(worktree_id, cx) else { - return; - }; - let project = project.downgrade(); - let Some(project_state) = self.projects.get_mut(&project) else { - return; - }; - - let worktree = worktree.read(cx); - let worktree_state = - if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) { - worktree_state - } else { - return; - }; - worktree_state.paths_changed(changes, worktree); - if let WorktreeState::Registered(_) = worktree_state { - cx.spawn(|this, mut cx| async move { - cx.background_executor() - .timer(BACKGROUND_INDEXING_DELAY) - .await; - if let Some((this, project)) = this.upgrade().zip(project.upgrade()) { - this.update(&mut cx, |this, cx| { - this.index_project(project, cx).detach_and_log_err(cx) - })?; - } - anyhow::Ok(()) - }) - .detach_and_log_err(cx); - } - } - - fn register_worktree( - &mut self, - project: Model, - worktree: Model, - cx: &mut ModelContext, - ) { - let project = project.downgrade(); - let project_state = if let Some(project_state) = self.projects.get_mut(&project) { - project_state - } else { - return; - }; - let worktree = if let Some(worktree) = worktree.read(cx).as_local() { - worktree - } else { - return; - }; - let worktree_abs_path = worktree.abs_path().clone(); - let scan_complete = worktree.scan_complete(); - let worktree_id = worktree.id(); - let db = self.db.clone(); - let language_registry = self.language_registry.clone(); - let (mut done_tx, done_rx) = watch::channel(); - let registration = cx.spawn(|this, mut cx| { - async move { - let register = async { - scan_complete.await; - let db_id = db.find_or_create_worktree(worktree_abs_path).await?; - let mut file_mtimes = db.get_file_mtimes(db_id).await?; - let worktree = if let Some(project) = project.upgrade() { - project - .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx)) - .ok() - .flatten() - .context("worktree not found")? - } else { - return anyhow::Ok(()); - }; - let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot())?; - let mut changed_paths = cx - .background_executor() - .spawn(async move { - let mut changed_paths = BTreeMap::new(); - for file in worktree.files(false, 0) { - let absolute_path = worktree.absolutize(&file.path)?; - - if file.is_external || file.is_ignored || file.is_symlink { - continue; - } - - if let Ok(language) = language_registry - .language_for_file_path(&absolute_path) - .await - { - // Test if file is valid parseable file - if !PARSEABLE_ENTIRE_FILE_TYPES - .contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } - let Some(new_mtime) = file.mtime else { - continue; - }; - - let stored_mtime = file_mtimes.remove(&file.path.to_path_buf()); - let already_stored = stored_mtime == Some(new_mtime); - - if !already_stored { - changed_paths.insert( - file.path.clone(), - ChangedPathInfo { - mtime: new_mtime, - is_deleted: false, - }, - ); - } - } - } - - // Clean up entries from database that are no longer in the worktree. - for (path, mtime) in file_mtimes { - changed_paths.insert( - path.into(), - ChangedPathInfo { - mtime, - is_deleted: true, - }, - ); - } - - anyhow::Ok(changed_paths) - }) - .await?; - this.update(&mut cx, |this, cx| { - let project_state = this - .projects - .get_mut(&project) - .context("project not registered")?; - let project = project.upgrade().context("project was dropped")?; - - if let Some(WorktreeState::Registering(state)) = - project_state.worktrees.remove(&worktree_id) - { - changed_paths.extend(state.changed_paths); - } - project_state.worktrees.insert( - worktree_id, - WorktreeState::Registered(RegisteredWorktreeState { - db_id, - changed_paths, - }), - ); - this.index_project(project, cx).detach_and_log_err(cx); - - anyhow::Ok(()) - })??; - - anyhow::Ok(()) - }; - - if register.await.log_err().is_none() { - // Stop tracking this worktree if the registration failed. - this.update(&mut cx, |this, _| { - if let Some(project_state) = this.projects.get_mut(&project) { - project_state.worktrees.remove(&worktree_id); - } - }) - .ok(); - } - - *done_tx.borrow_mut() = Some(()); - } - }); - project_state.worktrees.insert( - worktree_id, - WorktreeState::Registering(RegisteringWorktreeState { - changed_paths: Default::default(), - done_rx, - _registration: registration, - }), - ); - } - - fn project_worktrees_changed(&mut self, project: Model, cx: &mut ModelContext) { - let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade()) - { - project_state - } else { - return; - }; - - let mut worktrees = project - .read(cx) - .worktrees() - .filter(|worktree| worktree.read(cx).is_local()) - .collect::>(); - let worktree_ids = worktrees - .iter() - .map(|worktree| worktree.read(cx).id()) - .collect::>(); - - // Remove worktrees that are no longer present - project_state - .worktrees - .retain(|worktree_id, _| worktree_ids.contains(worktree_id)); - - // Register new worktrees - worktrees.retain(|worktree| { - let worktree_id = worktree.read(cx).id(); - !project_state.worktrees.contains_key(&worktree_id) - }); - for worktree in worktrees { - self.register_worktree(project.clone(), worktree, cx); - } - } - - pub fn pending_file_count(&self, project: &Model) -> Option> { - Some( - self.projects - .get(&project.downgrade())? - .pending_file_count_rx - .clone(), - ) - } - - pub fn search_project( - &mut self, - project: Model, - query: String, - limit: usize, - includes: Vec, - excludes: Vec, - cx: &mut ModelContext, - ) -> Task>> { - if query.is_empty() { - return Task::ready(Ok(Vec::new())); - } - - let index = self.index_project(project.clone(), cx); - let embedding_provider = self.embedding_provider.clone(); - - cx.spawn(|this, mut cx| async move { - index.await?; - let t0 = Instant::now(); - - let query = embedding_provider - .embed_batch(vec![query]) - .await? - .pop() - .context("could not embed query")?; - log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis()); - - let search_start = Instant::now(); - let modified_buffer_results = this.update(&mut cx, |this, cx| { - this.search_modified_buffers( - &project, - query.clone(), - limit, - &includes, - &excludes, - cx, - ) - })?; - let file_results = this.update(&mut cx, |this, cx| { - this.search_files(project, query, limit, includes, excludes, cx) - })?; - let (modified_buffer_results, file_results) = - futures::join!(modified_buffer_results, file_results); - - // Weave together the results from modified buffers and files. - let mut results = Vec::new(); - let mut modified_buffers = HashSet::default(); - for result in modified_buffer_results.log_err().unwrap_or_default() { - modified_buffers.insert(result.buffer.clone()); - results.push(result); - } - for result in file_results.log_err().unwrap_or_default() { - if !modified_buffers.contains(&result.buffer) { - results.push(result); - } - } - results.sort_by_key(|result| Reverse(result.similarity)); - results.truncate(limit); - log::trace!("Semantic search took {:?}", search_start.elapsed()); - Ok(results) - }) - } - - pub fn search_files( - &mut self, - project: Model, - query: Embedding, - limit: usize, - includes: Vec, - excludes: Vec, - cx: &mut ModelContext, - ) -> Task>> { - let db_path = self.db.path().clone(); - let fs = self.fs.clone(); - cx.spawn(|this, mut cx| async move { - let database = VectorDatabase::new( - fs.clone(), - db_path.clone(), - cx.background_executor().clone(), - ) - .await?; - - let worktree_db_ids = this.read_with(&cx, |this, _| { - let project_state = this - .projects - .get(&project.downgrade()) - .context("project was not indexed")?; - let worktree_db_ids = project_state - .worktrees - .values() - .filter_map(|worktree| { - if let WorktreeState::Registered(worktree) = worktree { - Some(worktree.db_id) - } else { - None - } - }) - .collect::>(); - anyhow::Ok(worktree_db_ids) - })??; - - let file_ids = database - .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes) - .await?; - - let batch_n = cx.background_executor().num_cpus(); - let ids_len = file_ids.clone().len(); - let minimum_batch_size = 50; - - let batch_size = { - let size = ids_len / batch_n; - if size < minimum_batch_size { - minimum_batch_size - } else { - size - } - }; - - let mut batch_results = Vec::new(); - for batch in file_ids.chunks(batch_size) { - let batch = batch.into_iter().map(|v| *v).collect::>(); - let fs = fs.clone(); - let db_path = db_path.clone(); - let query = query.clone(); - if let Some(db) = - VectorDatabase::new(fs, db_path.clone(), cx.background_executor().clone()) - .await - .log_err() - { - batch_results.push(async move { - db.top_k_search(&query, limit, batch.as_slice()).await - }); - } - } - - let batch_results = futures::future::join_all(batch_results).await; - - let mut results = Vec::new(); - for batch_result in batch_results { - if batch_result.is_ok() { - for (id, similarity) in batch_result.unwrap() { - let ix = match results - .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s)) - { - Ok(ix) => ix, - Err(ix) => ix, - }; - - results.insert(ix, (id, similarity)); - results.truncate(limit); - } - } - } - - let ids = results.iter().map(|(id, _)| *id).collect::>(); - let scores = results - .into_iter() - .map(|(_, score)| score) - .collect::>(); - let spans = database.spans_for_ids(ids.as_slice()).await?; - - let mut tasks = Vec::new(); - let mut ranges = Vec::new(); - let weak_project = project.downgrade(); - project.update(&mut cx, |project, cx| { - let this = this.upgrade().context("index was dropped")?; - for (worktree_db_id, file_path, byte_range) in spans { - let project_state = - if let Some(state) = this.read(cx).projects.get(&weak_project) { - state - } else { - return Err(anyhow!("project not added")); - }; - if let Some(worktree_id) = project_state.worktree_id_for_db_id(worktree_db_id) { - tasks.push(project.open_buffer((worktree_id, file_path), cx)); - ranges.push(byte_range); - } - } - - Ok(()) - })??; - - let buffers = futures::future::join_all(tasks).await; - Ok(buffers - .into_iter() - .zip(ranges) - .zip(scores) - .filter_map(|((buffer, range), similarity)| { - let buffer = buffer.log_err()?; - let range = buffer - .read_with(&cx, |buffer, _| { - let start = buffer.clip_offset(range.start, Bias::Left); - let end = buffer.clip_offset(range.end, Bias::Right); - buffer.anchor_before(start)..buffer.anchor_after(end) - }) - .log_err()?; - Some(SearchResult { - buffer, - range, - similarity, - }) - }) - .collect()) - }) - } - - fn search_modified_buffers( - &self, - project: &Model, - query: Embedding, - limit: usize, - includes: &[PathMatcher], - excludes: &[PathMatcher], - cx: &mut ModelContext, - ) -> Task>> { - let modified_buffers = project - .read(cx) - .opened_buffers() - .into_iter() - .filter_map(|buffer_handle| { - let buffer = buffer_handle.read(cx); - let snapshot = buffer.snapshot(); - let excluded = snapshot.resolve_file_path(cx, false).map_or(false, |path| { - excludes.iter().any(|matcher| matcher.is_match(&path)) - }); - - let included = if includes.len() == 0 { - true - } else { - snapshot.resolve_file_path(cx, false).map_or(false, |path| { - includes.iter().any(|matcher| matcher.is_match(&path)) - }) - }; - - if buffer.is_dirty() && !excluded && included { - Some((buffer_handle, snapshot)) - } else { - None - } - }) - .collect::>(); - - let embedding_provider = self.embedding_provider.clone(); - let fs = self.fs.clone(); - let db_path = self.db.path().clone(); - let background = cx.background_executor().clone(); - cx.background_executor().spawn(async move { - let db = VectorDatabase::new(fs, db_path.clone(), background).await?; - let mut results = Vec::::new(); - - let mut retriever = CodeContextRetriever::new(embedding_provider.clone()); - for (buffer, snapshot) in modified_buffers { - let language = snapshot - .language_at(0) - .cloned() - .unwrap_or_else(|| language::PLAIN_TEXT.clone()); - let mut spans = retriever - .parse_file_with_template(None, &snapshot.text(), language) - .log_err() - .unwrap_or_default(); - if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db) - .await - .log_err() - .is_some() - { - for span in spans { - let similarity = span.embedding.unwrap().similarity(&query); - let ix = match results - .binary_search_by_key(&Reverse(similarity), |result| { - Reverse(result.similarity) - }) { - Ok(ix) => ix, - Err(ix) => ix, - }; - - let range = { - let start = snapshot.clip_offset(span.range.start, Bias::Left); - let end = snapshot.clip_offset(span.range.end, Bias::Right); - snapshot.anchor_before(start)..snapshot.anchor_after(end) - }; - - results.insert( - ix, - SearchResult { - buffer: buffer.clone(), - range, - similarity, - }, - ); - results.truncate(limit); - } - } - } - - Ok(results) - }) - } - - pub fn index_project( - &mut self, - project: Model, - cx: &mut ModelContext, - ) -> Task> { - if self.is_authenticated() { - self.index_project_internal(project, cx) - } else { - let authenticate = self.authenticate(cx); - cx.spawn(|this, mut cx| async move { - if authenticate.await { - this.update(&mut cx, |this, cx| this.index_project_internal(project, cx))? - .await - } else { - Err(anyhow!("user is not authenticated")) - } - }) - } - } - - fn index_project_internal( - &mut self, - project: Model, - cx: &mut ModelContext, - ) -> Task> { - if !self.projects.contains_key(&project.downgrade()) { - let subscription = cx.subscribe(&project, |this, project, event, cx| match event { - project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => { - this.project_worktrees_changed(project.clone(), cx); - } - project::Event::WorktreeUpdatedEntries(worktree_id, changes) => { - this.project_entries_changed(project, *worktree_id, changes.clone(), cx); - } - _ => {} - }); - let project_state = ProjectState::new(subscription, cx); - self.projects.insert(project.downgrade(), project_state); - self.project_worktrees_changed(project.clone(), cx); - } - let project_state = self.projects.get_mut(&project.downgrade()).unwrap(); - project_state.pending_index += 1; - cx.notify(); - - let mut pending_file_count_rx = project_state.pending_file_count_rx.clone(); - let db = self.db.clone(); - let language_registry = self.language_registry.clone(); - let parsing_files_tx = self.parsing_files_tx.clone(); - let worktree_registration = self.wait_for_worktree_registration(&project, cx); - - cx.spawn(|this, mut cx| async move { - worktree_registration.await?; - - let mut pending_files = Vec::new(); - let mut files_to_delete = Vec::new(); - this.update(&mut cx, |this, cx| { - let project_state = this - .projects - .get_mut(&project.downgrade()) - .context("project was dropped")?; - let pending_file_count_tx = &project_state.pending_file_count_tx; - - project_state - .worktrees - .retain(|worktree_id, worktree_state| { - let worktree = if let Some(worktree) = - project.read(cx).worktree_for_id(*worktree_id, cx) - { - worktree - } else { - return false; - }; - let worktree_state = - if let WorktreeState::Registered(worktree_state) = worktree_state { - worktree_state - } else { - return true; - }; - - for (path, info) in &worktree_state.changed_paths { - if info.is_deleted { - files_to_delete.push((worktree_state.db_id, path.clone())); - } else if let Ok(absolute_path) = worktree.read(cx).absolutize(path) { - let job_handle = JobHandle::new(pending_file_count_tx); - pending_files.push(PendingFile { - absolute_path, - relative_path: path.clone(), - language: None, - job_handle, - modified_time: info.mtime, - worktree_db_id: worktree_state.db_id, - }); - } - } - worktree_state.changed_paths.clear(); - true - }); - - anyhow::Ok(()) - })??; - - cx.background_executor() - .spawn(async move { - for (worktree_db_id, path) in files_to_delete { - db.delete_file(worktree_db_id, path).await.log_err(); - } - - let embeddings_for_digest = { - let mut files = HashMap::default(); - for pending_file in &pending_files { - files - .entry(pending_file.worktree_db_id) - .or_insert(Vec::new()) - .push(pending_file.relative_path.clone()); - } - Arc::new( - db.embeddings_for_files(files) - .await - .log_err() - .unwrap_or_default(), - ) - }; - - for mut pending_file in pending_files { - if let Ok(language) = language_registry - .language_for_file_path(&pending_file.relative_path) - .await - { - if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref()) - && &language.name().as_ref() != &"Markdown" - && language - .grammar() - .and_then(|grammar| grammar.embedding_config.as_ref()) - .is_none() - { - continue; - } - pending_file.language = Some(language); - } - parsing_files_tx - .try_send((embeddings_for_digest.clone(), pending_file)) - .ok(); - } - - // Wait until we're done indexing. - while let Some(count) = pending_file_count_rx.next().await { - if count == 0 { - break; - } - } - }) - .await; - - this.update(&mut cx, |this, cx| { - let project_state = this - .projects - .get_mut(&project.downgrade()) - .context("project was dropped")?; - project_state.pending_index -= 1; - cx.notify(); - anyhow::Ok(()) - })??; - - Ok(()) - }) - } - - fn wait_for_worktree_registration( - &self, - project: &Model, - cx: &mut ModelContext, - ) -> Task> { - let project = project.downgrade(); - cx.spawn(|this, cx| async move { - loop { - let mut pending_worktrees = Vec::new(); - this.upgrade() - .context("semantic index dropped")? - .read_with(&cx, |this, _| { - if let Some(project) = this.projects.get(&project) { - for worktree in project.worktrees.values() { - if let WorktreeState::Registering(worktree) = worktree { - pending_worktrees.push(worktree.done()); - } - } - } - })?; - - if pending_worktrees.is_empty() { - break; - } else { - future::join_all(pending_worktrees).await; - } - } - Ok(()) - }) - } - - async fn embed_spans( - spans: &mut [Span], - embedding_provider: &dyn EmbeddingProvider, - db: &VectorDatabase, - ) -> Result<()> { - let mut batch = Vec::new(); - let mut batch_tokens = 0; - let mut embeddings = Vec::new(); - - let digests = spans - .iter() - .map(|span| span.digest.clone()) - .collect::>(); - let embeddings_for_digests = db - .embeddings_for_digests(digests) - .await - .log_err() - .unwrap_or_default(); - - for span in &*spans { - if embeddings_for_digests.contains_key(&span.digest) { - continue; - }; - - if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() { - let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch)) - .await?; - embeddings.extend(batch_embeddings); - batch_tokens = 0; - } - - batch_tokens += span.token_count; - batch.push(span.content.clone()); - } - - if !batch.is_empty() { - let batch_embeddings = embedding_provider - .embed_batch(mem::take(&mut batch)) - .await?; - - embeddings.extend(batch_embeddings); - } - - let mut embeddings = embeddings.into_iter(); - for span in spans { - let embedding = if let Some(embedding) = embeddings_for_digests.get(&span.digest) { - Some(embedding.clone()) - } else { - embeddings.next() - }; - let embedding = embedding.context("failed to embed spans")?; - span.embedding = Some(embedding); - } - Ok(()) - } -} - -impl Drop for JobHandle { - fn drop(&mut self) { - if let Some(inner) = Arc::get_mut(&mut self.tx) { - // This is the last instance of the JobHandle (regardless of its origin - whether it was cloned or not) - if let Some(tx) = inner.upgrade() { - let mut tx = tx.lock(); - *tx.borrow_mut() -= 1; - } - } - } -} - -#[cfg(test)] -mod tests { - - use super::*; - #[test] - fn test_job_handle() { - let (job_count_tx, job_count_rx) = watch::channel_with(0); - let tx = Arc::new(Mutex::new(job_count_tx)); - let job_handle = JobHandle::new(&tx); - - assert_eq!(1, *job_count_rx.borrow()); - let new_job_handle = job_handle.clone(); - assert_eq!(1, *job_count_rx.borrow()); - drop(job_handle); - assert_eq!(1, *job_count_rx.borrow()); - drop(new_job_handle); - assert_eq!(0, *job_count_rx.borrow()); - } -} diff --git a/crates/semantic_index/src/semantic_index_settings.rs b/crates/semantic_index/src/semantic_index_settings.rs deleted file mode 100644 index 73fd49c8f5f61d..00000000000000 --- a/crates/semantic_index/src/semantic_index_settings.rs +++ /dev/null @@ -1,33 +0,0 @@ -use anyhow; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use settings::Settings; - -#[derive(Deserialize, Debug)] -pub struct SemanticIndexSettings { - pub enabled: bool, -} - -/// Configuration of semantic index, an alternate search engine available in -/// project search. -#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)] -pub struct SemanticIndexSettingsContent { - /// Whether or not to display the Semantic mode in project search. - /// - /// Default: true - pub enabled: Option, -} - -impl Settings for SemanticIndexSettings { - const KEY: Option<&'static str> = Some("semantic_index"); - - type FileContent = SemanticIndexSettingsContent; - - fn load( - default_value: &Self::FileContent, - user_values: &[&Self::FileContent], - _: &mut gpui::AppContext, - ) -> anyhow::Result { - Self::load_via_json_merge(default_value, user_values) - } -} diff --git a/crates/semantic_index/src/semantic_index_tests.rs b/crates/semantic_index/src/semantic_index_tests.rs deleted file mode 100644 index 728e12f0bc0c4c..00000000000000 --- a/crates/semantic_index/src/semantic_index_tests.rs +++ /dev/null @@ -1,1725 +0,0 @@ -use crate::{ - embedding_queue::EmbeddingQueue, - parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest}, - semantic_index_settings::SemanticIndexSettings, - FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT, -}; -use ai::test::FakeEmbeddingProvider; -use gpui::TestAppContext; -use language::{Language, LanguageConfig, LanguageMatcher, LanguageRegistry, ToOffset}; -use parking_lot::Mutex; -use pretty_assertions::assert_eq; -use project::{FakeFs, Fs, Project}; -use rand::{rngs::StdRng, Rng}; -use serde_json::json; -use settings::{Settings, SettingsStore}; -use std::{path::Path, sync::Arc, time::SystemTime}; -use unindent::Unindent; -use util::{paths::PathMatcher, RandomCharIter}; - -#[ctor::ctor] -fn init_logger() { - if std::env::var("RUST_LOG").is_ok() { - env_logger::init(); - } -} - -#[gpui::test] -async fn test_semantic_index(cx: &mut TestAppContext) { - init_test(cx); - - let fs = FakeFs::new(cx.background_executor.clone()); - fs.insert_tree( - "/the-root", - json!({ - "src": { - "file1.rs": " - fn aaa() { - println!(\"aaaaaaaaaaaa!\"); - } - - fn zzzzz() { - println!(\"SLEEPING\"); - } - ".unindent(), - "file2.rs": " - fn bbb() { - println!(\"bbbbbbbbbbbbb!\"); - } - struct pqpqpqp {} - ".unindent(), - "file3.toml": " - ZZZZZZZZZZZZZZZZZZ = 5 - ".unindent(), - } - }), - ) - .await; - - let languages = Arc::new(LanguageRegistry::test(cx.executor().clone())); - let rust_language = rust_lang(); - let toml_language = toml_lang(); - languages.add(rust_language); - languages.add(toml_language); - - let db_dir = tempfile::Builder::new() - .prefix("vector-store") - .tempdir() - .unwrap(); - let db_path = db_dir.path().join("db.sqlite"); - - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let semantic_index = SemanticIndex::new( - fs.clone(), - db_path, - embedding_provider.clone(), - languages, - cx.to_async(), - ) - .await - .unwrap(); - - let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await; - - let search_results = semantic_index.update(cx, |store, cx| { - store.search_project( - project.clone(), - "aaaaaabbbbzz".to_string(), - 5, - vec![], - vec![], - cx, - ) - }); - let pending_file_count = - semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap()); - cx.background_executor.run_until_parked(); - assert_eq!(*pending_file_count.borrow(), 3); - cx.background_executor - .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); - assert_eq!(*pending_file_count.borrow(), 0); - - let search_results = search_results.await.unwrap(); - assert_search_results( - &search_results, - &[ - (Path::new("src/file1.rs").into(), 0), - (Path::new("src/file2.rs").into(), 0), - (Path::new("src/file3.toml").into(), 0), - (Path::new("src/file1.rs").into(), 45), - (Path::new("src/file2.rs").into(), 45), - ], - cx, - ); - - // Test Include Files Functionality - let include_files = vec![PathMatcher::new("*.rs").unwrap()]; - let exclude_files = vec![PathMatcher::new("*.rs").unwrap()]; - let rust_only_search_results = semantic_index - .update(cx, |store, cx| { - store.search_project( - project.clone(), - "aaaaaabbbbzz".to_string(), - 5, - include_files, - vec![], - cx, - ) - }) - .await - .unwrap(); - - assert_search_results( - &rust_only_search_results, - &[ - (Path::new("src/file1.rs").into(), 0), - (Path::new("src/file2.rs").into(), 0), - (Path::new("src/file1.rs").into(), 45), - (Path::new("src/file2.rs").into(), 45), - ], - cx, - ); - - let no_rust_search_results = semantic_index - .update(cx, |store, cx| { - store.search_project( - project.clone(), - "aaaaaabbbbzz".to_string(), - 5, - vec![], - exclude_files, - cx, - ) - }) - .await - .unwrap(); - - assert_search_results( - &no_rust_search_results, - &[(Path::new("src/file3.toml").into(), 0)], - cx, - ); - - fs.save( - "/the-root/src/file2.rs".as_ref(), - &" - fn dddd() { println!(\"ddddd!\"); } - struct pqpqpqp {} - " - .unindent() - .into(), - Default::default(), - ) - .await - .unwrap(); - - cx.background_executor - .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); - - let prev_embedding_count = embedding_provider.embedding_count(); - let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx)); - cx.background_executor.run_until_parked(); - assert_eq!(*pending_file_count.borrow(), 1); - cx.background_executor - .advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT); - assert_eq!(*pending_file_count.borrow(), 0); - index.await.unwrap(); - - assert_eq!( - embedding_provider.embedding_count() - prev_embedding_count, - 1 - ); -} - -#[gpui::test(iterations = 10)] -async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) { - let (outstanding_job_count, _) = postage::watch::channel_with(0); - let outstanding_job_count = Arc::new(Mutex::new(outstanding_job_count)); - - let files = (1..=3) - .map(|file_ix| FileToEmbed { - worktree_id: 5, - path: Path::new(&format!("path-{file_ix}")).into(), - mtime: SystemTime::now(), - spans: (0..rng.gen_range(4..22)) - .map(|document_ix| { - let content_len = rng.gen_range(10..100); - let content = RandomCharIter::new(&mut rng) - .with_simple_text() - .take(content_len) - .collect::(); - let digest = SpanDigest::from(content.as_str()); - Span { - range: 0..10, - embedding: None, - name: format!("document {document_ix}"), - content, - digest, - token_count: rng.gen_range(10..30), - } - }) - .collect(), - job_handle: JobHandle::new(&outstanding_job_count), - }) - .collect::>(); - - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - - let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background_executor.clone()); - for file in &files { - queue.push(file.clone()); - } - queue.flush(); - - cx.background_executor.run_until_parked(); - let finished_files = queue.finished_files(); - let mut embedded_files: Vec<_> = files - .iter() - .map(|_| finished_files.try_recv().expect("no finished file")) - .collect(); - - let expected_files: Vec<_> = files - .iter() - .map(|file| { - let mut file = file.clone(); - for doc in &mut file.spans { - doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref())); - } - file - }) - .collect(); - - embedded_files.sort_by_key(|f| f.path.clone()); - - assert_eq!(embedded_files, expected_files); -} - -#[track_caller] -fn assert_search_results( - actual: &[SearchResult], - expected: &[(Arc, usize)], - cx: &TestAppContext, -) { - let actual = actual - .iter() - .map(|search_result| { - search_result.buffer.read_with(cx, |buffer, _cx| { - ( - buffer.file().unwrap().path().clone(), - search_result.range.start.to_offset(buffer), - ) - }) - }) - .collect::>(); - assert_eq!(actual, expected); -} - -#[gpui::test] -async fn test_code_context_retrieval_rust() { - let language = rust_lang(); - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut retriever = CodeContextRetriever::new(embedding_provider); - - let text = " - /// A doc comment - /// that spans multiple lines - #[gpui::test] - fn a() { - b - } - - impl C for D { - } - - impl E { - // This is also a preceding comment - pub fn function_1() -> Option<()> { - unimplemented!(); - } - - // This is a preceding comment - fn function_2() -> Result<()> { - unimplemented!(); - } - } - - #[derive(Clone)] - struct D { - name: String - } - " - .unindent(); - - let documents = retriever.parse_file(&text, language).unwrap(); - - assert_documents_eq( - &documents, - &[ - ( - " - /// A doc comment - /// that spans multiple lines - #[gpui::test] - fn a() { - b - }" - .unindent(), - text.find("fn a").unwrap(), - ), - ( - " - impl C for D { - }" - .unindent(), - text.find("impl C").unwrap(), - ), - ( - " - impl E { - // This is also a preceding comment - pub fn function_1() -> Option<()> { /* ... */ } - - // This is a preceding comment - fn function_2() -> Result<()> { /* ... */ } - }" - .unindent(), - text.find("impl E").unwrap(), - ), - ( - " - // This is also a preceding comment - pub fn function_1() -> Option<()> { - unimplemented!(); - }" - .unindent(), - text.find("pub fn function_1").unwrap(), - ), - ( - " - // This is a preceding comment - fn function_2() -> Result<()> { - unimplemented!(); - }" - .unindent(), - text.find("fn function_2").unwrap(), - ), - ( - " - #[derive(Clone)] - struct D { - name: String - }" - .unindent(), - text.find("struct D").unwrap(), - ), - ], - ); -} - -#[gpui::test] -async fn test_code_context_retrieval_json() { - let language = json_lang(); - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut retriever = CodeContextRetriever::new(embedding_provider); - - let text = r#" - { - "array": [1, 2, 3, 4], - "string": "abcdefg", - "nested_object": { - "array_2": [5, 6, 7, 8], - "string_2": "hijklmnop", - "boolean": true, - "none": null - } - } - "# - .unindent(); - - let documents = retriever.parse_file(&text, language.clone()).unwrap(); - - assert_documents_eq( - &documents, - &[( - r#" - { - "array": [], - "string": "", - "nested_object": { - "array_2": [], - "string_2": "", - "boolean": true, - "none": null - } - }"# - .unindent(), - text.find('{').unwrap(), - )], - ); - - let text = r#" - [ - { - "name": "somebody", - "age": 42 - }, - { - "name": "somebody else", - "age": 43 - } - ] - "# - .unindent(); - - let documents = retriever.parse_file(&text, language.clone()).unwrap(); - - assert_documents_eq( - &documents, - &[( - r#" - [{ - "name": "", - "age": 42 - }]"# - .unindent(), - text.find('[').unwrap(), - )], - ); -} - -fn assert_documents_eq( - documents: &[Span], - expected_contents_and_start_offsets: &[(String, usize)], -) { - assert_eq!( - documents - .iter() - .map(|document| (document.content.clone(), document.range.start)) - .collect::>(), - expected_contents_and_start_offsets - ); -} - -#[gpui::test] -async fn test_code_context_retrieval_javascript() { - let language = js_lang(); - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut retriever = CodeContextRetriever::new(embedding_provider); - - let text = " - /* globals importScripts, backend */ - function _authorize() {} - - /** - * Sometimes the frontend build is way faster than backend. - */ - export async function authorizeBank() { - _authorize(pushModal, upgradingAccountId, {}); - } - - export class SettingsPage { - /* This is a test setting */ - constructor(page) { - this.page = page; - } - } - - /* This is a test comment */ - class TestClass {} - - /* Schema for editor_events in Clickhouse. */ - export interface ClickhouseEditorEvent { - installation_id: string - operation: string - } - " - .unindent(); - - let documents = retriever.parse_file(&text, language.clone()).unwrap(); - - assert_documents_eq( - &documents, - &[ - ( - " - /* globals importScripts, backend */ - function _authorize() {}" - .unindent(), - 37, - ), - ( - " - /** - * Sometimes the frontend build is way faster than backend. - */ - export async function authorizeBank() { - _authorize(pushModal, upgradingAccountId, {}); - }" - .unindent(), - 131, - ), - ( - " - export class SettingsPage { - /* This is a test setting */ - constructor(page) { - this.page = page; - } - }" - .unindent(), - 225, - ), - ( - " - /* This is a test setting */ - constructor(page) { - this.page = page; - }" - .unindent(), - 290, - ), - ( - " - /* This is a test comment */ - class TestClass {}" - .unindent(), - 374, - ), - ( - " - /* Schema for editor_events in Clickhouse. */ - export interface ClickhouseEditorEvent { - installation_id: string - operation: string - }" - .unindent(), - 440, - ), - ], - ) -} - -#[gpui::test] -async fn test_code_context_retrieval_lua() { - let language = lua_lang(); - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut retriever = CodeContextRetriever::new(embedding_provider); - - let text = r#" - -- Creates a new class - -- @param baseclass The Baseclass of this class, or nil. - -- @return A new class reference. - function classes.class(baseclass) - -- Create the class definition and metatable. - local classdef = {} - -- Find the super class, either Object or user-defined. - baseclass = baseclass or classes.Object - -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable. - setmetatable(classdef, { __index = baseclass }) - -- All class instances have a reference to the class object. - classdef.class = classdef - --- Recursively allocates the inheritance tree of the instance. - -- @param mastertable The 'root' of the inheritance tree. - -- @return Returns the instance with the allocated inheritance tree. - function classdef.alloc(mastertable) - -- All class instances have a reference to a superclass object. - local instance = { super = baseclass.alloc(mastertable) } - -- Any functions this instance does not know of will 'look up' to the superclass definition. - setmetatable(instance, { __index = classdef, __newindex = mastertable }) - return instance - end - end - "#.unindent(); - - let documents = retriever.parse_file(&text, language.clone()).unwrap(); - - assert_documents_eq( - &documents, - &[ - (r#" - -- Creates a new class - -- @param baseclass The Baseclass of this class, or nil. - -- @return A new class reference. - function classes.class(baseclass) - -- Create the class definition and metatable. - local classdef = {} - -- Find the super class, either Object or user-defined. - baseclass = baseclass or classes.Object - -- If this class definition does not know of a function, it will 'look up' to the Baseclass via the __index of the metatable. - setmetatable(classdef, { __index = baseclass }) - -- All class instances have a reference to the class object. - classdef.class = classdef - --- Recursively allocates the inheritance tree of the instance. - -- @param mastertable The 'root' of the inheritance tree. - -- @return Returns the instance with the allocated inheritance tree. - function classdef.alloc(mastertable) - --[ ... ]-- - --[ ... ]-- - end - end"#.unindent(), - 114), - (r#" - --- Recursively allocates the inheritance tree of the instance. - -- @param mastertable The 'root' of the inheritance tree. - -- @return Returns the instance with the allocated inheritance tree. - function classdef.alloc(mastertable) - -- All class instances have a reference to a superclass object. - local instance = { super = baseclass.alloc(mastertable) } - -- Any functions this instance does not know of will 'look up' to the superclass definition. - setmetatable(instance, { __index = classdef, __newindex = mastertable }) - return instance - end"#.unindent(), 810), - ] - ); -} - -#[gpui::test] -async fn test_code_context_retrieval_elixir() { - let language = elixir_lang(); - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut retriever = CodeContextRetriever::new(embedding_provider); - - let text = r#" - defmodule File.Stream do - @moduledoc """ - Defines a `File.Stream` struct returned by `File.stream!/3`. - - The following fields are public: - - * `path` - the file path - * `modes` - the file modes - * `raw` - a boolean indicating if bin functions should be used - * `line_or_bytes` - if reading should read lines or a given number of bytes - * `node` - the node the file belongs to - - """ - - defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil - - @type t :: %__MODULE__{} - - @doc false - def __build__(path, modes, line_or_bytes) do - raw = :lists.keyfind(:encoding, 1, modes) == false - - modes = - case raw do - true -> - case :lists.keyfind(:read_ahead, 1, modes) do - {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] - {:read_ahead, _} -> [:raw | modes] - false -> [:raw, :read_ahead | modes] - end - - false -> - modes - end - - %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} - - end"# - .unindent(); - - let documents = retriever.parse_file(&text, language.clone()).unwrap(); - - assert_documents_eq( - &documents, - &[( - r#" - defmodule File.Stream do - @moduledoc """ - Defines a `File.Stream` struct returned by `File.stream!/3`. - - The following fields are public: - - * `path` - the file path - * `modes` - the file modes - * `raw` - a boolean indicating if bin functions should be used - * `line_or_bytes` - if reading should read lines or a given number of bytes - * `node` - the node the file belongs to - - """ - - defstruct path: nil, modes: [], line_or_bytes: :line, raw: true, node: nil - - @type t :: %__MODULE__{} - - @doc false - def __build__(path, modes, line_or_bytes) do - raw = :lists.keyfind(:encoding, 1, modes) == false - - modes = - case raw do - true -> - case :lists.keyfind(:read_ahead, 1, modes) do - {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] - {:read_ahead, _} -> [:raw | modes] - false -> [:raw, :read_ahead | modes] - end - - false -> - modes - end - - %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} - - end"# - .unindent(), - 0, - ),(r#" - @doc false - def __build__(path, modes, line_or_bytes) do - raw = :lists.keyfind(:encoding, 1, modes) == false - - modes = - case raw do - true -> - case :lists.keyfind(:read_ahead, 1, modes) do - {:read_ahead, false} -> [:raw | :lists.keydelete(:read_ahead, 1, modes)] - {:read_ahead, _} -> [:raw | modes] - false -> [:raw, :read_ahead | modes] - end - - false -> - modes - end - - %File.Stream{path: path, modes: modes, raw: raw, line_or_bytes: line_or_bytes, node: node()} - - end"#.unindent(), 574)], - ); -} - -#[gpui::test] -async fn test_code_context_retrieval_cpp() { - let language = cpp_lang(); - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut retriever = CodeContextRetriever::new(embedding_provider); - - let text = " - /** - * @brief Main function - * @returns 0 on exit - */ - int main() { return 0; } - - /** - * This is a test comment - */ - class MyClass { // The class - public: // Access specifier - int myNum; // Attribute (int variable) - string myString; // Attribute (string variable) - }; - - // This is a test comment - enum Color { red, green, blue }; - - /** This is a preceding block comment - * This is the second line - */ - struct { // Structure declaration - int myNum; // Member (int variable) - string myString; // Member (string variable) - } myStructure; - - /** - * @brief Matrix class. - */ - template ::value || std::is_floating_point::value, - bool>::type> - class Matrix2 { - std::vector> _mat; - - public: - /** - * @brief Constructor - * @tparam Integer ensuring integers are being evaluated and not other - * data types. - * @param size denoting the size of Matrix as size x size - */ - template ::value, - Integer>::type> - explicit Matrix(const Integer size) { - for (size_t i = 0; i < size; ++i) { - _mat.emplace_back(std::vector(size, 0)); - } - } - }" - .unindent(); - - let documents = retriever.parse_file(&text, language.clone()).unwrap(); - - assert_documents_eq( - &documents, - &[ - ( - " - /** - * @brief Main function - * @returns 0 on exit - */ - int main() { return 0; }" - .unindent(), - 54, - ), - ( - " - /** - * This is a test comment - */ - class MyClass { // The class - public: // Access specifier - int myNum; // Attribute (int variable) - string myString; // Attribute (string variable) - }" - .unindent(), - 112, - ), - ( - " - // This is a test comment - enum Color { red, green, blue }" - .unindent(), - 322, - ), - ( - " - /** This is a preceding block comment - * This is the second line - */ - struct { // Structure declaration - int myNum; // Member (int variable) - string myString; // Member (string variable) - } myStructure;" - .unindent(), - 425, - ), - ( - " - /** - * @brief Matrix class. - */ - template ::value || std::is_floating_point::value, - bool>::type> - class Matrix2 { - std::vector> _mat; - - public: - /** - * @brief Constructor - * @tparam Integer ensuring integers are being evaluated and not other - * data types. - * @param size denoting the size of Matrix as size x size - */ - template ::value, - Integer>::type> - explicit Matrix(const Integer size) { - for (size_t i = 0; i < size; ++i) { - _mat.emplace_back(std::vector(size, 0)); - } - } - }" - .unindent(), - 612, - ), - ( - " - explicit Matrix(const Integer size) { - for (size_t i = 0; i < size; ++i) { - _mat.emplace_back(std::vector(size, 0)); - } - }" - .unindent(), - 1226, - ), - ], - ); -} - -#[gpui::test] -async fn test_code_context_retrieval_ruby() { - let language = ruby_lang(); - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut retriever = CodeContextRetriever::new(embedding_provider); - - let text = r#" - # This concern is inspired by "sudo mode" on GitHub. It - # is a way to re-authenticate a user before allowing them - # to see or perform an action. - # - # Add `before_action :require_challenge!` to actions you - # want to protect. - # - # The user will be shown a page to enter the challenge (which - # is either the password, or just the username when no - # password exists). Upon passing, there is a grace period - # during which no challenge will be asked from the user. - # - # Accessing challenge-protected resources during the grace - # period will refresh the grace period. - module ChallengableConcern - extend ActiveSupport::Concern - - CHALLENGE_TIMEOUT = 1.hour.freeze - - def require_challenge! - return if skip_challenge? - - if challenge_passed_recently? - session[:challenge_passed_at] = Time.now.utc - return - end - - @challenge = Form::Challenge.new(return_to: request.url) - - if params.key?(:form_challenge) - if challenge_passed? - session[:challenge_passed_at] = Time.now.utc - else - flash.now[:alert] = I18n.t('challenge.invalid_password') - render_challenge - end - else - render_challenge - end - end - - def challenge_passed? - current_user.valid_password?(challenge_params[:current_password]) - end - end - - class Animal - include Comparable - - attr_reader :legs - - def initialize(name, legs) - @name, @legs = name, legs - end - - def <=>(other) - legs <=> other.legs - end - end - - # Singleton method for car object - def car.wheels - puts "There are four wheels" - end"# - .unindent(); - - let documents = retriever.parse_file(&text, language.clone()).unwrap(); - - assert_documents_eq( - &documents, - &[ - ( - r#" - # This concern is inspired by "sudo mode" on GitHub. It - # is a way to re-authenticate a user before allowing them - # to see or perform an action. - # - # Add `before_action :require_challenge!` to actions you - # want to protect. - # - # The user will be shown a page to enter the challenge (which - # is either the password, or just the username when no - # password exists). Upon passing, there is a grace period - # during which no challenge will be asked from the user. - # - # Accessing challenge-protected resources during the grace - # period will refresh the grace period. - module ChallengableConcern - extend ActiveSupport::Concern - - CHALLENGE_TIMEOUT = 1.hour.freeze - - def require_challenge! - # ... - end - - def challenge_passed? - # ... - end - end"# - .unindent(), - 558, - ), - ( - r#" - def require_challenge! - return if skip_challenge? - - if challenge_passed_recently? - session[:challenge_passed_at] = Time.now.utc - return - end - - @challenge = Form::Challenge.new(return_to: request.url) - - if params.key?(:form_challenge) - if challenge_passed? - session[:challenge_passed_at] = Time.now.utc - else - flash.now[:alert] = I18n.t('challenge.invalid_password') - render_challenge - end - else - render_challenge - end - end"# - .unindent(), - 663, - ), - ( - r#" - def challenge_passed? - current_user.valid_password?(challenge_params[:current_password]) - end"# - .unindent(), - 1254, - ), - ( - r#" - class Animal - include Comparable - - attr_reader :legs - - def initialize(name, legs) - # ... - end - - def <=>(other) - # ... - end - end"# - .unindent(), - 1363, - ), - ( - r#" - def initialize(name, legs) - @name, @legs = name, legs - end"# - .unindent(), - 1427, - ), - ( - r#" - def <=>(other) - legs <=> other.legs - end"# - .unindent(), - 1501, - ), - ( - r#" - # Singleton method for car object - def car.wheels - puts "There are four wheels" - end"# - .unindent(), - 1591, - ), - ], - ); -} - -#[gpui::test] -async fn test_code_context_retrieval_php() { - let language = php_lang(); - let embedding_provider = Arc::new(FakeEmbeddingProvider::default()); - let mut retriever = CodeContextRetriever::new(embedding_provider); - - let text = r#" - 100) { - throw new Exception(message: 'Progress cannot be greater than 100'); - } - - if ($this->achievements()->find($achievement->id)) { - throw new Exception(message: 'User already has this Achievement'); - } - - $this->achievements()->attach($achievement, [ - 'progress' => $progress ?? null, - ]); - - $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this))); - } - - public function achievements(): BelongsToMany - { - return $this->belongsToMany(related: Achievement::class) - ->withPivot(columns: 'progress') - ->where('is_secret', false) - ->using(AchievementUser::class); - } - } - - interface Multiplier - { - public function qualifies(array $data): bool; - - public function setMultiplier(): int; - } - - enum AuditType: string - { - case Add = 'add'; - case Remove = 'remove'; - case Reset = 'reset'; - case LevelUp = 'level_up'; - } - - ?>"# - .unindent(); - - let documents = retriever.parse_file(&text, language.clone()).unwrap(); - - assert_documents_eq( - &documents, - &[ - ( - r#" - /* - This is a multiple-lines comment block - that spans over multiple - lines - */ - function functionName() { - echo "Hello world!"; - }"# - .unindent(), - 123, - ), - ( - r#" - trait HasAchievements - { - /** - * @throws \Exception - */ - public function grantAchievement(Achievement $achievement, $progress = null): void - {/* ... */} - - public function achievements(): BelongsToMany - {/* ... */} - }"# - .unindent(), - 177, - ), - (r#" - /** - * @throws \Exception - */ - public function grantAchievement(Achievement $achievement, $progress = null): void - { - if ($progress > 100) { - throw new Exception(message: 'Progress cannot be greater than 100'); - } - - if ($this->achievements()->find($achievement->id)) { - throw new Exception(message: 'User already has this Achievement'); - } - - $this->achievements()->attach($achievement, [ - 'progress' => $progress ?? null, - ]); - - $this->when(value: ($progress === null) || ($progress === 100), callback: fn (): ?array => event(new AchievementAwarded(achievement: $achievement, user: $this))); - }"#.unindent(), 245), - (r#" - public function achievements(): BelongsToMany - { - return $this->belongsToMany(related: Achievement::class) - ->withPivot(columns: 'progress') - ->where('is_secret', false) - ->using(AchievementUser::class); - }"#.unindent(), 902), - (r#" - interface Multiplier - { - public function qualifies(array $data): bool; - - public function setMultiplier(): int; - }"#.unindent(), - 1146), - (r#" - enum AuditType: string - { - case Add = 'add'; - case Remove = 'remove'; - case Reset = 'reset'; - case LevelUp = 'level_up'; - }"#.unindent(), 1265) - ], - ); -} - -fn js_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "Javascript".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["js".into()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_typescript::language_tsx()), - ) - .with_embedding_query( - &r#" - - ( - (comment)* @context - . - [ - (export_statement - (function_declaration - "async"? @name - "function" @name - name: (_) @name)) - (function_declaration - "async"? @name - "function" @name - name: (_) @name) - ] @item - ) - - ( - (comment)* @context - . - [ - (export_statement - (class_declaration - "class" @name - name: (_) @name)) - (class_declaration - "class" @name - name: (_) @name) - ] @item - ) - - ( - (comment)* @context - . - [ - (export_statement - (interface_declaration - "interface" @name - name: (_) @name)) - (interface_declaration - "interface" @name - name: (_) @name) - ] @item - ) - - ( - (comment)* @context - . - [ - (export_statement - (enum_declaration - "enum" @name - name: (_) @name)) - (enum_declaration - "enum" @name - name: (_) @name) - ] @item - ) - - ( - (comment)* @context - . - (method_definition - [ - "get" - "set" - "async" - "*" - "static" - ]* @name - name: (_) @name) @item - ) - - "# - .unindent(), - ) - .unwrap(), - ) -} - -fn rust_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".into()], - ..Default::default() - }, - collapsed_placeholder: " /* ... */ ".to_string(), - ..Default::default() - }, - Some(tree_sitter_rust::language()), - ) - .with_embedding_query( - r#" - ( - [(line_comment) (attribute_item)]* @context - . - [ - (struct_item - name: (_) @name) - - (enum_item - name: (_) @name) - - (impl_item - trait: (_)? @name - "for"? @name - type: (_) @name) - - (trait_item - name: (_) @name) - - (function_item - name: (_) @name - body: (block - "{" @keep - "}" @keep) @collapse) - - (macro_definition - name: (_) @name) - ] @item - ) - - (attribute_item) @collapse - (use_declaration) @collapse - "#, - ) - .unwrap(), - ) -} - -fn json_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "JSON".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["json".into()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_json::language()), - ) - .with_embedding_query( - r#" - (document) @item - - (array - "[" @keep - . - (object)? @keep - "]" @keep) @collapse - - (pair value: (string - "\"" @keep - "\"" @keep) @collapse) - "#, - ) - .unwrap(), - ) -} - -fn toml_lang() -> Arc { - Arc::new(Language::new( - LanguageConfig { - name: "TOML".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["toml".into()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_toml::language()), - )) -} - -fn cpp_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "CPP".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["cpp".into()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_cpp::language()), - ) - .with_embedding_query( - r#" - ( - (comment)* @context - . - (function_definition - (type_qualifier)? @name - type: (_)? @name - declarator: [ - (function_declarator - declarator: (_) @name) - (pointer_declarator - "*" @name - declarator: (function_declarator - declarator: (_) @name)) - (pointer_declarator - "*" @name - declarator: (pointer_declarator - "*" @name - declarator: (function_declarator - declarator: (_) @name))) - (reference_declarator - ["&" "&&"] @name - (function_declarator - declarator: (_) @name)) - ] - (type_qualifier)? @name) @item - ) - - ( - (comment)* @context - . - (template_declaration - (class_specifier - "class" @name - name: (_) @name) - ) @item - ) - - ( - (comment)* @context - . - (class_specifier - "class" @name - name: (_) @name) @item - ) - - ( - (comment)* @context - . - (enum_specifier - "enum" @name - name: (_) @name) @item - ) - - ( - (comment)* @context - . - (declaration - type: (struct_specifier - "struct" @name) - declarator: (_) @name) @item - ) - - "#, - ) - .unwrap(), - ) -} - -fn lua_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "Lua".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["lua".into()], - ..Default::default() - }, - collapsed_placeholder: "--[ ... ]--".to_string(), - ..Default::default() - }, - Some(tree_sitter_lua::language()), - ) - .with_embedding_query( - r#" - ( - (comment)* @context - . - (function_declaration - "function" @name - name: (_) @name - (comment)* @collapse - body: (block) @collapse - ) @item - ) - "#, - ) - .unwrap(), - ) -} - -fn php_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "PHP".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["php".into()], - ..Default::default() - }, - collapsed_placeholder: "/* ... */".into(), - ..Default::default() - }, - Some(tree_sitter_php::language_php()), - ) - .with_embedding_query( - r#" - ( - (comment)* @context - . - [ - (function_definition - "function" @name - name: (_) @name - body: (_ - "{" @keep - "}" @keep) @collapse - ) - - (trait_declaration - "trait" @name - name: (_) @name) - - (method_declaration - "function" @name - name: (_) @name - body: (_ - "{" @keep - "}" @keep) @collapse - ) - - (interface_declaration - "interface" @name - name: (_) @name - ) - - (enum_declaration - "enum" @name - name: (_) @name - ) - - ] @item - ) - "#, - ) - .unwrap(), - ) -} - -fn ruby_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "Ruby".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rb".into()], - ..Default::default() - }, - collapsed_placeholder: "# ...".to_string(), - ..Default::default() - }, - Some(tree_sitter_ruby::language()), - ) - .with_embedding_query( - r#" - ( - (comment)* @context - . - [ - (module - "module" @name - name: (_) @name) - (method - "def" @name - name: (_) @name - body: (body_statement) @collapse) - (class - "class" @name - name: (_) @name) - (singleton_method - "def" @name - object: (_) @name - "." @name - name: (_) @name - body: (body_statement) @collapse) - ] @item - ) - "#, - ) - .unwrap(), - ) -} - -fn elixir_lang() -> Arc { - Arc::new( - Language::new( - LanguageConfig { - name: "Elixir".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".into()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_elixir::language()), - ) - .with_embedding_query( - r#" - ( - (unary_operator - operator: "@" - operand: (call - target: (identifier) @unary - (#match? @unary "^(doc)$")) - ) @context - . - (call - target: (identifier) @name - (arguments - [ - (identifier) @name - (call - target: (identifier) @name) - (binary_operator - left: (call - target: (identifier) @name) - operator: "when") - ]) - (#any-match? @name "^(def|defp|defdelegate|defguard|defguardp|defmacro|defmacrop|defn|defnp)$")) @item - ) - - (call - target: (identifier) @name - (arguments (alias) @name) - (#any-match? @name "^(defmodule|defprotocol)$")) @item - "#, - ) - .unwrap(), - ) -} - -#[gpui::test] -fn test_subtract_ranges() { - assert_eq!( - subtract_ranges(&[0..5, 10..21], &[0..1, 4..5]), - vec![1..4, 10..21] - ); - - assert_eq!(subtract_ranges(&[0..5], &[1..2]), &[0..1, 2..5]); -} - -fn init_test(cx: &mut TestAppContext) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - SemanticIndexSettings::register(cx); - language::init(cx); - Project::init_settings(cx); - }); -} diff --git a/crates/settings/src/settings_store.rs b/crates/settings/src/settings_store.rs index 32522eca016d9e..bc96ac12a0167e 100644 --- a/crates/settings/src/settings_store.rs +++ b/crates/settings/src/settings_store.rs @@ -479,7 +479,28 @@ impl SettingsStore { merge_schema(target_schema, setting_schema.schema); } - fn merge_schema(target: &mut SchemaObject, source: SchemaObject) { + fn merge_schema(target: &mut SchemaObject, mut source: SchemaObject) { + let source_subschemas = source.subschemas(); + let target_subschemas = target.subschemas(); + if let Some(all_of) = source_subschemas.all_of.take() { + target_subschemas + .all_of + .get_or_insert(Vec::new()) + .extend(all_of); + } + if let Some(any_of) = source_subschemas.any_of.take() { + target_subschemas + .any_of + .get_or_insert(Vec::new()) + .extend(any_of); + } + if let Some(one_of) = source_subschemas.one_of.take() { + target_subschemas + .one_of + .get_or_insert(Vec::new()) + .extend(one_of); + } + if let Some(source) = source.object { let target_properties = &mut target.object().properties; for (key, value) in source.properties { diff --git a/crates/util/src/http.rs b/crates/util/src/http.rs index 0078cdee2f2fa8..6dfde7833be802 100644 --- a/crates/util/src/http.rs +++ b/crates/util/src/http.rs @@ -5,9 +5,8 @@ use futures_lite::FutureExt; use isahc::config::{Configurable, RedirectPolicy}; pub use isahc::{ http::{Method, StatusCode, Uri}, - Error, + AsyncBody, Error, HttpClient as IsahcHttpClient, Request, Response, }; -pub use isahc::{AsyncBody, Request, Response}; #[cfg(feature = "test-support")] use std::fmt; use std::{ diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index f39bc5c3f1b6da..21c3b8dd99970c 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -71,7 +71,6 @@ recent_projects.workspace = true release_channel.workspace = true rope.workspace = true search.workspace = true -semantic_index.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 9746b97dba09a4..ce0cdaf77fb48d 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -174,7 +174,7 @@ fn main() { node_runtime.clone(), cx, ); - assistant::init(cx); + assistant::init(client.clone(), cx); extension::init( fs.clone(), @@ -247,7 +247,6 @@ fn main() { tasks_ui::init(cx); channel::init(&client, user_store.clone(), cx); search::init(cx); - semantic_index::init(fs.clone(), http.clone(), languages.clone(), cx); vim::init(cx); terminal_view::init(cx); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 33f3dc1c92cdb7..80f4e60a6844fa 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -3060,7 +3060,7 @@ mod tests { collab_ui::init(&app_state, cx); project_panel::init((), cx); terminal_view::init(cx); - assistant::init(cx); + assistant::init(app_state.client.clone(), cx); initialize_workspace(app_state.clone(), cx); app_state }) diff --git a/docs/src/configuring_zed.md b/docs/src/configuring_zed.md index 06d41098f617ad..cec956347dbc02 100644 --- a/docs/src/configuring_zed.md +++ b/docs/src/configuring_zed.md @@ -606,28 +606,6 @@ These values take in the same options as the root-level settings with the same n `boolean` values -## Semantic Index - -- Description: Settings related to semantic index. -- Setting: `semantic_index` -- Default: - -```json -"semantic_index": { - "enabled": false -}, -``` - -### Enabled - -- Description: Whether or not to display the `Semantic` mode in project search. -- Setting: `enabled` -- Default: `true` - -**Options** - -`boolean` values - ## Show Call Status Icon - Description: Whether or not to show the call status icon in the status bar. diff --git a/script/bootstrap b/script/bootstrap index e23f42e80edd4e..3f1ebf666abf4c 100755 --- a/script/bootstrap +++ b/script/bootstrap @@ -11,3 +11,8 @@ cargo run -p collab -- migrate echo "seeding database..." script/seed-db + +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + echo "Linux dependencies..." + script/linux +fi diff --git a/script/evaluate_semantic_index b/script/evaluate_semantic_index deleted file mode 100755 index 9ecfe898c5c4ea..00000000000000 --- a/script/evaluate_semantic_index +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -RUST_LOG=semantic_index=trace cargo run --example semantic_index_eval --release diff --git a/script/gemini.py b/script/gemini.py new file mode 100644 index 00000000000000..7f3a7130de4e9b --- /dev/null +++ b/script/gemini.py @@ -0,0 +1,91 @@ +import subprocess +import json +import http.client +import mimetypes +import os + +def get_text_files(): + text_files = [] + # List all files tracked by Git + git_files_proc = subprocess.run(['git', 'ls-files'], stdout=subprocess.PIPE, text=True) + for file in git_files_proc.stdout.strip().split('\n'): + # Check MIME type for each file + mime_check_proc = subprocess.run(['file', '--mime', file], stdout=subprocess.PIPE, text=True) + if 'text' in mime_check_proc.stdout: + text_files.append(file) + + print(f"File count: {len(text_files)}") + + return text_files + +def get_file_contents(file): + # Read file content + with open(file, 'r') as f: + return f.read() + + +def main(): + GEMINI_API_KEY = os.environ.get('GEMINI_API_KEY') + + # Your prompt + prompt = "Document the data types and dataflow in this codebase in preparation to port a streaming implementation to rust:\n\n" + # Fetch all text files + text_files = get_text_files() + code_blocks = [] + for file in text_files: + file_contents = get_file_contents(file) + # Create a code block for each text file + code_blocks.append(f"\n`{file}`\n\n```{file_contents}```\n") + + # Construct the JSON payload + payload = json.dumps({ + "contents": [{ + "parts": [{ + "text": prompt + "".join(code_blocks) + }] + }] + }) + + # Prepare the HTTP connection + conn = http.client.HTTPSConnection("generativelanguage.googleapis.com") + + # Define headers + headers = { + 'Content-Type': 'application/json', + 'Content-Length': str(len(payload)) + } + + # Output the content length in bytes + print(f"Content Length in kilobytes: {len(payload.encode('utf-8')) / 1024:.2f} KB") + + + # Send a request to count the tokens + conn.request("POST", f"/v1beta/models/gemini-1.5-pro-latest:countTokens?key={GEMINI_API_KEY}", body=payload, headers=headers) + # Get the response + response = conn.getresponse() + if response.status == 200: + token_count = json.loads(response.read().decode('utf-8')).get('totalTokens') + print(f"Token count: {token_count}") + else: + print(f"Failed to get token count. Status code: {response.status}, Response body: {response.read().decode('utf-8')}") + + + # Prepare the HTTP connection + conn = http.client.HTTPSConnection("generativelanguage.googleapis.com") + conn.request("GET", f"/v1beta/models/gemini-1.5-pro-latest:streamGenerateContent?key={GEMINI_API_KEY}", body=payload, headers=headers) + + # Get the response in a streaming manner + response = conn.getresponse() + if response.status == 200: + print("Successfully sent the data to the API.") + # Read the response in chunks + while chunk := response.read(4096): + print(chunk.decode('utf-8')) + else: + print(f"Failed to send the data to the API. Status code: {response.status}, Response body: {response.read().decode('utf-8')}") + + # Close the connection + conn.close() + +if __name__ == "__main__": + main() diff --git a/script/linux b/script/linux index ef67073ac234d9..1d6c40f839dfbb 100755 --- a/script/linux +++ b/script/linux @@ -1,4 +1,6 @@ -#!/usr/bin/bash -e +#!/usr/bin/bash + +set -e # if sudo is not installed, define an empty alias maysudo=$(command -v sudo || command -v doas || true) diff --git a/script/script.py b/script/script.py new file mode 100644 index 00000000000000..8b137891791fe9 --- /dev/null +++ b/script/script.py @@ -0,0 +1 @@ + diff --git a/script/sqlx b/script/sqlx index cf2fa8d405f0b1..038218c4652a67 100755 --- a/script/sqlx +++ b/script/sqlx @@ -3,12 +3,15 @@ set -e # Install sqlx-cli if needed -[[ "$(sqlx --version)" == "sqlx-cli 0.5.7" ]] || cargo install sqlx-cli --version 0.5.7 +if [[ "$(sqlx --version)" != "sqlx-cli 0.5.7" ]]; then + echo "sqlx-cli not found or not the required version, installing version 0.5.7..." + cargo install sqlx-cli --version 0.5.7 +fi cd crates/collab # Export contents of .env.toml -eval "$(cargo run --quiet --bin dotenv)" +eval "$(cargo run --bin dotenv)" # Run sqlx command sqlx $@