Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow OpenAI API URL to be configured via assistant.openai_api_url #7552

Merged
merged 5 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions assets/settings/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@
"default_width": 640,
// Default height when the assistant is docked to the bottom.
"default_height": 320,
// The default OpenAI api endpoint to use when starting new conversations.
"openai_api_url": "https://api.openai.com/v1",
// The default OpenAI model to use when starting new conversations. This
// setting can take three values:
//
Expand Down
10 changes: 7 additions & 3 deletions crates/ai/src/providers/open_ai/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pub struct OpenAiResponseStreamEvent {
}

pub async fn stream_completion(
api_url: String,
credential: ProviderCredential,
executor: BackgroundExecutor,
request: Box<dyn CompletionRequest>,
Expand All @@ -117,7 +118,7 @@ pub async fn stream_completion(
let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAiResponseStreamEvent>>();

let json_data = request.data()?;
let mut response = Request::post(format!("{OPEN_AI_API_URL}/chat/completions"))
let mut response = Request::post(format!("{api_url}/chat/completions"))
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key))
.body(json_data)?
Expand Down Expand Up @@ -195,18 +196,20 @@ pub async fn stream_completion(

#[derive(Clone)]
pub struct OpenAiCompletionProvider {
api_url: String,
model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
executor: BackgroundExecutor,
}

impl OpenAiCompletionProvider {
pub async fn new(model_name: String, executor: BackgroundExecutor) -> Self {
pub async fn new(api_url: String, model_name: String, executor: BackgroundExecutor) -> Self {
let model = executor
.spawn(async move { OpenAiLanguageModel::load(&model_name) })
.await;
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
Self {
api_url,
model,
credential,
executor,
Expand Down Expand Up @@ -303,7 +306,8 @@ impl CompletionProvider for OpenAiCompletionProvider {
// which is currently model based, due to the language model.
// At some point in the future we should rectify this.
let credential = self.credential.read().clone();
let request = stream_completion(credential, self.executor.clone(), prompt);
let api_url = self.api_url.clone();
let request = stream_completion(api_url, credential, self.executor.clone(), prompt);
async move {
let response = request.await?;
let stream = response
Expand Down
13 changes: 11 additions & 2 deletions crates/ai/src/providers/open_ai/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ lazy_static! {

#[derive(Clone)]
pub struct OpenAiEmbeddingProvider {
api_url: String,
model: OpenAiLanguageModel,
credential: Arc<RwLock<ProviderCredential>>,
pub client: Arc<dyn HttpClient>,
Expand Down Expand Up @@ -69,7 +70,11 @@ struct OpenAiEmbeddingUsage {
}

impl OpenAiEmbeddingProvider {
pub async fn new(client: Arc<dyn HttpClient>, executor: BackgroundExecutor) -> Self {
pub async fn new(
api_url: String,
client: Arc<dyn HttpClient>,
executor: BackgroundExecutor,
) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));

Expand All @@ -80,6 +85,7 @@ impl OpenAiEmbeddingProvider {
let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));

OpenAiEmbeddingProvider {
api_url,
model,
credential,
client,
Expand Down Expand Up @@ -130,11 +136,12 @@ impl OpenAiEmbeddingProvider {
}
async fn send_request(
&self,
api_url: &str,
api_key: &str,
spans: Vec<&str>,
request_timeout: u64,
) -> Result<Response<AsyncBody>> {
let request = Request::post(format!("{OPEN_AI_API_URL}/embeddings"))
let request = Request::post(format!("{api_url}/embeddings"))
.redirect_policy(isahc::config::RedirectPolicy::Follow)
.timeout(Duration::from_secs(request_timeout))
.header("Content-Type", "application/json")
Expand Down Expand Up @@ -246,6 +253,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;

let api_url = self.api_url.as_str();
let api_key = self.get_api_key()?;

let mut request_number = 0;
Expand All @@ -255,6 +263,7 @@ impl EmbeddingProvider for OpenAiEmbeddingProvider {
while request_number < MAX_RETRIES {
response = self
.send_request(
&api_url,
&api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
Expand Down
1 change: 1 addition & 0 deletions crates/assistant/src/assistant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ struct SavedConversation {
messages: Vec<SavedMessage>,
message_metadata: HashMap<MessageId, MessageMetadata>,
summary: String,
api_url: Option<String>,
model: OpenAiModel,
}

Expand Down
30 changes: 26 additions & 4 deletions crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use ai::{
completion::{CompletionProvider, CompletionRequest},
providers::open_ai::{OpenAiCompletionProvider, OpenAiRequest, RequestMessage},
};
use ai::providers::open_ai::OPEN_AI_API_URL;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use client::telemetry::AssistantKind;
Expand Down Expand Up @@ -121,10 +122,22 @@ impl AssistantPanel {
.await
.log_err()
.unwrap_or_default();
// Defaulting currently to GPT4, allow for this to be set via config.
let completion_provider =
OpenAiCompletionProvider::new("gpt-4".into(), cx.background_executor().clone())
.await;
let (api_url, model_name) = cx
.update(|cx| {
let settings = AssistantSettings::get_global(cx);
(
settings.openai_api_url.clone(),
settings.default_open_ai_model.full_name().to_string(),
)
})
.log_err()
.unwrap();
let completion_provider = OpenAiCompletionProvider::new(
api_url,
model_name,
cx.background_executor().clone(),
)
.await;

// TODO: deserialize state.
let workspace_handle = workspace.clone();
Expand Down Expand Up @@ -1407,6 +1420,7 @@ struct Conversation {
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
model: OpenAiModel,
api_url: Option<String>,
token_count: Option<usize>,
max_token_count: usize,
pending_token_count: Task<Option<()>>,
Expand Down Expand Up @@ -1441,6 +1455,7 @@ impl Conversation {

let settings = AssistantSettings::get_global(cx);
let model = settings.default_open_ai_model.clone();
let api_url = settings.openai_api_url.clone();

let mut this = Self {
id: Some(Uuid::new_v4().to_string()),
Expand All @@ -1454,6 +1469,7 @@ impl Conversation {
token_count: None,
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
pending_token_count: Task::ready(None),
api_url: Some(api_url),
model: model.clone(),
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
Expand Down Expand Up @@ -1499,6 +1515,7 @@ impl Conversation {
.map(|summary| summary.text.clone())
.unwrap_or_default(),
model: self.model.clone(),
api_url: self.api_url.clone(),
}
}

Expand All @@ -1513,8 +1530,12 @@ impl Conversation {
None => Some(Uuid::new_v4().to_string()),
};
let model = saved_conversation.model;
let api_url = saved_conversation.api_url;
let completion_provider: Arc<dyn CompletionProvider> = Arc::new(
OpenAiCompletionProvider::new(
api_url
.clone()
.unwrap_or_else(|| OPEN_AI_API_URL.to_string()),
model.full_name().into(),
cx.background_executor().clone(),
)
Expand Down Expand Up @@ -1567,6 +1588,7 @@ impl Conversation {
token_count: None,
max_token_count: tiktoken_rs::model::get_context_size(&model.full_name()),
pending_token_count: Task::ready(None),
api_url,
model,
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
Expand Down
5 changes: 5 additions & 0 deletions crates/assistant/src/assistant_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub struct AssistantSettings {
pub default_width: Pixels,
pub default_height: Pixels,
pub default_open_ai_model: OpenAiModel,
pub openai_api_url: String,
}

/// Assistant panel settings
Expand All @@ -80,6 +81,10 @@ pub struct AssistantSettingsContent {
///
/// Default: gpt-4-1106-preview
pub default_open_ai_model: Option<OpenAiModel>,
/// OpenAI api base url to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
pub openai_api_url: Option<String>,
}

impl Settings for AssistantSettings {
Expand Down
11 changes: 8 additions & 3 deletions crates/semantic_index/src/semantic_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod semantic_index_tests;

use crate::semantic_index_settings::SemanticIndexSettings;
use ai::embedding::{Embedding, EmbeddingProvider};
use ai::providers::open_ai::OpenAiEmbeddingProvider;
use ai::providers::open_ai::{OpenAiEmbeddingProvider, OPEN_AI_API_URL};
use anyhow::{anyhow, Context as _, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
Expand Down Expand Up @@ -91,8 +91,13 @@ pub fn init(
.detach();

cx.spawn(move |cx| async move {
let embedding_provider =
OpenAiEmbeddingProvider::new(http_client, cx.background_executor().clone()).await;
let embedding_provider = OpenAiEmbeddingProvider::new(
// TODO: We should read it from config, but I'm not sure whether to reuse `openai_api_url` in assistant settings or not
OPEN_AI_API_URL.to_string(),
http_client,
cx.background_executor().clone(),
)
.await;
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
Expand Down
Loading