Skip to content

Commit

Permalink
refactor evals
Browse files Browse the repository at this point in the history
update ChartQA, CountBenchQA, DocVQA, GazeFollow to export a reusable
function. remove GQA since we're no longer going to use it.
  • Loading branch information
vikhyat committed Jan 13, 2025
1 parent 101c255 commit 17bb88e
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 231 deletions.
52 changes: 29 additions & 23 deletions moondream/eval/chartqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,16 @@

PREFIX = "Analyze the chart carefully, consider both visual features and data values, and provide a precise answer without any additional explanation or formatting. "

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

if torch.cuda.is_available():
torch.set_default_device("cuda")
elif torch.backends.mps.is_available():
torch.set_default_device("mps")

config = MoondreamConfig()
model = MoondreamModel(config)
load_weights_into_model(args.model, model)
model.compile()

def eval_chartqa(model, debug=False):
dataset = datasets.load_dataset("vikhyatk/chartqa", split="test")

correct = 0
total = 0
human_correct = 0
human_total = 0

for row in tqdm(dataset, disable=args.debug):
for row in tqdm(dataset, disable=debug):
image = row["image"]
encoded_image = model.encode_image(image)

Expand All @@ -49,20 +35,40 @@
correct += 1
if qa["source"] == "human":
human_correct += 1
elif args.debug:
elif debug:
print(f"Question: {qa['question']}")
print(f"Answer: {answer}")
print(f"Model Answer: {model_answer}")
if args.debug:
if debug:
print(
f"Correct: {correct}, Total: {total}, Human Correct: {human_correct}, Human Total: {human_total}"
)
print(f"Human Accuracy: {human_correct * 100 / human_total:.2f}")
print(f"Total Accuracy: {correct * 100 / total:.2f}")
print("---------")

print(
f"Correct: {correct}, Total: {total}, Human Correct: {human_correct}, Human Total: {human_total}"
)
print(f"Human Accuracy: {human_correct * 100 / human_total:.2f}")
print(f"Total Accuracy: {correct * 100 / total:.2f}")
return {
"human_acc": human_correct * 100 / human_total,
"total_acc": correct * 100 / total,
}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

if torch.cuda.is_available():
torch.set_default_device("cuda")
elif torch.backends.mps.is_available():
torch.set_default_device("mps")

config = MoondreamConfig()
model = MoondreamModel(config)
load_weights_into_model(args.model, model)
model.compile()

results = eval_chartqa(model, args.debug)
print(f"Human Accuracy: {results['human_acc']:.2f}")
print(f"Total Accuracy: {results['total_acc']:.2f}")
49 changes: 30 additions & 19 deletions moondream/eval/countbenchqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,14 @@

PREFIX = "Look at the image carefully and count the objects. Answer with just a number, without any additional text. "

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

if torch.cuda.is_available():
torch.set_default_device("cuda")
elif torch.backends.mps.is_available():
torch.set_default_device("mps")

config = MoondreamConfig()
model = MoondreamModel(config)
load_weights_into_model(args.model, model)

def eval_countbenchqa(model, debug=False):
dataset = datasets.load_dataset("vikhyatk/CountBenchQA", split="test")

correct = 0
total = 0

for row in tqdm(dataset, disable=args.debug):
for row in tqdm(dataset, disable=debug):
image = row["image"]
encoded_image = model.encode_image(image)

Expand All @@ -41,14 +28,38 @@
total += 1
if model_answer.strip().lower() == answer.strip().lower():
correct += 1
elif args.debug:
elif debug:
print(f"Question: {row['question']}")
print(f"Answer: {answer}")
print(f"Model Answer: {model_answer}")
if args.debug:
if debug:
print(f"Correct: {correct}, Total: {total}")
print(f"Accuracy: {correct * 100 / total:.2f}")
print("---------")

print(f"Correct: {correct}, Total: {total}")
print(f"Accuracy: {correct * 100 / total:.2f}")
return {
"acc": correct * 100 / total,
"correct_count": correct,
"total_count": total,
}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

