Skip to content

Commit

Permalink
Partially fix assistant onboarding (#25313)
Browse files Browse the repository at this point in the history
While investigating #24896, I noticed two issues:

1. The default configuration for the `zed.dev` provider was using the
wrong string for Claude 3.5 Sonnet. This meant the provider would always
result as not configured until the user selected it from the model
picker, because we couldn't deserialize that string to a valid
`anthropic::Model` enum variant.
2. When clicking on `Open New Chat`/`Start New Thread` in the provider
configuration, we would select `Claude 3.5 Haiku` by default instead of
Claude 3.5 Sonnet.

Release Notes:

- Fixed some issues that caused AI providers to sometimes be
misconfigured.
  • Loading branch information
as-cii authored Feb 24, 2025
1 parent 535ba75 commit f517050
Show file tree
Hide file tree
Showing 16 changed files with 94 additions and 5 deletions.
2 changes: 1 addition & 1 deletion assets/settings/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@
// The provider to use.
"provider": "zed.dev",
// The model to use.
"model": "claude-3-5-sonnet"
"model": "claude-3-5-sonnet-latest"
}
},
// The settings for slash commands.
Expand Down
2 changes: 1 addition & 1 deletion crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ impl AssistantPanel {
.active_provider()
.map_or(true, |p| p.id() != provider.id())
{
if let Some(model) = provider.provided_models(cx).first().cloned() {
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
this.fs.clone(),
cx,
Expand Down
2 changes: 1 addition & 1 deletion crates/assistant2/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ impl AssistantPanel {
active_provider.id() != provider.id()
})
{
if let Some(model) = provider.provided_models(cx).first().cloned() {
if let Some(model) = provider.default_model(cx) {
update_settings_file::<AssistantSettings>(
self.fs.clone(),
cx,
Expand Down
2 changes: 1 addition & 1 deletion crates/assistant_settings/src/assistant_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ mod tests {
AssistantSettings::get_global(cx).default_model,
LanguageModelSelection {
provider: "zed.dev".into(),
model: "claude-3-5-sonnet".into(),
model: "claude-3-5-sonnet-latest".into(),
}
);
});
Expand Down
3 changes: 2 additions & 1 deletion crates/google_ai/src/google_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ pub struct CountTokensResponse {
}

#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model {
#[serde(rename = "gemini-1.5-pro")]
Gemini15Pro,
Expand All @@ -308,6 +308,7 @@ pub enum Model {
#[serde(rename = "gemini-2.0-pro-exp")]
Gemini20Pro,
#[serde(rename = "gemini-2.0-flash")]
#[default]
Gemini20Flash,
#[serde(rename = "gemini-2.0-flash-thinking-exp")]
Gemini20FlashThinking,
Expand Down
4 changes: 4 additions & 0 deletions crates/language_model/src/fake_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
provider_name()
}

fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
Some(Arc::new(FakeLanguageModel::default()))
}

fn provided_models(&self, _: &App) -> Vec<Arc<dyn LanguageModel>> {
vec![Arc::new(FakeLanguageModel::default())]
}
Expand Down
1 change: 1 addition & 0 deletions crates/language_model/src/language_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ pub trait LanguageModelProvider: 'static {
fn icon(&self) -> IconName {
IconName::ZedAssistant
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
fn is_authenticated(&self, cx: &App) -> bool;
Expand Down
11 changes: 11 additions & 0 deletions crates/language_models/src/provider/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
IconName::AiAnthropic
}

fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = anthropic::Model::default();
Some(Arc::new(AnthropicModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}

fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();

Expand Down
12 changes: 12 additions & 0 deletions crates/language_models/src/provider/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,18 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
IconName::AiZed
}

fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
let llm_api_token = self.state.read(cx).llm_api_token.clone();
let model = CloudModel::Anthropic(anthropic::Model::default());
Some(Arc::new(CloudLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
llm_api_token: llm_api_token.clone(),
client: self.client.clone(),
request_limiter: RateLimiter::new(4),
}))
}

fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();

Expand Down
8 changes: 8 additions & 0 deletions crates/language_models/src/provider/copilot_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
IconName::Copilot
}

fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = CopilotChatModel::default();
Some(Arc::new(CopilotChatLanguageModel {
model,
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>)
}

fn provided_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter()
.map(|model| {
Expand Down
11 changes: 11 additions & 0 deletions crates/language_models/src/provider/deepseek.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,17 @@ impl LanguageModelProvider for DeepSeekLanguageModelProvider {
IconName::AiDeepSeek
}

fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = deepseek::Model::Chat;
Some(Arc::new(DeepSeekLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}

fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();

Expand Down
11 changes: 11 additions & 0 deletions crates/language_models/src/provider/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,17 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
IconName::AiGoogle
}

fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = google_ai::Model::default();
Some(Arc::new(GoogleLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
}))
}

fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();

Expand Down
4 changes: 4 additions & 0 deletions crates/language_models/src/provider/lmstudio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
IconName::AiLmStudio
}

fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.provided_models(cx).into_iter().next()
}

fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, lmstudio::Model> = BTreeMap::default();

Expand Down
11 changes: 11 additions & 0 deletions crates/language_models/src/provider/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,17 @@ impl LanguageModelProvider for MistralLanguageModelProvider {
IconName::AiMistral
}

fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = mistral::Model::default();
Some(Arc::new(MistralLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}

fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();

Expand Down
4 changes: 4 additions & 0 deletions crates/language_models/src/provider/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
IconName::AiOllama
}

fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
self.provided_models(cx).into_iter().next()
}

fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();

Expand Down
11 changes: 11 additions & 0 deletions crates/language_models/src/provider/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
IconName::AiOpenAi
}

fn default_model(&self, _cx: &App) -> Option<Arc<dyn LanguageModel>> {
let model = open_ai::Model::default();
Some(Arc::new(OpenAiLanguageModel {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}))
}

fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models = BTreeMap::default();

Expand Down

0 comments on commit f517050

Please sign in to comment.