Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
rchalamala committed Jun 6, 2024
1 parent 973a727 commit ed947cb
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 232 deletions.
114 changes: 55 additions & 59 deletions pipeline/data_utils/data.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
import base64
import io
import json
import os
import re
import random
import re
import sys
import base64
import urllib
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from multiprocessing import Value
from src.dragonfly.models.processing_dragonfly import DragonflyProcessor
from functools import partial
from multiprocessing import Value

import numpy as np
import torch
import torch.utils
from concurrent.futures import ThreadPoolExecutor
from datasets import load_dataset, interleave_datasets
from datasets import interleave_datasets, load_dataset
from datasets.distributed import split_dataset_by_node
from PIL import Image

from src.dragonfly.models.processing_dragonfly import DragonflyProcessor

sys.path.append("../..")
import json
import os

import yaml
from datasets.utils.file_utils import get_datasets_user_agent
from PIL import Image, ImageFile

from pipeline.train.train_utils import DistributedProxySampler
from datasets.utils.file_utils import get_datasets_user_agent

USER_AGENT = get_datasets_user_agent()

Expand All @@ -41,68 +42,63 @@
MIN_KB = 10

IMAGE_CAP_INSTRUCT = [
'Analyze the image in a comprehensive and detailed manner.',
"Analyze the image in a comprehensive and detailed manner.",
"What's happening in the scene?",
'Write a terse but informative summary of the picture.',
'What are the key elements in this picture?',
"Write a terse but informative summary of the picture.",
"What are the key elements in this picture?",
"Present a compact description of the photo's key features.",
'What do you think is going on in this snapshot?',
'Describe the following image.',
'What do you see happening in this image?',
'Provide a brief description of the given image.',
"What do you think is going on in this snapshot?",
"Describe the following image.",
"What do you see happening in this image?",
"Provide a brief description of the given image.",
"What is this photo about'?",
'Summarize the visual content of the image.',
'What is in the photo?',
'Write a detailed description of the given image.',
'Can you elaborate on the elements of the picture provided?',
'Give a brief description of the image.',
'Explain the visual content of the image in great detail.',
'Render a clear and concise summary of the photo.',
'Describe the image concisely.',
'Give a short and clear explanation of the subsequent image.',
'Can you describe the main features of this image for me?',
'Share a concise interpretation of the image provided.',
'What is this?'
"Summarize the visual content of the image.",
"What is in the photo?",
"Write a detailed description of the given image.",
"Can you elaborate on the elements of the picture provided?",
"Give a brief description of the image.",
"Explain the visual content of the image in great detail.",
"Render a clear and concise summary of the photo.",
"Describe the image concisely.",
"Give a short and clear explanation of the subsequent image.",
"Can you describe the main features of this image for me?",
"Share a concise interpretation of the image provided.",
"What is this?",
]


def format_llama3_prompt(item):
if 'text' in item:
if "text" in item:
instruction = random.choice(IMAGE_CAP_INSTRUCT)
formated_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{instruction}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{item['text']}<|eot_id|>"
else:
formated_prompt = item['conversations']
formated_prompt = item["conversations"]
return formated_prompt


def prepare_dragonfly_sample(batch_data, processor, image_dir=None, max_length=None):
def hq_dataset_pp(example):
if 'image_url' in example and example['image_url'] is not None and example['image_url'].strip() !="":
if example['image_url'].startswith('/data/'):
image_url = example['image_url']
if "image_url" in example and example["image_url"] is not None and example["image_url"].strip() != "":
if example["image_url"].startswith("/data/"):
image_url = example["image_url"]
else:
image_url = os.path.join(image_dir, example['image_url'])
img = Image.open(image_url).convert('RGB')
image_url = os.path.join(image_dir, example["image_url"])
img = Image.open(image_url).convert("RGB")
else:
img = None
return img

for item in batch_data:
item['text'] = format_llama3_prompt(item)
if 'image' not in item:
item['image'] = hq_dataset_pp(item)
pil_images = [item['image']for item in batch_data]
item["text"] = format_llama3_prompt(item)
if "image" not in item:
item["image"] = hq_dataset_pp(item)

pil_images = [item["image"] for item in batch_data]
if None in pil_images:
pil_images = None

texts = [item['text'] for item in batch_data]
model_inputs = processor(
text=texts,
images=pil_images,
return_tensors="pt",
truncation=True,
max_length=max_length)
texts = [item["text"] for item in batch_data]
model_inputs = processor(text=texts, images=pil_images, return_tensors="pt", truncation=True, max_length=max_length)
labels = processor.get_labels(input_ids=model_inputs["input_ids"], special_token_id=128000)
input_ids, labels = processor.find_and_remove_tokens(input_ids=model_inputs["input_ids"], labels=labels, token_id=128000)
model_inputs["input_ids"] = input_ids
Expand All @@ -112,17 +108,17 @@ def hq_dataset_pp(example):

