Skip to content

Commit

Permalink
Add max_output_tokens to OpenAI models and integrate into requests (z…
Browse files Browse the repository at this point in the history
…ed-industries#16381)

### Pull Request Title
Introduce `max_output_tokens` Field for OpenAI Models


https://platform.deepseek.com/api-docs/news/news0725/#4-8k-max_tokens-betarelease-longer-possibilities

### Description
This commit introduces a new field `max_output_tokens` to the OpenAI
models, which allows specifying the maximum number of tokens that can be
generated in the output. This field is now integrated into the request
handling across multiple crates, ensuring that the output token limit is
respected during language model completions.

Changes include:
- Adding `max_output_tokens` to the `Custom` variant of the
`open_ai::Model` enum.
- Updating the `into_open_ai` method in `LanguageModelRequest` to accept
and use `max_output_tokens`.
- Modifying the `OpenAiLanguageModel` and `CloudLanguageModel`
implementations to pass `max_output_tokens` when converting requests.
- Ensuring that the `max_output_tokens` field is correctly serialized
and deserialized in relevant structures.

This enhancement provides more control over the output length of OpenAI
model responses, improving the flexibility and accuracy of language
model interactions.

### Changes
- Added `max_output_tokens` to the `Custom` variant of the
`open_ai::Model` enum.
- Updated the `into_open_ai` method in `LanguageModelRequest` to accept
and use `max_output_tokens`.
- Modified the `OpenAiLanguageModel` and `CloudLanguageModel`
implementations to pass `max_output_tokens` when converting requests.
- Ensured that the `max_output_tokens` field is correctly serialized and
deserialized in relevant structures.

### Related Issue
zed-industries#16358

### Screenshots / Media
N/A

### Checklist
- [x] Code compiles correctly.
- [x] All tests pass.
- [ ] Documentation has been updated accordingly.
- [ ] Additional tests have been added to cover new functionality.
- [ ] Relevant documentation has been updated or added.

### Release Notes

- Added `max_output_tokens` field to OpenAI models for controlling
output token length.
  • Loading branch information
Cupnfish authored Aug 21, 2024
1 parent 36d51fe commit f1778dd
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 15 deletions.
4 changes: 2 additions & 2 deletions crates/assistant/src/assistant_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ impl AssistantSettingsContent {
models
.into_iter()
.filter_map(|model| match model {
open_ai::Model::Custom { name, max_tokens } => {
Some(language_model::provider::open_ai::AvailableModel { name, max_tokens })
open_ai::Model::Custom { name, max_tokens,max_output_tokens } => {
Some(language_model::provider::open_ai::AvailableModel { name, max_tokens,max_output_tokens })
}
_ => None,
})
Expand Down
10 changes: 6 additions & 4 deletions crates/language_model/src/provider/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
}),
AvailableProvider::Google => CloudModel::Google(google_ai::Model::Custom {
name: model.name.clone(),
Expand Down Expand Up @@ -513,7 +514,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into());
let request = request.into_open_ai(model.id().into(), model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let response = Self::perform_llm_completion(
Expand Down Expand Up @@ -557,7 +558,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::Zed(model) => {
let client = self.client.clone();
let mut request = request.into_open_ai(model.id().into());
let mut request = request.into_open_ai(model.id().into(), None);
request.max_tokens = Some(4000);
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
Expand Down Expand Up @@ -629,7 +630,8 @@ impl LanguageModel for CloudLanguageModel {
.boxed()
}
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
let mut request =
request.into_open_ai(model.id().into(), model.max_output_tokens());
request.tool_choice = Some(open_ai::ToolChoice::Other(
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
Expand Down Expand Up @@ -676,7 +678,7 @@ impl LanguageModel for CloudLanguageModel {
}
CloudModel::Zed(model) => {
// All Zed models are OpenAI-based at the time of writing.
let mut request = request.into_open_ai(model.id().into());
let mut request = request.into_open_ai(model.id().into(), None);
request.tool_choice = Some(open_ai::ToolChoice::Other(
open_ai::ToolDefinition::Function {
function: open_ai::FunctionDefinition {
Expand Down
10 changes: 8 additions & 2 deletions crates/language_model/src/provider/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct OpenAiSettings {
pub struct AvailableModel {
pub name: String,
pub max_tokens: usize,
pub max_output_tokens: Option<u32>,
}

pub struct OpenAiLanguageModelProvider {
Expand Down Expand Up @@ -170,6 +171,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
open_ai::Model::Custom {
name: model.name.clone(),
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
},
);
}
Expand Down Expand Up @@ -275,6 +277,10 @@ impl LanguageModel for OpenAiLanguageModel {
self.model.max_token_count()
}

fn max_output_tokens(&self) -> Option<u32> {
self.model.max_output_tokens()
}

fn count_tokens(
&self,
request: LanguageModelRequest,
Expand All @@ -288,7 +294,7 @@ impl LanguageModel for OpenAiLanguageModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let request = request.into_open_ai(self.model.id().into());
let request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
}
Expand All @@ -301,7 +307,7 @@ impl LanguageModel for OpenAiLanguageModel {
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let mut request = request.into_open_ai(self.model.id().into());
let mut request = request.into_open_ai(self.model.id().into(), self.max_output_tokens());
request.tool_choice = Some(ToolChoice::Other(ToolDefinition::Function {
function: FunctionDefinition {
name: tool_name.clone(),
Expand Down
4 changes: 2 additions & 2 deletions crates/language_model/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ pub struct LanguageModelRequest {
}

impl LanguageModelRequest {
pub fn into_open_ai(self, model: String) -> open_ai::Request {
pub fn into_open_ai(self, model: String, max_output_tokens: Option<u32>) -> open_ai::Request {
open_ai::Request {
model,
messages: self
Expand All @@ -251,7 +251,7 @@ impl LanguageModelRequest {
stream: true,
stop: self.stop,
temperature: self.temperature,
max_tokens: None,
max_tokens: max_output_tokens,
tools: Vec::new(),
tool_choice: None,
}
Expand Down
12 changes: 9 additions & 3 deletions crates/language_model/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,15 @@ impl OpenAiSettingsContent {
models
.into_iter()
.filter_map(|model| match model {
open_ai::Model::Custom { name, max_tokens } => {
Some(provider::open_ai::AvailableModel { name, max_tokens })
}
open_ai::Model::Custom {
name,
max_tokens,
max_output_tokens,
} => Some(provider::open_ai::AvailableModel {
name,
max_tokens,
max_output_tokens,
}),
_ => None,
})
.collect()
Expand Down
21 changes: 19 additions & 2 deletions crates/open_ai/src/open_ai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ pub enum Model {
#[serde(rename = "gpt-4o-mini", alias = "gpt-4o-mini-2024-07-18")]
FourOmniMini,
#[serde(rename = "custom")]
Custom { name: String, max_tokens: usize },
Custom {
name: String,
max_tokens: usize,
max_output_tokens: Option<u32>,
},
}

impl Model {
Expand Down Expand Up @@ -113,6 +117,19 @@ impl Model {
Self::Custom { max_tokens, .. } => *max_tokens,
}
}

pub fn max_output_tokens(&self) -> Option<u32> {
match self {
Self::ThreePointFiveTurbo => Some(4096),
Self::Four => Some(8192),
Self::FourTurbo => Some(4096),
Self::FourOmni => Some(4096),
Self::FourOmniMini => Some(16384),
Self::Custom {
max_output_tokens, ..
} => *max_output_tokens,
}
}
}

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -121,7 +138,7 @@ pub struct Request {
pub messages: Vec<RequestMessage>,
pub stream: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
pub max_tokens: Option<u32>,
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
Expand Down

0 comments on commit f1778dd

Please sign in to comment.