Skip to content

Commit

Permalink
Merge pull request #6 from togethercomputer/rchalamala-patch-1
Browse files Browse the repository at this point in the history
Format and fix small bugs
  • Loading branch information
rthapa84 authored Jun 6, 2024
2 parents 38d98f1 + 4d50c3b commit b4c5643
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 304 deletions.
47 changes: 20 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ Recent advances in large multimodal models (LMMs) suggest that higher image reso

## 💿 Installation

Clone this repository and navigate to Dragonfly folder
```bash
git clone https://github.com/togethercomputer/Dragonfly.git
cd Dragonfly
```

Create a conda environment and install necessary packages
```bash
conda env create -f environment.yml
Expand Down Expand Up @@ -76,50 +70,49 @@ Question: Summarize the visual content of the image.

Load necessary packages
```python
import sys
from dragonfly.models.modeling_dragonfly import *
from dragonfly.models.processing_dragonfly import *
from transformers import AutoProcessor, AutoTokenizer
from PIL import Image
import torch
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer

from dragonfly.models.modeling_dragonfly import DragonflyForCausalLM
from dragonfly.models.processing_dragonfly import DragonflyProcessor
from pipeline.train.train_utils import random_seed
```

Instantiate the tokenizer, processor, and model.
```python
device = torch.device("cuda:0")

tokenizer = AutoTokenizer.from_pretrained("togethercomputer/Llama-3-8B-Dragonfly-v1")
clip_processor = AutoProcessor.from_pretrained('openai/clip-vit-base-patch32')
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_processor = clip_processor.image_processor
processor = DragonflyProcessor(image_processor=image_processor, tokenizer=tokenizer, image_encoding_style='llava-hd')
model = DragonflyForCausalLM.from_pretrained(
"togethercomputer/Llama-3-8B-Dragonfly-v1"
)
processor = DragonflyProcessor(image_processor=image_processor, tokenizer=tokenizer, image_encoding_style="llava-hd")

model = DragonflyForCausalLM.from_pretrained("togethercomputer/Llama-3-8B-Dragonfly-v1")
model = model.to(torch.bfloat16)
model = model.to("cuda:0")
model = model.to(device)
```

Now, lets load the image and process them.
```python
image = Image.open("./test_images/skateboard.png")
image = image.convert('RGB')
image = image.convert("RGB")
images = [image]
# images = None # if you do not want to pass any images
# images = [None] # if you do not want to pass any images

text_prompt = "<|start_header_id|>user<|end_header_id|>\n\nSummarize the visual content of the image.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
inputs = processor(text=[text_prompt], images=[image], max_length=2048, return_tensors="pt", is_generate=True)
inputs = inputs.to("cuda:0")

inputs = processor(text=[text_prompt], images=images, max_length=2048, return_tensors="pt", is_generate=True)
inputs = inputs.to(device)
```

Finally, let us generate the responses from the model
```python
temperature = 0

with torch.inference_mode():
generation_output = model.generate(**inputs,
max_new_tokens=1024,
eos_token_id=tokenizer.encode('<|eot_id|>'),
do_sample=temperature > 0,
temperature=temperature,
use_cache=True)
generation_output = model.generate(**inputs, max_new_tokens=1024, eos_token_id=tokenizer.encode("<|eot_id|>"), do_sample=temperature > 0, temperature=temperature, use_cache=True)

