Skip to content

Commit

Permalink
New QWEN 2 VLM (#3247)
Browse files Browse the repository at this point in the history
  • Loading branch information
teetone authored Jan 5, 2025
1 parent e2e7270 commit ee8cf38
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 8 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ vlm =
# For metrics
pycocoevalcap~=1.2

# For Qwen2
transformers~=4.45.2
qwen-vl-utils~=0.0.8

ibm-enterprise-scenarios =
openpyxl~=3.1

Expand Down
7 changes: 7 additions & 0 deletions src/helm/benchmark/run_spec_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def alter_run_spec(run_spec: RunSpec) -> RunSpec:
):
run_spec = singleton(IncreaseMaxTokensRunExpander(value=1).expand(run_spec))

if model.name == "openai/o1-2024-12-17":
# From https://platform.openai.com/docs/guides/reasoning,
# "OpenAI recommends reserving at least 25,000 tokens for reasoning and outputs when you start
# experimenting with these models. As you become familiar with the number of reasoning tokens your
# prompts require, you can adjust this buffer accordingly."
run_spec = singleton(IncreaseMaxTokensRunExpander(value=25_000).expand(run_spec))

# IDEFICS special handling
if IDEFICS_MODEL_TAG in model.tags:
if IDEFICS_INSTRUCT_MODEL_TAG in model.tags:
Expand Down
4 changes: 2 additions & 2 deletions src/helm/benchmark/static/schema_vhelm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ run_groups:
- accuracy
- general_information
environment:
main_name: exact_match
main_name: quasi_prefix_exact_match
main_split: test
taxonomy:
task: short-answer question answering
Expand Down Expand Up @@ -902,7 +902,7 @@ run_groups:
- accuracy
- general_information
environment:
main_name: exact_match
main_name: quasi_prefix_exact_match
main_split: test
taxonomy:
task: short-answer question answering
Expand Down
20 changes: 14 additions & 6 deletions src/helm/clients/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class OpenAIClient(CachingClient):

# Error OpenAI throws when the image in the prompt violates their content policy
INAPPROPRIATE_IMAGE_ERROR: str = "Your input image may contain content that is not allowed by our safety system"
INAPPROPRIATE_PROMPT_ERROR: str = "Invalid prompt: your prompt was flagged"

# Set the finish reason to this if the prompt violates OpenAI's content policy
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = (
Expand Down Expand Up @@ -171,11 +172,6 @@ def _make_chat_request(self, request: Request) -> RequestResult:
"frequency_penalty": request.frequency_penalty,
}

# OpenAI's vision API doesn't allow None values for stop.
# Fails with "body -> stop: none is not an allowed value" if None is passed.
if is_vlm(request.model) and raw_request["stop"] is None:
raw_request.pop("stop")

# Special handling for o1 models.
# Refer to the "Reasoning models" documentation further discussion of o1 model limitations:
# https://platform.openai.com/docs/guides/reasoning
Expand All @@ -191,6 +187,18 @@ def _make_chat_request(self, request: Request) -> RequestResult:
if raw_request["stop"] is None:
raw_request.pop("stop")

if request.model_engine == "o1-2024-12-17":
# Avoid error:
# "Error code: 400 - {'error': {'message': "Unsupported parameter: 'temperature' is
# not supported with this model.", 'type': 'invalid_request_error', 'param': 'temperature',
# 'code': 'unsupported_parameter'}}"
raw_request.pop("temperature", None)
elif is_vlm(request.model):
# Avoid error:
# "Invalid type for 'stop': expected an unsupported value, but got null instead."
if raw_request["stop"] is None:
raw_request.pop("stop")