if torch.cuda.is_available():
torch.set_default_device("cuda")
elif torch.backends.mps.is_available():
torch.set_default_device("mps")

config = MoondreamConfig()
model = MoondreamModel(config)
load_weights_into_model(args.model, model)

result = eval_countbenchqa(model, args.debug)

print(f"Accuracy: {result['acc']:.2f}")
print(f"Correct: {result['correct_count']}, Total: {result['total_count']}")
42 changes: 25 additions & 17 deletions moondream/eval/docvqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,7 @@ def get_anls(s1, s2):
return anls


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

if torch.cuda.is_available():
torch.set_default_device("cuda")
elif torch.backends.mps.is_available():
torch.set_default_device("mps")

config = MoondreamConfig()
model = MoondreamModel(config)
load_weights_into_model(args.model, model)
model.compile()

def eval_docvqa(model, debug=False):
docvqa_val = load_dataset("vikhyatk/docvqa-val", split="validation")

scores = []
Expand All @@ -58,4 +43,27 @@ def get_anls(s1, s2):
print(f"Current Average ANLS: {sum(scores) / len(scores):.4f}")
print("---------")

print("ANLS:", sum(scores) / len(scores))
return {
"anls": sum(scores) / len(scores),
}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

if torch.cuda.is_available():
torch.set_default_device("cuda")
elif torch.backends.mps.is_available():
torch.set_default_device("mps")

config = MoondreamConfig()
model = MoondreamModel(config)
load_weights_into_model(args.model, model)
model.compile()

result = eval_docvqa(model, args.debug)

print(f"ANLS: {result['anls']:.4f}")
161 changes: 89 additions & 72 deletions moondream/eval/gazefollow.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,87 +8,104 @@
from ..torch.moondream import MoondreamModel
from ..torch.weights import load_weights_into_model

dataset = datasets.load_dataset("vikhyatk/gazefollow", split="test")

torch.set_default_device("cuda")
model = MoondreamModel(MoondreamConfig())
load_weights_into_model("model.pt", model)
def eval_gazefollow(model, debug=False):
dataset = datasets.load_dataset("vikhyatk/gazefollow", split="test")

mean_l2_error = []
min_l2_error = []
total = 0

for i, row in tqdm(enumerate(dataset), total=len(dataset)):
heads = []

for gaze in row["gazes"]:
head_bbox = gaze["head_bbox"] # xmin, ymin, xmax, ymax
eye_coord = (gaze["eye"]["x"], gaze["eye"]["y"])
mean_target_gaze = (gaze["gaze"]["x"], gaze["gaze"]["y"])

# Check if a head already exists with the same approximate bbox.
# If so, use that head instead of creating a new one.
for head in heads:
if (
abs(head["head_bbox"]["xmin"] - head_bbox["xmin"]) < 0.001
and abs(head["head_bbox"]["xmax"] - head_bbox["xmax"]) < 0.001
and abs(head["head_bbox"]["ymin"] - head_bbox["ymin"]) < 0.001
and abs(head["head_bbox"]["ymax"] - head_bbox["ymax"]) < 0.001
):
head["gazes"].append(mean_target_gaze)
break
else:
heads.append(
{
"head_bbox": head_bbox,
"eye_coord": eye_coord,
"gazes": [mean_target_gaze],
}
)

mean_l2_error = []
min_l2_error = []
total = 0
for head in heads:
pred_gaze = model.detect_gaze(
row["image"],
eye=head["eye_coord"],
face={
"x_min": head["head_bbox"]["xmin"],
"y_min": head["head_bbox"]["ymin"],
"x_max": head["head_bbox"]["xmax"],
"y_max": head["head_bbox"]["ymax"],
},
unstable_settings={"force_detect": True},
)["gaze"]