def load_dragonfly_val_dataset(args):
processor = args.processor
data_files = args.val_files.split(',')
data_files = args.val_files.split(",")

def hq_dataset_pp(example):
imgs = []
for image_url in example['image_url']:
for image_url in example["image_url"]:
image_url = os.path.join(args.image_dir, image_url)
imgs.append(Image.open(image_url).convert('RGB'))
example['image'] = imgs
imgs.append(Image.open(image_url).convert("RGB"))
example["image"] = imgs
return example

val_dataset = load_dataset('parquet', data_files=data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
val_dataset = load_dataset("parquet", data_files=data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
val_dataset = split_dataset_by_node(val_dataset, world_size=args.world_size, rank=args.rank)
val_dataset = val_dataset.map(hq_dataset_pp, batched=True, batch_size=args.batch_size)
val_dataset = val_dataset.remove_columns(["image_url", "source"])
Expand All @@ -133,16 +129,16 @@ def hq_dataset_pp(example):
def load_dragonfly_pretrain_dataset(args):
processor = args.processor
data_dir = args.data_dir
data_files = [f for f in os.listdir(data_dir) if f.endswith('.parquet')]
data_files = [f for f in os.listdir(data_dir) if f.endswith(".parquet")]
hq_data_files = []
lq_data_files = []
text_data_files = []
math_data_files = []
together_hq_datasets = args.together_hq_datasets.split(',') if args.together_hq_datasets else []
together_lq_datasets = args.together_lq_datasets.split(',') if args.together_lq_datasets else []
together_text_datasets = args.together_text_datasets.split(',') if args.together_text_datasets else []
together_math_datasets = args.together_math_datasets.split(',') if args.together_math_datasets else []

together_hq_datasets = args.together_hq_datasets.split(",") if args.together_hq_datasets else []
together_lq_datasets = args.together_lq_datasets.split(",") if args.together_lq_datasets else []
together_text_datasets = args.together_text_datasets.split(",") if args.together_text_datasets else []
together_math_datasets = args.together_math_datasets.split(",") if args.together_math_datasets else []
print(together_hq_datasets)
print(together_text_datasets)
print(together_math_datasets)
Expand All @@ -167,23 +163,23 @@ def load_dragonfly_pretrain_dataset(args):
print(text_data_files)
print(math_data_files)

hq_dataset = load_dataset('parquet', data_files=hq_data_files, split="train", cache_dir=args.data_cache_dir)
hq_dataset = load_dataset("parquet", data_files=hq_data_files, split="train", cache_dir=args.data_cache_dir)
hq_dataset = split_dataset_by_node(hq_dataset, world_size=args.world_size, rank=args.rank)
hq_dataset = hq_dataset.remove_columns(["source"])
dataset = hq_dataset

dataset = dataset.shuffle(seed=args.seed)

if text_data_files:
text_dataset = load_dataset('parquet', data_files=text_data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
text_dataset = load_dataset("parquet", data_files=text_data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
text_dataset = split_dataset_by_node(text_dataset, world_size=args.world_size, rank=args.rank)
# text_dataset = text_dataset.remove_columns(["source"])
text_dataset = text_dataset.shuffle(seed=args.seed, buffer_size=1000)
else:
text_dataset = None

if math_data_files:
math_dataset = load_dataset('parquet', data_files=math_data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
math_dataset = load_dataset("parquet", data_files=math_data_files, split="train", cache_dir=args.data_cache_dir, streaming=True)
math_dataset = split_dataset_by_node(math_dataset, world_size=args.world_size, rank=args.rank)
math_dataset = math_dataset.shuffle(seed=args.seed, buffer_size=1000)
else:
Expand Down
1 change: 1 addition & 0 deletions pipeline/train/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

import torch


Expand Down
8 changes: 6 additions & 2 deletions pipeline/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

try:
from transformers.models.idefics.processing_idefics import image_attention_mask_for_packed_input_ids, incremental_to_binary_attention_mask
from transformers.models.idefics.processing_idefics import (
image_attention_mask_for_packed_input_ids,
incremental_to_binary_attention_mask,
)
except ImportError:
print("Failed to import Idefics processing module.")

Expand Down Expand Up @@ -295,6 +298,7 @@ def get_next_dataloader(dataloader_iterators, weights):
chosen_dataloader_index = np.random.choice(len(dataloader_iterators), p=weights)
return dataloader_iterators[chosen_dataloader_index]


def get_next_dataloader_torch(dataloaders, weights):
weights_tensor = torch.tensor(weights, dtype=torch.float) # Convert weights to a tensor
dataloader_index = torch.multinomial(weights_tensor, 1).item() # Sample an index
Expand Down
Loading

0 comments on commit ed947cb

Please sign in to comment.