# Special handling for gpt-4o-audio-preview
# See: https://platform.openai.com/docs/guides/audio
if request.model_engine.startswith("gpt-4o-audio-preview"):
Expand All @@ -208,7 +216,7 @@ def do_it() -> Dict[str, Any]:
cache_key = self._get_cache_key(raw_request, request)
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
except openai.OpenAIError as e:
if self.INAPPROPRIATE_IMAGE_ERROR in str(e):
if self.INAPPROPRIATE_IMAGE_ERROR in str(e) or self.INAPPROPRIATE_PROMPT_ERROR in str(e):
hlog(f"Failed safety check: {str(request)}")
empty_completion = GeneratedOutput(
text="",
Expand Down
175 changes: 175 additions & 0 deletions src/helm/clients/vision_language/qwen2_vlm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from threading import Lock
from typing import Any, Dict, List, Optional
from dataclasses import dataclass

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch

from helm.common.cache import CacheConfig
from helm.common.gpu_utils import get_torch_device_name
from helm.common.hierarchical_logger import hlog, htrack_block
from helm.common.media_object import TEXT_TYPE
from helm.common.request import Request, RequestResult, GeneratedOutput, Token
from helm.common.request import wrap_request_time
from helm.clients.client import CachingClient, generate_uid_for_multimodal_prompt


@dataclass(frozen=True)
class LoadedQwen2ModelProcessor:
model: Qwen2VLForConditionalGeneration
processor: AutoProcessor


_models_lock: Lock = Lock()
_models: Dict[str, Optional[LoadedQwen2ModelProcessor]] = {
"Qwen/Qwen2-VL-7B-Instruct": None,
"Qwen/Qwen2-VL-72B-Instruct": None,
}


class Qwen2VLMClient(CachingClient):
def __init__(self, cache_config: CacheConfig):
super().__init__(cache_config=cache_config)
self._device: str = get_torch_device_name()

def _get_model_name(self, helm_model_name: str) -> str:
if helm_model_name == "qwen2-vl-7b-instruct":
return "Qwen/Qwen2-VL-7B-Instruct"
elif helm_model_name == "qwen2-vl-72b-instruct":
return "Qwen/Qwen2-VL-72B-Instruct"
else:
raise ValueError(f"Unhandled model name: {helm_model_name}")

def _get_model(self, helm_model_name: str) -> LoadedQwen2ModelProcessor:
global _models_lock
global _models

model_name = self._get_model_name(helm_model_name)

with _models_lock:
loaded = _models[model_name]
if loaded is None:
hlog(f"Loading model {model_name} and caching in memory...")
# https://huggingface.co/docs/transformers/model_doc/qwen2_vl#flash-attention-2-to-speed-up-generation
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2",
).eval()
processor = AutoProcessor.from_pretrained(model_name)
loaded = LoadedQwen2ModelProcessor(model=model, processor=processor)
_models[model_name] = loaded

return loaded

def make_request(self, request: Request) -> RequestResult:
assert request.multimodal_prompt is not None, "Multimodal prompt is required"
loaded = self._get_model(request.model_engine)
model = loaded.model
processor = loaded.processor

# Build Qwen2 messages
# We assume all media objects go into a single "user" message:
# messages = [
# {
# "role": "user",
# "content": [
# {"type": "image", "image": "file:///path/to/image1.jpg"},
# {"type": "image", "image": "file:///path/to/image2.jpg"},
# {"type": "text", "text": "Describe these images."}
# ]
# }
# ]
message_content = []
for media_object in request.multimodal_prompt.media_objects:
if media_object.is_type("image") and media_object.location:
message_content.append({"type": "image", "image": media_object.location})
elif media_object.is_type(TEXT_TYPE):
if media_object.text is None:
raise ValueError("MediaObject of text type has missing text field value")
message_content.append({"type": "text", "text": media_object.text})
else:
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")

messages = [{"role": "user", "content": message_content}]

# Prepare text and vision inputs
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)

inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
).to(self._device)

generation_args = {
"max_new_tokens": request.max_tokens,
}

completions: List[GeneratedOutput] = []
request_time: float = 0
request_datetime: Optional[int] = None
all_cached: bool = True

