Skip to content

Commit

Permalink
added new dragonfly files
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Thapa committed Oct 10, 2024
1 parent 2e814a9 commit 6ed24a3
Show file tree
Hide file tree
Showing 15 changed files with 78 additions and 254 deletions.
35 changes: 15 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@

## 🔥 News
- [Our paper](https://arxiv.org/abs/2406.00977) is out on arxiv.
- Check out [our blogpost](https://www.together.ai/blog/dragonfly-v1).
- Our model checkpoints are out on huggingface 🤗 🚀:
- General: [`togethercomputer/Llama-3-8B-Dragonfly-v1`](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-v1)
- Biomed: [`togethercomputer/Llama-3-8B-Dragonfly-Med-v1`](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-Med-v1)
- General: [`togethercomputer/Llama-3.1-8B-Dragonfly-v1`](https://huggingface.co/togethercomputer/Llama-3.1-8B-Dragonfly-v1)
- Biomed: [`togethercomputer/Llama-3.1-8B-Dragonfly-Med-v1`](https://huggingface.co/togethercomputer/Llama-3.1-8B-Dragonfly-Med-v1)


## 📖 Introduction

![Dragonfly framework](assets/model_overview.png)

Recent advances in large multimodal models (LMMs) suggest that higher image resolution enhances the fine-grained understanding of image details, crucial for tasks such as visual commonsense reasoning and analyzing biomedical images. However, increasing input resolution poses two main challenges: 1) It extends the context length required by the language model, leading to inefficiencies and hitting the model's context limit; 2) It increases the complexity of visual features, necessitating more training data or more complex architecture. We introduce Dragonfly, a new LMM architecture that enhances fine-grained visual understanding and reasoning about image regions to address these challenges. Dragonfly employs two key strategies: multi-resolution visual encoding and zoom-in patch selection. These strategies allow the model to process high-resolution images efficiently while maintaining reasonable context length. Our experiments on eight popular benchmarks demonstrate that Dragonfly achieves competitive or better performance compared to other architectures, highlighting the effectiveness of our design. Additionally, we finetuned Dragonfly on biomedical instructions, achieving state-of-the-art results on multiple biomedical tasks requiring fine-grained visual understanding, including 92.3% accuracy on the Path-VQA dataset (compared to 83.3% for Med-Gemini) and the highest reported results on biomedical image captioning. To support model training, we curated a visual instruction-tuning dataset with 5.5 million image-instruction samples in the general domain and 1.4 million samples in the biomedical domain. We also conducted ablation studies to characterize the impact of various architectural designs and image resolutions, providing insights for future research on visual instruction alignment.
Recent advances in vision-language models (VLMs) have demonstrated the advantages of processing images at higher resolutions and utilizing multi-crop features to preserve native resolution details. However, existing vision transformers (ViTs) often struggle to capture fine-grained details from less prominent objects, charts, and embedded text, limiting their effectiveness in certain tasks. In this paper, we go beyond recent high-resolution and multi-crop techniques by not only preserving the native resolution but also zooming in beyond it and extracting features from a large number of image sub-crops. This enhancement allows our model to better capture fine-grained details, overcoming the limitations of current ViTs. To manage the increased token count and computational complexity, we demonstrate that a simple mean-pooling aggregation over tokens is effective. Our model, Dragonfly, achieves competitive performance on general-domain tasks such as ScienceQA and AI2D, and excels in tasks requiring fine-grained image understanding, including TextVQA and ChartQA. On average, Dragonfly ranks at the top across ten general-domain benchmarks, outperforming models that are significantly larger or trained on much larger datasets. Our biomedical version, Dragonfly-Med, sets new benchmarks on several medical tasks, achieving 91.6\% accuracy on SLAKE (compared to 84.8\% for Med-Gemini), a 67.1\% token F1 score on Path-VQA (compared to 62.7\% for Med-PaLM M), and attains state-of-the-art results across the majority of performance metrics. Overall, our work establishes a new paradigm for extracting high-resolution fine-grained features from images, significantly enhancing the capabilities of VLMs in both general and specialized domains.


# 📖 Table of Contents
Expand Down Expand Up @@ -54,7 +53,7 @@ pip install --upgrade -e .

*Note: These models are released under [Llama 3 Community License Agreement](LICENSE)*

We release two huggingface model checkpoints: [`togethercomputer/Llama-3-8B-Dragonfly-v1`](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-v1) and [`togethercomputer/Llama-3-8B-Dragonfly-Med-v1`](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-Med-v1). Please follow the script [`test_dragonfly.py`](test_dragonfly.py) for more details. We provide a brief description on how to use them below.
We release two huggingface model checkpoints: [`togethercomputer/Llama-3.1-8B-Dragonfly-v1`](https://huggingface.co/togethercomputer/Llama-3.1-8B-Dragonfly-v1) and [`togethercomputer/Llama-3.1-8B-Dragonfly-Med-v1`](https://huggingface.co/togethercomputer/Llama-3.1-8B-Dragonfly-Med-v1). Please follow the script [`test_dragonfly.py`](test_dragonfly.py) for more details. We provide a brief description on how to use them below.

<a name="inference"/>

Expand All @@ -64,9 +63,9 @@ If you have successfully completed the [Installation](#installation) process, th

We provide two test examples inside [`test_images`](test_images).

Question: Summarize the visual content of the image.
Question: What is so funny about this image?

![Skateboard](test_images/skateboard.png)
![Skateboard](test_images/monalisa_dog.jpg)

Load necessary packages
```python
Expand All @@ -83,26 +82,26 @@ Instantiate the tokenizer, processor, and model.
```python
device = torch.device("cuda:0")

tokenizer = AutoTokenizer.from_pretrained("togethercomputer/Llama-3-8B-Dragonfly-v1")
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = AutoTokenizer.from_pretrained("togethercomputer/Llama-3.1-8B-Dragonfly-v1")
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
image_processor = clip_processor.image_processor
processor = DragonflyProcessor(image_processor=image_processor, tokenizer=tokenizer, image_encoding_style="llava-hd")

model = DragonflyForCausalLM.from_pretrained("togethercomputer/Llama-3-8B-Dragonfly-v1")
model = DragonflyForCausalLM.from_pretrained("togethercomputer/Llama-3.1-8B-Dragonfly-v1")
model = model.to(torch.bfloat16)
model = model.to(device)
```

Now, lets load the image and process them.
```python
image = Image.open("./test_images/skateboard.png")
image = Image.open("./test_images/monalisa_dog.jpg")
image = image.convert("RGB")
images = [image]
# images = [None] # if you do not want to pass any images

text_prompt = "<|start_header_id|>user<|end_header_id|>\n\nSummarize the visual content of the image.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

inputs = processor(text=[text_prompt], images=images, max_length=2048, return_tensors="pt", is_generate=True)
inputs = processor(text=[text_prompt], images=images, max_length=4096, return_tensors="pt", is_generate=True)
inputs = inputs.to(device)
```

Expand All @@ -118,11 +117,7 @@ generation_text = processor.batch_decode(generation_output, skip_special_tokens=

An example response.
```plaintext
In the heart of a vibrant skatepark, a skateboarder is caught in a moment of pure exhilaration. The skateboarder, dressed in a black t-shirt adorned with a yellow graphic and black pants, is suspended in mid-air, performing an impressive trick on a concrete ramp. The skateboarder's arms are outstretched, adding balance to the daring stunt.
The skatepark itself is a concrete playground, with the skateboarder's ramp being the main focus. In the background, palm trees sway gently, adding a touch of nature to the urban setting. A few spectators can be seen in the distance, their attention riveted on the airborne skateboarder.
The image captures not just a moment, but a story of skill, courage, and the joy of skateboarding.<|eot_id|>
The humor in this image comes from the surreal juxtaposition of a dog's face with the body of the Mona Lisa, a famous painting by Leonardo da Vinci. The Mona Lisa is known for her enigmatic smile and is often considered one of the most famous paintings in the world. By combining the dog's face with the body of the Mona Lisa, the artist has created a whimsical and amusing image that plays on the viewer 's expectations and familiarity with the original paintings. The contrast between the dog's natural, expressive features and the serene, mysterious expression of the Mona Lisa creates a humerous effect that is likely to elicit laughter<|eot_id|>
```

<a name="dataset"/>
Expand Down Expand Up @@ -179,7 +174,7 @@ Describe the content in the image.<|eot_id|><|start_header_id|>assistant<|end_he
We would like to acknowledge the following resources that were instrumental in the development of Dragonfly:

- [Meta Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B): We utilized the Llama 3 model as our foundational language model.
- [CLIP](https://huggingface.co/openai/clip-vit-base-patch32): Our vision backbone is CLIP model from OpenAI.
- [CLIP](https://huggingface.co/openai/clip-vit-large-patch14-336): Our vision backbone is CLIP model from OpenAI.
- Our codebase is built upon the following two codebases:
- [Otter: A Multi-Modal Model with In-Context Instruction Tuning](https://github.com/Luodian/Otter)
- [LLaVA-UHD: an LMM Perceiving Any Aspect Ratio and High-Resolution Images](https://github.com/thunlp/LLaVA-UHD)
Expand All @@ -190,8 +185,8 @@ We would like to acknowledge the following resources that were instrumental in t

```bibtex
@misc{chen2024dragonfly,
title={Dragonfly: Multi-Resolution Zoom Supercharges Large Visual-Language Model},
author={Kezhen Chen and Rahul Thapa and Rahul Chalamala and Ben Athiwaratkun and Shuaiwen Leon Song and James Zou},
title={Dragonfly: Multi-Resolution Zoom-In Encoding Enhances Vision-Language Models},
author={Rahul Thapa and Kezhen Chen and Ian Covert and Rahul Chalamala and Ben Athiwaratkun and Shuaiwen Leon Song and James Zou},
year={2024},
eprint={2406.00977},
archivePrefix={arXiv},
Expand Down
Binary file removed assets/dragonfly_icon.png
Binary file not shown.
Binary file modified assets/model_overview.pdf
Binary file not shown.
Binary file removed assets/model_overview.png
Binary file not shown.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: dragonfly_env
channels:
- defaults
dependencies:
- python=3.9
- python=3.10
- conda-forge::openjdk
- pip
- pip:
Expand Down
7 changes: 0 additions & 7 deletions pipeline/data_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@
USER_AGENT = get_datasets_user_agent()

Image.MAX_IMAGE_PIXELS = 1000000000
MAX_NUM_TOKENS = 256
MAX_NUM_IMAGES = 5
TINY_IMAGE_SIZE_THRESHOLD = 1
NUM_BACKUP_SPLIT = 5000
N_CHANNELS = 3
INTERLEAVED_IMAGE_SIZE = 224
MIN_KB = 10

IMAGE_CAP_INSTRUCT = [
"Analyze the image in a comprehensive and detailed manner.",
Expand Down
3 changes: 3 additions & 0 deletions pipeline/train/training.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def train_one_epoch(
total=total_training_steps,
initial=current_global_steps,
):

data_time_m.update(time.time() - end)
global_step = num_steps + current_global_steps

Expand Down Expand Up @@ -423,6 +424,8 @@ def mask_embedding(m):
lr_scheduler.step()
optimizer.zero_grad()

# print(f"Step 3: Beginning Step: {num_steps}; Global Step: {global_step}")

# step time and reset end outside of rank 0
step_time_m.update(time.time() - end)
end = time.time()
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
package_dir={"": "src"},
install_requires=requirements,
author="Together AI",
author_email="kezhen@together.ai",
description="Dragonfly: Multi-Resolution Zoom Supercharges Large Visual-Language Model",
author_email="rthapa84@stanford.edu",
description="Dragonfly: Multi-Resolution Zoom-In Encoding Enhances Vision-Language Models",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
url="https://github.com/togethercomputer/Dragonfly",
Expand Down
77 changes: 18 additions & 59 deletions src/dragonfly/models/modeling_dragonfly.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,9 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
topk: Optional[int] = 3,
topk: Optional[int] = 5,
region_token_interval: Optional[int] = 6,
steps=100,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -277,67 +278,25 @@ def forward(
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:

# first, lets extract the high resolution image patches and generate embeddings from them using the image encoder
# we also project them through the projection layer
query_hd_patches = [full_img_patches[5:] for full_img_patches in image_patches]
query_outputs = [self.image_encoder(patch_pixel_values.to(self.vision_embed_tokens.weight.dtype), output_hidden_states=True) for patch_pixel_values in query_hd_patches]
query_image_patches = [item.hidden_states[-2] for item in query_outputs]
query_image_patches = [self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)) for patch in query_image_patches]
query_ranks = [torch.mean(query_item, 1) for query_item in query_image_patches]

# now, lets extract the low and mid resolution image patches and generate embeddings from them using the image encoder
# we similarly project them through the projection layer
image_patches = [full_img_patches[:5] for full_img_patches in image_patches]
# breakpoint()

ie_outputs = [self.image_encoder(patch_pixel_values.to(self.vision_embed_tokens.weight.dtype), output_hidden_states=True).hidden_states[-2] for patch_pixel_values in image_patches]
ie_outputs = [self.vision_embed_tokens(patch_pixel_values.to(self.vision_embed_tokens.weight.dtype)) for patch_pixel_values in ie_outputs]

"""Now, for each mid resolution region, we select the top k high resolution regions using
the dot product of the region embeddings with the query embeddings. The query embeddings
are the mean of the high resolution region embeddings and the region embeddings are the
mean of the mid resolution region embeddings
"""

# region 1
abstract_query1 = [torch.mean(patch_item[1], 0) for patch_item in ie_outputs]
query_ranks1 = [torch.matmul(qr, abstract.unsqueeze(-1)).squeeze(-1) for qr, abstract in zip(query_ranks, abstract_query1)]
query_ranks_mask1 = torch.zeros(query_ranks[0].size()[:-1]).to(query_ranks[0].device)
query_ranks_mask1.scatter_(0, torch.arange(region_token_interval, region_token_interval * 4).to(query_ranks[0].device), float("-inf"))
query_ranks1 = [T + query_ranks_mask1 for T in query_ranks1]
query_ranks1 = [torch.topk(item, topk).indices for item in query_ranks1]

# region 2
abstract_query2 = [torch.mean(patch_item[2], 0) for patch_item in ie_outputs]
query_ranks2 = [torch.matmul(qr, abstract.unsqueeze(-1)).squeeze(-1) for qr, abstract in zip(query_ranks, abstract_query2)]
query_ranks_mask2 = torch.zeros(query_ranks[0].size()[:-1]).to(query_ranks[0].device)
query_ranks_mask2.scatter_(0, torch.concat([torch.arange(0, region_token_interval).to(query_ranks[0].device), torch.arange(region_token_interval * 2, region_token_interval * 4).to(query_ranks[0].device)]), float("-inf"))
query_ranks2 = [T + query_ranks_mask2 for T in query_ranks2]
query_ranks2 = [torch.topk(item, topk).indices for item in query_ranks2]

# region 3
abstract_query3 = [torch.mean(patch_item[3], 0) for patch_item in ie_outputs]
query_ranks3 = [torch.matmul(qr, abstract.unsqueeze(-1)).squeeze(-1) for qr, abstract in zip(query_ranks, abstract_query3)]
query_ranks_mask3 = torch.zeros(query_ranks[0].size()[:-1]).to(query_ranks[0].device) # add
query_ranks_mask3.scatter_(
0, torch.concat([torch.arange(0, region_token_interval * 2).to(query_ranks[0].device), torch.arange(region_token_interval * 3, region_token_interval * 4).to(query_ranks[0].device)]), float("-inf")
) # add
query_ranks3 = [T + query_ranks_mask3 for T in query_ranks3] # add
query_ranks3 = [torch.topk(item, topk).indices for item in query_ranks3]

# region 4
abstract_query4 = [torch.mean(patch_item[4], 0) for patch_item in ie_outputs]
query_ranks4 = [torch.matmul(qr, abstract.unsqueeze(-1)).squeeze(-1) for qr, abstract in zip(query_ranks, abstract_query4)]
query_ranks_mask4 = torch.zeros(query_ranks[0].size()[:-1]).to(query_ranks[0].device) # add
query_ranks_mask4.scatter_(0, torch.arange(0, region_token_interval * 3).to(query_ranks[0].device), float("-inf")) # add
query_ranks4 = [T + query_ranks_mask4 for T in query_ranks4] # add
query_ranks4 = [torch.topk(item, topk).indices for item in query_ranks4]

# Construct visual encoding
query_ranks = [torch.concat([q1, q2, q3, q4]) for q1, q2, q3, q4 in zip(query_ranks1, query_ranks2, query_ranks3, query_ranks4)]
query_ranks = [torch.sort(item).values for item in query_ranks]
selected_image_patches = [torch.index_select(q_image_patches, 0, q_ranks) for (q_image_patches, q_ranks) in zip(query_image_patches, query_ranks)]

# concat
patch_embeddings = [torch.concat([ie_output.view(-1, self.config.hidden_size), s_image_patches.view(-1, self.config.hidden_size)]) for ie_output, s_image_patches in zip(ie_outputs, selected_image_patches)]
low_res_embeddings = [ie_output[:1] for ie_output in ie_outputs]
high_res_embeddings = [ie_output[1:] for ie_output in ie_outputs]

high_patch_embeddings = []
for ie_output in high_res_embeddings:
cls_token = ie_output[:, :1, :] # Shape: [B, 1, 1024]
img_tokens = ie_output[:, 1:, :] # Shape: [B, 576, 1024]
img_tokens_reshaped = img_tokens.view(ie_output.size(0), 24, 24, -1)
img_tokens_pooled = F.avg_pool2d(img_tokens_reshaped.permute(0, 3, 1, 2), kernel_size=4, stride=4) # Shape: [B, 1024, 6, 6]
img_tokens_pooled = img_tokens_pooled.permute(0, 2, 3, 1).view(ie_output.size(0), -1, ie_output.size(2)) # Shape: [B, 36, 1024]
output = torch.cat([cls_token, img_tokens_pooled], dim=1)
high_patch_embeddings.append(output)

patch_embeddings = [torch.concat([low_res.view(-1, self.config.hidden_size), high_res.view(-1, self.config.hidden_size)]) for low_res, high_res in zip(low_res_embeddings, high_patch_embeddings)]

inputs_embeds = self.gather_continuous_embeddings(
word_embeddings=inputs_embeds,
Expand Down
Loading

0 comments on commit 6ed24a3

Please sign in to comment.