Skip to content

Commit

Permalink
some cleanup/updates to caption stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
victorchall committed Mar 22, 2024
1 parent 273ce23 commit 056de84
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 81 deletions.
159 changes: 112 additions & 47 deletions caption_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@

from plugins.caption_plugins import load_prompt_alteration_plugin
from utils.patch_cog import patch_cog
from data.gen_utils import image_generator, SUPPORTED_EXT
from data.generators import image_path_generator, SUPPORTED_EXT

try:
from moai.load_moai import prepare_moai
except ImportError:
print("moai not found, skipping")

IMAGE_SIZE: int = 490
PATCH_SIZE: int = 14
Expand Down Expand Up @@ -107,7 +112,7 @@ def save_params(args, gen_kwargs):
with open(save_path, "w") as f:
f.write(pretty_print)

def create_bnb_config(args):
def create_bnb_config():
return BitsAndBytesConfig(
bnb_4bit_compute_dtype="float32",
bnb_4bit_quant_type= "fp4",
Expand All @@ -121,58 +126,117 @@ def create_bnb_config(args):
quant_method="bitsandbytes"
)

class MoaiManager:
def __init__(self, model_name: str):
self.model_name = model_name
self.moai_model = None
self.moai_processor = None
self.seg_model = None
self.seg_processor = None
self.od_model = None
self.od_processor = None
self.sgg_model = None
self.ocr_model = None

def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str='fp16'):
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
self.moai_model = moai_model
self.moai_processor = moai_processor
self.seg_model = seg_model
self.seg_processor = seg_processor
self.od_model = od_model
self.od_processor = od_processor
self.sgg_model = sgg_model
self.ocr_model = ocr_model

return moai_model, moai_processor

def get_inputs(self, image: Image.Image, prompt: str):
moai_inputs = self.moai_model.demo_process(image=image,
prompt=prompt,
processor=self.moai_processor,
seg_model=self.seg_model,
seg_processor=self.seg_processor,
od_model=self.od_model,
od_processor=self.od_processor,
sgg_model=self.sgg_model,
ocr_model=self.ocr_model,
device='cuda:0')
return moai_inputs

def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
with torch.inference_mode():
generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split('[U')[0]
return answer

class CogVLMManager:
def __init__(self, model_name: str):
self.model_name = model_name
self.tokenizer = None
self.model = None

def load_model(self):
self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
quantization_config=create_bnb_config()
)
return self.model, self.tokenizer

def get_inputs(self, prompt: str, history: List[Tuple[str, str]], images: List[Image.Image], starts_with: str):
return build_conversation_input_ids(self.tokenizer, query=prompt, history=history, images=images, starts_with=starts_with)

def get_gen_kwargs(self, args):
gen_kwargs = {
"max_length": args.max_length,
"do_sample": args.top_k is not None or args.top_p is not None or args.temp is not None or False,
"length_penalty": args.length_penalty,
"num_beams": args.num_beams,
"temperature": args.temp,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"no_repeat_ngram_size": args.no_repeat_ngram_size,
"min_new_tokens": args.min_new_tokens,
"max_new_tokens": args.max_new_tokens,
"length_penalty": args.length_penalty,
}
print(gen_kwargs)
if args.max_new_tokens is not None:
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
del gen_kwargs["max_length"]

if not gen_kwargs["do_sample"]:
logging.info(f"** Using greedy sampling")
del gen_kwargs["top_k"]
del gen_kwargs["top_p"]
del gen_kwargs["temperature"]
else:
logging.info(f"** Sampling enabled")
return gen_kwargs

def model_manager_factory(model_name: str):
if "moai" in model_name:
return MoaiManager(model_name)
else:
return CogVLMManager(model_name)

def main(args):
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
model_manager = model_manager_factory(args.model)

bnb_config = create_bnb_config(args)

tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
'THUDM/cogvlm-chat-hf',
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
#revision=... # no one is actually doing this
#load_in_4bit=not args.disable_4bit,
quantization_config=bnb_config,
)

do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None
if do_sample:
args.top_k = args.top_k or 50
args.top_p = args.top_p or 1.0
args.temp = args.temp or 1.0
model, tokenizer = model_manager.load_model()

args.append = args.append or ""
if len(args.append) > 0:
args.append = " " + args.append.strip()

gen_kwargs = {
"max_length": args.max_length,
"do_sample": do_sample,
"length_penalty": args.length_penalty,
"num_beams": args.num_beams,
"temperature": args.temp,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"no_repeat_ngram_size": args.no_repeat_ngram_size,
"min_new_tokens": args.min_new_tokens,
"max_new_tokens": args.max_new_tokens,
"length_penalty": args.length_penalty,
}