with htrack_block(f"Generating for prompt: {text}"):
for completion_index in range(request.num_completions):
try:

def do_it() -> Dict[str, Any]:
generated_ids = model.generate(**inputs, **generation_args)
# Remove the input prefix from outputs
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
# There's only one batch element
completion = output_text[0]
# For simplicity, we split tokens by whitespace.
# A more accurate tokenization would require a tokenizer for Qwen2, if desired.
tokens = completion.split()
return {"output": (completion, tokens)}

cache_key = CachingClient.make_cache_key(
raw_request={
"completion_index": completion_index,
"model": request.model,
"prompt": generate_uid_for_multimodal_prompt(request.multimodal_prompt),
**generation_args,
},
request=request,
)
result, cached = self.cache.get(cache_key, wrap_request_time(do_it))
except RuntimeError as model_error:
return RequestResult(
success=False, cached=False, error=str(model_error), completions=[], embedding=[]
)

text_out, tokens = result["output"]
completions.append(
GeneratedOutput(
text=text_out,
logprob=0,
tokens=[Token(text=str(token), logprob=0) for token in tokens],
)
)
hlog(f"Generated: {text_out}")

request_time += result["request_time"]
request_datetime = request_datetime or result.get("request_datetime")
all_cached = all_cached and cached

return RequestResult(
success=True,
cached=all_cached,
request_time=request_time,
request_datetime=request_datetime,
completions=completions,
embedding=[],
)
14 changes: 14 additions & 0 deletions src/helm/config/model_deployments.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2733,6 +2733,20 @@ model_deployments:
client_spec:
class_name: "helm.clients.vision_language.qwen_vlm_client.QwenVLMClient"

- name: huggingface/qwen2-vl-7b-instruct
model_name: qwen/qwen2-vl-7b-instruct
tokenizer_name: qwen/qwen-vl-chat
max_sequence_length: 8191
client_spec:
class_name: "helm.clients.vision_language.qwen2_vlm_client.Qwen2VLMClient"

- name: huggingface/qwen2-vl-72b-instruct
model_name: qwen/qwen2-vl-72b-instruct
tokenizer_name: qwen/qwen-vl-chat
max_sequence_length: 8191
client_spec:
class_name: "helm.clients.vision_language.qwen2_vlm_client.Qwen2VLMClient"

- name: huggingface/qwen-audio-chat
model_name: qwen/qwen-audio-chat
tokenizer_name: qwen/qwen-audio-chat
Expand Down
16 changes: 16 additions & 0 deletions src/helm/config/model_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2827,6 +2827,22 @@ models:
release_date: 2023-08-24
tags: [VISION_LANGUAGE_MODEL_TAG, FULL_FUNCTIONALITY_VLM_TAG]

- name: qwen/qwen2-vl-7b-instruct
display_name: Qwen2-VL Instruct (7B)
description: The second generation of Qwen2-VL models ([paper](https://arxiv.org/abs/2409.12191)).
creator_organization_name: Alibaba Group
access: open
release_date: 2024-08-29
tags: [VISION_LANGUAGE_MODEL_TAG, FULL_FUNCTIONALITY_VLM_TAG]

- name: qwen/qwen2-vl-72b-instruct
display_name: Qwen2-VL Instruct (72B)
description: The second generation of Qwen2-VL models ([paper](https://arxiv.org/abs/2409.12191)).
creator_organization_name: Alibaba Group
access: open
release_date: 2024-08-29
tags: [VISION_LANGUAGE_MODEL_TAG, FULL_FUNCTIONALITY_VLM_TAG]

- name: qwen/qwen-audio-chat
display_name: Qwen-Audio Chat
description: Auditory multimodal version of the Qwen large language model series ([paper](https://arxiv.org/abs/2311.07919)).
Expand Down

0 comments on commit ee8cf38

Please sign in to comment.