-
Notifications
You must be signed in to change notification settings - Fork 365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: use hf chat support #1047
Changes from all commits
87ab1ca
9449258
9ec5794
99c760a
900fc8f
86de116
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,6 +79,9 @@ def _load_client(self): | |
set_seed(_config.run.seed) | ||
|
||
pipeline_kwargs = self._gather_hf_params(hf_constructor=pipeline) | ||
pipeline_kwargs["truncation"] = ( | ||
True # this is forced to maintain existing pipeline expectations | ||
) | ||
self.generator = pipeline("text-generation", **pipeline_kwargs) | ||
if self.generator.tokenizer is None: | ||
# account for possible model without a stored tokenizer | ||
|
@@ -87,6 +90,11 @@ def _load_client(self): | |
self.generator.tokenizer = AutoTokenizer.from_pretrained( | ||
pipeline_kwargs["model"] | ||
) | ||
if not hasattr(self, "use_chat"): | ||
self.use_chat = ( | ||
hasattr(self.generator.tokenizer, "chat_template") | ||
and self.generator.tokenizer.chat_template is not None | ||
) | ||
if not hasattr(self, "deprefix_prompt"): | ||
self.deprefix_prompt = self.name in models_to_deprefix | ||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if _config.loaded: | ||
|
@@ -98,6 +106,9 @@ def _load_client(self): | |
def _clear_client(self): | ||
self.generator = None | ||
|
||
def _format_chat_prompt(self, prompt: str) -> List[dict]: | ||
return [{"role": "user", "content": prompt}] | ||
|
||
def _call_model( | ||
self, prompt: str, generations_this_call: int = 1 | ||
) -> List[Union[str, None]]: | ||
|
@@ -106,13 +117,16 @@ def _call_model( | |
warnings.simplefilter("ignore", category=UserWarning) | ||
try: | ||
with torch.no_grad(): | ||
# workaround for pipeline to truncate the input | ||
encoded_prompt = self.generator.tokenizer(prompt, truncation=True) | ||
truncated_prompt = self.generator.tokenizer.decode( | ||
encoded_prompt["input_ids"], skip_special_tokens=True | ||
) | ||
# according to docs https://huggingface.co/docs/transformers/main/en/chat_templating | ||
# chat template should be automatically utilized if the pipeline tokenizer has support | ||
# and a properly formatted list[dict] is supplied | ||
if self.use_chat: | ||
formatted_prompt = self._format_chat_prompt(prompt) | ||
else: | ||
formatted_prompt = prompt | ||
Comment on lines
+123
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note for future Erick and Jeffrey: we will want to merge |
||
|
||
raw_output = self.generator( | ||
truncated_prompt, | ||
formatted_prompt, | ||
pad_token_id=self.generator.tokenizer.eos_token_id, | ||
max_new_tokens=self.max_tokens, | ||
num_return_sequences=generations_this_call, | ||
|
@@ -127,10 +141,15 @@ def _call_model( | |
i["generated_text"] for i in raw_output | ||
] # generator returns 10 outputs by default in __init__ | ||
|
||
if self.use_chat: | ||
text_outputs = [_o[-1]["content"].strip() for _o in outputs] | ||
else: | ||
text_outputs = outputs | ||
|
||
Comment on lines
+144
to
+148
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment not for this PR: there are multiple ways of representing conversations in garak, and attempt should be canonical for conversation history; I'm starting to consider patterns where attempt holds the conversation history still but where this can be read & written using other interfaces, like an OpenAI API messages dict list, or this HF style format. But mayeb the work is so simple that this current pattern works fine. |
||
if not self.deprefix_prompt: | ||
return outputs | ||
return text_outputs | ||
else: | ||
return [re.sub("^" + re.escape(prompt), "", _o) for _o in outputs] | ||
return [re.sub("^" + re.escape(prompt), "", _o) for _o in text_outputs] | ||
|
||
|
||
class OptimumPipeline(Pipeline, HFCompatible): | ||
|
@@ -468,6 +487,13 @@ def _load_client(self): | |
self.name, padding_side="left" | ||
) | ||
|
||
if not hasattr(self, "use_chat"): | ||
# test tokenizer for `apply_chat_template` support | ||
self.use_chat = ( | ||
hasattr(self.tokenizer, "chat_template") | ||
and self.tokenizer.chat_template is not None | ||
) | ||
|
||
self.generation_config = transformers.GenerationConfig.from_pretrained( | ||
self.name | ||
) | ||
|
@@ -492,14 +518,27 @@ def _call_model( | |
if self.top_k is not None: | ||
self.generation_config.top_k = self.top_k | ||
|
||
text_output = [] | ||
raw_text_output = [] | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore", category=UserWarning) | ||
with torch.no_grad(): | ||
if self.use_chat: | ||
formatted_prompt = self.tokenizer.apply_chat_template( | ||
self._format_chat_prompt(prompt), | ||
tokenize=False, | ||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||
add_generation_prompt=True, | ||
) | ||
else: | ||
formatted_prompt = prompt | ||
|
||
inputs = self.tokenizer( | ||
prompt, truncation=True, return_tensors="pt" | ||
formatted_prompt, truncation=True, return_tensors="pt" | ||
).to(self.device) | ||
jmartin-tech marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
prefix_prompt = self.tokenizer.decode( | ||
inputs["input_ids"][0], skip_special_tokens=True | ||
) | ||
|
||
try: | ||
outputs = self.model.generate( | ||
**inputs, generation_config=self.generation_config | ||
|
@@ -512,14 +551,22 @@ def _call_model( | |
return returnval | ||
else: | ||
raise e | ||
text_output = self.tokenizer.batch_decode( | ||
raw_text_output = self.tokenizer.batch_decode( | ||
outputs, skip_special_tokens=True, device=self.device | ||
) | ||
|
||
if self.use_chat: | ||
text_output = [ | ||
re.sub("^" + re.escape(prefix_prompt), "", i).strip() | ||
for i in raw_text_output | ||
] | ||
else: | ||
text_output = raw_text_output | ||
Comment on lines
+558
to
+564
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Part of me (albeit a part with limited convictions) feels like there HAS to be a better way to handle this. HF seems to be managing it internally with their There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. idk man the format uses Another option could be to do My experience also is that things like
|
||
|
||
if not self.deprefix_prompt: | ||
return text_output | ||
else: | ||
return [re.sub("^" + re.escape(prompt), "", i) for i in text_output] | ||
return [re.sub("^" + re.escape(prefix_prompt), "", i) for i in text_output] | ||
|
||
|
||
class LLaVA(Generator, HFCompatible): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sometimes I wish this were just the default for HF but c'est la vie. We may want to consider whether
truncation = True
requiresmax_len
to be set -- I've had HF yell at me for not specifying both but it may be a corner case that I encountered.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HF is kinda yell-y, and over the course of garak dev, HF has varied what it yells about. It can also be the case that some models require certain param combos to operate, while others will find the same param combo utterly intolerable. This has led to a style where one tries to do the right thing in garak, and tries to listen to HF warnings a little less.