if args.max_new_tokens is not None:
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
del gen_kwargs["max_length"]

if not do_sample:
logging.info(f"** Using greedy sampling")
del gen_kwargs["top_k"]
del gen_kwargs["top_p"]
del gen_kwargs["temperature"]
else:
logging.info(f"** Sampling enabled")
gen_kwargs = model_manager.get_gen_kwargs(args)

force_words_ids = None
if args.force_words is not None:
Expand All @@ -195,7 +259,7 @@ def main(args):

starts_with = args.starts_with.strip() if args.starts_with is not None else ""

for i, image_path in enumerate(image_generator(args.image_dir, do_recurse=not args.no_recurse)):
for i, image_path in enumerate(image_path_generator(args.image_dir, do_recurse=not args.no_recurse)):
candidate_caption_path = image_path.replace(os.path.splitext(image_path)[-1], ".txt")

if args.no_overwrite and os.path.exists(candidate_caption_path):
Expand Down Expand Up @@ -342,6 +406,7 @@ def configure_logging(args: argparse.Namespace):
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.")
args = argparser.parse_args()

configure_logging(args)
Expand Down
Empty file added data/__init__.py
Empty file.
15 changes: 0 additions & 15 deletions data/gen_utils.py

This file was deleted.

96 changes: 96 additions & 0 deletions data/generators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Copyright [2022-2024] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
from typing import Generator
from data.image_train_item import ImageTrainItem, ImageCaption
from PIL import Image, ImageOps
import tarfile
import logging

SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]

class BucketBatchedGenerator(Generator[ImageTrainItem, None, None]):
"""
returns items in with the same aspect ratio in batches, for use with batching dataloaders
"""
def __init__(self, batch_size: int=1, generator: Generator[ImageTrainItem, None, None]=None):
self.caption = batch_size
self.cache = {}
self.generator = generator

def __iter__(self):
for item in self.generator:
if item.target_wh:
aspect_bucket_key = item.target_wh
if aspect_bucket_key not in self.cache:
self.cache[aspect_bucket_key] = []
self.cache[aspect_bucket_key].append(item)
if len(self.cache[aspect_bucket_key]) >= self.batch_size:
for item in self.cache[aspect_bucket_key]:
yield item
self.cache[aspect_bucket_key] = []

# def image_train_item_generator_from_tar_pairs(image_dir: str, do_recurse: bool = True) -> Generator[ImageTrainItem, None, None]:
# for root, dirs, files in os.walk(image_dir):
# for file in files:
# if file.endswith(".tar"):
# tar_path = os.path.join(root, file)
# with tarfile.open(tar_path, "r") as tar:
# for tarinfo in tar:
# if tarinfo.isfile() and any(tarinfo.name.endswith(ext) for ext in SUPPORTED_EXT):
# try:
# img = Image.open(tar.extractfile(tarinfo))
# txt = tar.extractfile(tarinfo.name.replace(os.path.splitext(tarinfo.name)[-1], ".txt"))
# caption = txt.read().decode("utf-8")
# img_caption = ImageCaption(main_prompt=caption, rating=0, tags=[], tag_weights=[], max_target_length=256, use_weights=False)
# img = ImageOps.exif_transpose(img)
# iti = ImageTrainItem(img, img_caption)
# except Exception as e:
# logging.error(f"Failed to open {tarinfo.name}: {e}")
# continue
# yield iti

def image_train_item_generator_from_files(image_dir: str, do_recurse: bool = True) -> Generator[ImageTrainItem, None, None]:
for img_path in image_path_generator(image_dir, do_recurse):
try:
img = Image.open(img_path)
img = ImageOps.exif_transpose(img)
except Exception as e:
print(f"Failed to open {img_path}: {e}")
continue
# main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
txt_cap_path = img_path.replace(os.path.splitext(img_path)[-1], ".txt")
if os.path.exists(txt_cap_path):
with open(txt_cap_path, "r") as f:
caption = f.read()
if not caption or len(caption) < 1:
caption = os.path.basename(img_path)
caption = caption.split("_")[0]
image_caption = ImageCaption(main_prompt=caption, rating=0, tags=[], tag_weights=[], max_target_length=128, use_weights=False)
iti = ImageTrainItem(img)
yield iti

def image_path_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
if do_recurse:
for root, dirs, files in os.walk(image_dir):
for file in files:
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(root, file)
else:
for file in os.listdir(image_dir):
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(image_dir, file)
Loading

0 comments on commit 056de84

Please sign in to comment.