generation_text = processor.batch_decode(generation_output, skip_special_tokens=False)
```

Expand Down
1 change: 0 additions & 1 deletion pipeline/accelerate_configs/accelerate_config_zero2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: bf16
num_machines: 1
Expand Down
115 changes: 55 additions & 60 deletions pipeline/data_utils/data.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
import base64
import io
import json
import os
import re
import random
import re
import sys
import base64
import urllib
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from multiprocessing import Value
from src.dragonfly.models.processing_dragonfly import DragonflyProcessor
from functools import partial
from multiprocessing import Value

import numpy as np
import torch
import torch.utils
from concurrent.futures import ThreadPoolExecutor
from datasets import load_dataset, interleave_datasets
from datasets import interleave_datasets, load_dataset
from datasets.distributed import split_dataset_by_node
from PIL import Image

from src.dragonfly.models.processing_dragonfly import DragonflyProcessor

sys.path.append("../..")
import json
import os

import yaml
from datasets.utils.file_utils import get_datasets_user_agent
from PIL import Image, ImageFile

from pipeline.train.train_utils import DistributedProxySampler
from datasets.utils.file_utils import get_datasets_user_agent

USER_AGENT = get_datasets_user_agent()

Expand All @@ -39,71 +40,65 @@
N_CHANNELS = 3
INTERLEAVED_IMAGE_SIZE = 224
MIN_KB = 10
MAX_NUM_IMAGES = 5

IMAGE_CAP_INSTRUCT = [
'Analyze the image in a comprehensive and detailed manner.',
"Analyze the image in a comprehensive and detailed manner.",
"What's happening in the scene?",
'Write a terse but informative summary of the picture.',
'What are the key elements in this picture?',
"Write a terse but informative summary of the picture.",
"What are the key elements in this picture?",
"Present a compact description of the photo's key features.",
'What do you think is going on in this snapshot?',
'Describe the following image.',
'What do you see happening in this image?',
'Provide a brief description of the given image.',
"What do you think is going on in this snapshot?",
"Describe the following image.",
"What do you see happening in this image?",
"Provide a brief description of the given image.",
"What is this photo about'?",
'Summarize the visual content of the image.',
'What is in the photo?',
'Write a detailed description of the given image.',
'Can you elaborate on the elements of the picture provided?',
'Give a brief description of the image.',
'Explain the visual content of the image in great detail.',
'Render a clear and concise summary of the photo.',
'Describe the image concisely.',
'Give a short and clear explanation of the subsequent image.',
'Can you describe the main features of this image for me?',
'Share a concise interpretation of the image provided.',
'What is this?'
"Summarize the visual content of the image.",
"What is in the photo?",
"Write a detailed description of the given image.",
"Can you elaborate on the elements of the picture provided?",
"Give a brief description of the image.",
"Explain the visual content of the image in great detail.",
"Render a clear and concise summary of the photo.",
"Describe the image concisely.",
"Give a short and clear explanation of the subsequent image.",
"Can you describe the main features of this image for me?",
"Share a concise interpretation of the image provided.",
"What is this?",
]


def format_llama3_prompt(item):
if 'text' in item:
if "text" in item:
instruction = random.choice(IMAGE_CAP_INSTRUCT)
formated_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{item['text']}<|eot_id|>"
else:
formated_prompt = item['conversations']
formated_prompt = item["conversations"]
return formated_prompt


def prepare_dragonfly_sample(batch_data, processor, image_dir=None, max_length=None):
def hq_dataset_pp(example):
if 'image_url' in example and example['image_url'] is not None and example['image_url'].strip() !="":
if example['image_url'].startswith('/data/'):
image_url = example['image_url']
if "image_url" in example and example["image_url"] is not None and example["image_url"].strip() != "":
if example["image_url"].startswith("/data/"):
image_url = example["image_url"]
else:
image_url = os.path.join(image_dir, example['image_url'])
img = Image.open(image_url).convert('RGB')
image_url = os.path.join(image_dir, example["image_url"])
img = Image.open(image_url).convert("RGB")
else:
img = None
return img

for item in batch_data:
item['text'] = format_llama3_prompt(item)
if 'image' not in item:
item['image'] = hq_dataset_pp(item)
pil_images = [item['image']for item in batch_data]
item["text"] = format_llama3_prompt(item)
if "image" not in item:
item["image"] = hq_dataset_pp(item)

pil_images = [item["image"] for item in batch_data]
if None in pil_images:
pil_images = None

texts = [item['text'] for item in batch_data]
model_inputs = processor(
text=texts,
images=pil_images,
return_tensors="pt",
truncation=True,
max_length=max_length)
texts = [item["text"] for item in batch_data]
model_inputs = processor(text=texts, images=pil_images, return_tensors="pt", truncation=True, max_length=max_length)
labels = processor.get_labels(input_ids=model_inputs["input_ids"], special_token_id=128000)
input_ids, labels = processor.find_and_remove_tokens(input_ids=model_inputs["input_ids"], labels=labels, token_id=128000)
model_inputs["input_ids"] = input_ids
Expand All @@ -113,17 +108,17 @@ def hq_dataset_pp(example):

def load_dragonfly_val_dataset(args):
processor = args.processor
data_files = args.val_files.split(',')
data_files = args.val_files.split(",")

def hq_dataset_pp(example):
imgs = []
for image_url in example['image_url']:
for image_url in example["image_url"]:
image_url = os.path.join(args.image_dir, image_url)
imgs.append(Image.open(image_url).convert('RGB'))
example['image'] = imgs
imgs.append(Image.open(image_url).convert("RGB"))
example["image"] = imgs
return example

val_dataset = load_dataset('parquet', data_files=data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
val_dataset = load_dataset("parquet", data_files=data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
val_dataset = split_dataset_by_node(val_dataset, world_size=args.world_size, rank=args.rank)
val_dataset = val_dataset.map(hq_dataset_pp, batched=True, batch_size=args.batch_size)
val_dataset = val_dataset.remove_columns(["image_url", "source"])
Expand All @@ -134,16 +129,16 @@ def hq_dataset_pp(example):
def load_dragonfly_pretrain_dataset(args):
processor = args.processor
data_dir = args.data_dir
data_files = [f for f in os.listdir(data_dir) if f.endswith('.parquet')]
data_files = [f for f in os.listdir(data_dir) if f.endswith(".parquet")]
hq_data_files = []
lq_data_files = []
text_data_files = []
math_data_files = []
together_hq_datasets = args.together_hq_datasets.split(',') if args.together_hq_datasets else []
together_lq_datasets = args.together_lq_datasets.split(',') if args.together_lq_datasets else []
together_text_datasets = args.together_text_datasets.split(',') if args.together_text_datasets else []
together_math_datasets = args.together_math_datasets.split(',') if args.together_math_datasets else []

together_hq_datasets = args.together_hq_datasets.split(",") if args.together_hq_datasets else []
together_lq_datasets = args.together_lq_datasets.split(",") if args.together_lq_datasets else []
together_text_datasets = args.together_text_datasets.split(",") if args.together_text_datasets else []
together_math_datasets = args.together_math_datasets.split(",") if args.together_math_datasets else []
print(together_hq_datasets)
print(together_text_datasets)
print(together_math_datasets)
Expand All @@ -168,23 +163,23 @@ def load_dragonfly_pretrain_dataset(args):
print(text_data_files)
print(math_data_files)

hq_dataset = load_dataset('parquet', data_files=hq_data_files, split="train", cache_dir=args.data_cache_dir)
hq_dataset = load_dataset("parquet", data_files=hq_data_files, split="train", cache_dir=args.data_cache_dir)
hq_dataset = split_dataset_by_node(hq_dataset, world_size=args.world_size, rank=args.rank)
hq_dataset = hq_dataset.remove_columns(["source"])
dataset = hq_dataset

dataset = dataset.shuffle(seed=args.seed)

if text_data_files:
text_dataset = load_dataset('parquet', data_files=text_data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
text_dataset = load_dataset("parquet", data_files=text_data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
text_dataset = split_dataset_by_node(text_dataset, world_size=args.world_size, rank=args.rank)
# text_dataset = text_dataset.remove_columns(["source"])
text_dataset = text_dataset.shuffle(seed=args.seed, buffer_size=1000)
else:
text_dataset = None

if math_data_files:
math_dataset = load_dataset('parquet', data_files=math_data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
math_dataset = load_dataset("parquet", data_files=math_data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
math_dataset = split_dataset_by_node(math_dataset, world_size=args.world_size, rank=args.rank)
math_dataset = math_dataset.shuffle(seed=args.seed, buffer_size=1000)
else:
Expand Down
1 change: 1 addition & 0 deletions pipeline/train/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import torch


Expand Down
8 changes: 6 additions & 2 deletions pipeline/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

try:
from transformers.models.idefics.processing_idefics import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
from transformers.models.idefics.processing_idefics import (
image_attention_mask_for_packed_input_ids,
incremental_to_binary_attention_mask,
)
except ImportError:
print("Failed to import Idefics processing module.")

Expand Down Expand Up @@ -295,6 +298,7 @@ def get_next_dataloader(dataloader_iterators, weights):
chosen_dataloader_index = np.random.choice(len(dataloader_iterators), p=weights)
return dataloader_iterators[chosen_dataloader_index]


def get_next_dataloader_torch(dataloaders, weights):
weights_tensor = torch.tensor(weights, dtype=torch.float) # Convert weights to a tensor
dataloader_index = torch.multinomial(weights_tensor, 1).item() # Sample an index
Expand Down
Loading

0 comments on commit b4c5643

Please sign in to comment.