-
Notifications
You must be signed in to change notification settings - Fork 364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
estimate token use before sending openai completions #1112
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
import json | ||
import logging | ||
import re | ||
import tiktoken | ||
from typing import List, Union | ||
|
||
import openai | ||
|
@@ -223,6 +224,34 @@ def _call_model( | |
if hasattr(self, arg) and arg not in self.suppressed_params: | ||
create_args[arg] = getattr(self, arg) | ||
|
||
# basic token boundary validation to ensure requests are not rejected for exceeding target context length | ||
generation_max_tokens = create_args.get("max_tokens", None) | ||
if generation_max_tokens is not None: | ||
# count tokens in prompt and ensure max_tokens requested is <= context_len allowed | ||
if ( | ||
hasattr(self, "context_len") | ||
and self.context_len is not None | ||
and generation_max_tokens > self.context_len | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some models return the prompt in their output. In these cases, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
): | ||
logging.warning( | ||
f"Requested max_tokens {generation_max_tokens} exceeds context length {self.context_len}, reducing requested maximum" | ||
) | ||
generation_max_tokens = self.context_len | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this disregards There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is based on observed behavior of the OpenAI endpoints, this attempts to ensure a valid request can be made. OpenAI services are setting an upper bound on Hence if we know enough about the target model in this runtime we can make a best effort estimate to avoid bashing against a brick wall making requests we can predict will return no valid inference response. If the runtime does not know the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, that's crucial, good to know. We should document this here with reference to an OpenAI uri. Is variable name usage consistent with elsewhere in garak? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
As to documenting I could see adding some context about the assumptions made here to being based on OpenAI API spec. As a future iteration it may be of value to evaluate if shifting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we're suffering from an overloading of With:
Is this saying
? If so - can you run through the logic behind this in simple, verbose, explicit terms? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When making a request to OpenAI, the If the user configures
Now a scenario where the model has plenty of context length such as
This constrains Another possible approach could be to simply not pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Alright, I am going to have to take a moment with the API guide to get on top of this. OpenAI model input capacity must be greater than prompt length and output length? Examples make a ton of sense, thanks. This looks like a really helpful PR/feature. -- I think the results might just be a few variable renaming suggestions. Will get back to this within a day or two. |
||
prompt_tokens = 0 | ||
try: | ||
encoding = tiktoken.encoding_for_model(self.name) | ||
prompt_tokens = len(encoding.encode(prompt)) | ||
except KeyError as e: | ||
prompt_tokens = int( | ||
len(prompt.split()) * 4 / 3 | ||
) # extra naive fallback 1 token ~= 3/4 of a word | ||
generation_max_tokens -= prompt_tokens | ||
create_args["max_tokens"] = generation_max_tokens | ||
if generation_max_tokens < 1: # allow at least a binary result token | ||
raise garak.exception.GarakException( | ||
"A response cannot be created within the available context length" | ||
) | ||
|
||
if self.generator == self.client.completions: | ||
if not isinstance(prompt, str): | ||
msg = ( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_tokens
andcontext_len
are only related ifdeprefix
is assertedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpenAI client
create()
does not acceptdeprefix
as a named param and will not be passed by the generator call. If future support for passingdeprefix
in some way is added to the generator in the future we can rethink this calculation.