-
Notifications
You must be signed in to change notification settings - Fork 578
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: update TallyQA eval to use new inference and HF dataset
- Loading branch information
1 parent
9d5ef97
commit 4f472f0
Showing
1 changed file
with
56 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,70 @@ | ||
# Expects Visual Genome to be downloaded to `data/vg` and the TallyQA test set | ||
# to be present at `data/tallyqa/test.json`. | ||
# | ||
# Steps to download Visual Genome and TallyQA: | ||
# | ||
# mkdir -p data/vg/VG_100K | ||
# mkdir -p data/vg/VG_100K_2 | ||
# mkdir -p data/tallyqa | ||
# wget -P data/vg/VG_100K_2/ https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip | ||
# wget -P data/vg/VG_100K/ https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip | ||
# wget -P data/tallyqa/ https://github.com/manoja328/TallyQA_dataset/raw/master/tallyqa.zip | ||
# unzip data/vg/VG_100K_2/images2.zip -d data/vg/ | ||
# unzip data/vg/VG_100K/images.zip -d data/vg/ | ||
# unzip data/tallyqa/tallyqa.zip -d data/tallyqa/ | ||
# rm data/vg/VG_100K_2/images2.zip | ||
# rm data/vg/VG_100K/images.zip | ||
# rm data/tallyqa/tallyqa.zip | ||
import argparse | ||
import datasets | ||
import torch | ||
|
||
import json | ||
|
||
from PIL import Image | ||
from tqdm import tqdm | ||
from transformers import AutoTokenizer | ||
|
||
from ..hf import Moondream, detect_device | ||
from ..torch.config import MoondreamConfig | ||
from ..torch.moondream import MoondreamModel | ||
from ..torch.weights import load_weights_into_model | ||
|
||
BATCH_SIZE = 16 | ||
DEVICE, DTYPE = detect_device() | ||
PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. " | ||
|
||
model_id = "vikhyatk/moondream2" | ||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | ||
model = Moondream.from_pretrained( | ||
model_id, | ||
attn_implementation="flash_attention_2", | ||
torch_dtype=DTYPE, | ||
device_map={"": DEVICE}, | ||
) | ||
model.eval() | ||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", type=str, required=True) | ||
parser.add_argument("--debug", action="store_true") | ||
args = parser.parse_args() | ||
|
||
if torch.cuda.is_available(): | ||
torch.set_default_device("cuda") | ||
elif torch.backends.mps.is_available(): | ||
torch.set_default_device("mps") | ||
|
||
config = MoondreamConfig() | ||
model = MoondreamModel(config) | ||
load_weights_into_model(args.model, model) | ||
model.compile() | ||
|
||
dataset = datasets.load_dataset("vikhyatk/tallyqa-test", split="test") | ||
|
||
total = 0 | ||
total_simple = 0 | ||
correct = 0 | ||
correct_simple = 0 | ||
|
||
total = 0 | ||
total_simple = 0 | ||
correct = 0 | ||
correct_simple = 0 | ||
for row in tqdm(dataset, disable=args.debug): | ||
image = row["image"] | ||
encoded_image = model.encode_image(image) | ||
|
||
# Iterate over tallyqa_test in batches of BATCH_SIZE | ||
tallyqa_test = json.load(open("data/tallyqa/test.json")) | ||
for i in tqdm(range(0, len(tallyqa_test), BATCH_SIZE)): | ||
batch = tallyqa_test[i : i + BATCH_SIZE] | ||
for qa in row["qa"]: | ||
question = PREFIX + qa["question"] | ||
answer = str(qa["answer"]) | ||
is_simple = qa["is_simple"] | ||
|
||
images = [Image.open(f"data/vg/{item['image']}") for item in batch] | ||
questions = [ | ||
item["question"] + " Answer in a word or phrase only." for item in batch | ||
] | ||
model_answer = model.query(encoded_image, question)["answer"] | ||
|
||
answers = model.batch_answer( | ||
images=images, prompts=questions, tokenizer=tokenizer, max_new_tokens=10 | ||
) | ||
total += 1 | ||
if model_answer.strip().lower() == answer.strip().lower(): | ||
correct += 1 | ||
elif args.debug: | ||
print(f"Question: {qa['question']}") | ||
print(f"Answer: {answer}") | ||
print(f"Model Answer: {model_answer}") | ||
|
||
for answer, item in zip(answers, batch): | ||
is_simple = item["issimple"] | ||
is_correct = 1 if str(item["answer"]) == answer else 0 | ||
if is_simple: | ||
total_simple += 1 | ||
if model_answer.strip().lower() == answer.strip().lower(): | ||
correct_simple += 1 | ||
|
||
total += 1 | ||
correct += is_correct | ||
if is_simple: | ||
total_simple += 1 | ||
correct_simple += is_correct | ||
if args.debug: | ||
print(f"Simple - Correct: {correct_simple}, Total: {total_simple}") | ||
print(f"Simple Accuracy: {correct_simple * 100 / total_simple:.2f}") | ||
print(f"All - Correct: {correct}, Total: {total}") | ||
print(f"All Accuracy: {correct * 100 / total:.2f}") | ||
print("---------") | ||
|
||
print( | ||
f"Simple: {total_simple}, Correct: {correct_simple}, Accuracy: {correct_simple*100.0/total_simple}" | ||
f"Simple: {total_simple}, Correct: {correct_simple}, Accuracy: {correct_simple*100.0/total_simple:.2f}" | ||
) | ||
print(f"Total: {total}, Correct: {correct}, Accuracy: {correct*100.0/total}") | ||
print(f"Total: {total}, Correct: {correct}, Accuracy: {correct*100.0/total:.2f}") |