mean_target_gaze = (
sum(gaze[0] for gaze in head["gazes"]) / len(head["gazes"]),
sum(gaze[1] for gaze in head["gazes"]) / len(head["gazes"]),
)
mean_l2 = math.sqrt(
(mean_target_gaze[0] - pred_gaze["x"]) ** 2
+ (mean_target_gaze[1] - pred_gaze["y"]) ** 2
)
min_l2 = min(
math.sqrt(
(target_gaze[0] - pred_gaze["x"]) ** 2
+ (target_gaze[1] - pred_gaze["y"]) ** 2
)
for target_gaze in head["gazes"]
)

mean_l2_error.append(mean_l2)
min_l2_error.append(min_l2)
total += 1

for i, row in tqdm(enumerate(dataset), total=len(dataset)):
encoded_image = model.encode_image(row["image"])
if i % 100 == 0 and debug:
print("Mean L2 error:", sum(mean_l2_error) / total)
print("Min L2 error:", sum(min_l2_error) / total)

heads = []
return {
"mean_l2": sum(mean_l2_error) / total,
"min_l2": sum(min_l2_error) / total,
}

for gaze in row["gazes"]:
head_bbox = gaze["head_bbox"] # xmin, ymin, xmax, ymax
eye_coord = (gaze["eye"]["x"], gaze["eye"]["y"])
mean_target_gaze = (gaze["gaze"]["x"], gaze["gaze"]["y"])

# Check if a head already exists with the same approximate bbox.
# If so, use that head instead of creating a new one.
for head in heads:
if (
abs(head["head_bbox"]["xmin"] - head_bbox["xmin"]) < 0.001
and abs(head["head_bbox"]["xmax"] - head_bbox["xmax"]) < 0.001
and abs(head["head_bbox"]["ymin"] - head_bbox["ymin"]) < 0.001
and abs(head["head_bbox"]["ymax"] - head_bbox["ymax"]) < 0.001
):
head["gazes"].append(mean_target_gaze)
break
else:
heads.append(
{
"head_bbox": head_bbox,
"eye_coord": eye_coord,
"gazes": [mean_target_gaze],
}
)
if __name__ == "__main__":
import argparse

for head in heads:
pred_gaze = model.detect_gaze(
row["image"],
eye=head["eye_coord"],
face={
"x_min": head["head_bbox"]["xmin"],
"y_min": head["head_bbox"]["ymin"],
"x_max": head["head_bbox"]["xmax"],
"y_max": head["head_bbox"]["ymax"],
},
unstable_settings={"force_detect": True},
)["gaze"]

mean_target_gaze = (
sum(gaze[0] for gaze in head["gazes"]) / len(head["gazes"]),
sum(gaze[1] for gaze in head["gazes"]) / len(head["gazes"]),
)
mean_l2 = math.sqrt(
(mean_target_gaze[0] - pred_gaze["x"]) ** 2
+ (mean_target_gaze[1] - pred_gaze["y"]) ** 2
)
min_l2 = min(
math.sqrt(
(target_gaze[0] - pred_gaze["x"]) ** 2
+ (target_gaze[1] - pred_gaze["y"]) ** 2
)
for target_gaze in head["gazes"]
)
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True)

parser.add_argument("--debug", action="store_true")
args = parser.parse_args()

mean_l2_error.append(mean_l2)
min_l2_error.append(min_l2)
total += 1
if torch.cuda.is_available():
torch.set_default_device("cuda")
elif torch.backends.mps.is_available():
torch.set_default_device("mps")

if i % 100 == 0:
print("Mean L2 error:", sum(mean_l2_error) / total)
print("Min L2 error:", sum(min_l2_error) / total)
config = MoondreamConfig()
model = MoondreamModel(config)
load_weights_into_model(args.model, model)

results = eval_gazefollow(model, debug=args.debug)

print()
print("Single prediction mode")
print("Final score:")
print("Mean L2 error:", sum(mean_l2_error) / total)
print("Min L2 error:", sum(min_l2_error) / total)
print(f"Mean L2 error: {results['mean_l2']:.4f}")
print(f"Min L2 error: {results['min_l2']:.4f}")
Loading

0 comments on commit 17bb88e

Please sign in to comment.