Skip to content

Commit

Permalink
update eval scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
vikhyat committed Feb 6, 2025
1 parent 57915fb commit f981200
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 9 deletions.
20 changes: 16 additions & 4 deletions moondream/eval/chartqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,13 @@ def eval_chartqa(model, debug=False):
total = 0
human_correct = 0
human_total = 0
results = []

for row in tqdm(dataset, disable=debug, desc="ChartQA"):
image = row["image"]
encoded_image = model.encode_image(image)

result = []
for qa in row["qa"]:
question = PREFIX + qa["question"]
answer = qa["answer"]
Expand All @@ -92,6 +94,7 @@ def eval_chartqa(model, debug=False):
if qa["source"] == "human":
human_total += 1

is_correct = False
if all(
relaxed_correctness(
str(cur_answer).strip().lower(),
Expand All @@ -102,21 +105,30 @@ def eval_chartqa(model, debug=False):
correct += 1
if qa["source"] == "human":
human_correct += 1
elif debug:
print(f"Question: {qa['question']}")
print(f"Answer: {answer}")
print(f"Model Answer: {model_answer}")
is_correct = True
if debug:
print(
f"Correct: {correct}, Total: {total}, Human Correct: {human_correct}, Human Total: {human_total}"
)
print(f"Human Accuracy: {human_correct * 100 / human_total:.2f}")
print(f"Total Accuracy: {correct * 100 / total:.2f}")
print("---------")
result.append(
{
"question": question,
"ground_truth": answer_list,
"model_answer": model_answer_list,
"is_correct": is_correct,
"source": qa["source"],
}
)
results.append(result)


return {
"human_acc": human_correct * 100 / human_total,
"total_acc": correct * 100 / total,
"results": results,
}


Expand Down
14 changes: 13 additions & 1 deletion moondream/eval/countbenchqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def eval_countbenchqa(model, debug=False):

correct = 0
total = 0
results = []

for row in tqdm(dataset, disable=debug, desc="CountBenchQA"):
image = row["image"]
Expand All @@ -24,9 +25,19 @@ def eval_countbenchqa(model, debug=False):
question = PREFIX + row["question"]
answer = str(row["number"])
model_answer = model.query(encoded_image, question)["answer"]
is_correct = model_answer.strip().lower() == answer.strip().lower()

results.append(
{
"question": question,
"ground_truth": answer,
"model_answer": model_answer,
"is_correct": is_correct,
}
)

total += 1
if model_answer.strip().lower() == answer.strip().lower():
if is_correct:
correct += 1
elif debug:
print(f"Question: {row['question']}")
Expand All @@ -41,6 +52,7 @@ def eval_countbenchqa(model, debug=False):
"acc": correct * 100 / total,
"correct_count": correct,
"total_count": total,
"results": results,
}


Expand Down
12 changes: 12 additions & 0 deletions moondream/eval/docvqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ def eval_docvqa(model, debug=False):
docvqa_val = load_dataset("vikhyatk/docvqa-val", split="validation")

scores = []
results = []

for row in tqdm(docvqa_val, disable=debug, desc="DocVQA"):
image = row["image"]
encoded_image = model.encode_image(image)

result = []
for qa in row["qa"]:
question = qa["question"]
answers = qa["answers"]
Expand All @@ -34,6 +38,12 @@ def eval_docvqa(model, debug=False):
model_answer = model.query(encoded_image, prompt)["answer"]
anls = max(get_anls(model_answer, gt) for gt in answers)
scores.append(anls)
result.append({
"question": question,
"ground_truth": answers,
"model_answer": model_answer,
"anls": anls,
})

if debug:
print(f"Question: {question}")
Expand All @@ -42,9 +52,11 @@ def eval_docvqa(model, debug=False):
print(f"ANLS: {anls}")
print(f"Current Average ANLS: {sum(scores) / len(scores):.4f}")
print("---------")
results.append(result)

return {
"anls": sum(scores) / len(scores),
"results": results,
}


Expand Down
2 changes: 1 addition & 1 deletion moondream/eval/eval_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def eval_all(model, skip=[]):
results = {}
for name, eval_fn in evals.items():
results[name] = eval_fn(model)
pprint(results[name])
pprint({k: v for k, v in results[name].items() if k != "results"})

return results

Expand Down
15 changes: 14 additions & 1 deletion moondream/eval/mmstar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ def eval_mmstar(model, debug=False):
correct = 0
total = 0
category_stats = {}
results = []

for row in tqdm(dataset, disable=debug, desc="MMStar"):
image = row["image"]
question = row["question"] + SUFFIX
answer = row["answer"]
model_answer = model.query(image, question)["answer"]
is_correct = model_answer.strip().lower() == answer.strip().lower()

category = f"{row['category']} / {row['l2_category']}"
if category not in category_stats:
Expand All @@ -30,7 +32,17 @@ def eval_mmstar(model, debug=False):
total += 1
category_stats[category]["total"] += 1

if model_answer.strip().lower() == answer.strip().lower():
results.append(
{
"question": question,
"ground_truth": answer,
"model_answer": model_answer,
"is_correct": is_correct,
"category": category,
}
)

if is_correct:
correct += 1
category_stats[category]["correct"] += 1
elif debug:
Expand All @@ -52,6 +64,7 @@ def eval_mmstar(model, debug=False):
"correct_count": correct,
"total_count": total,
"category_stats": category_stats,
"results": results,
}


Expand Down
12 changes: 11 additions & 1 deletion moondream/eval/realworldqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,24 @@ def eval_realworldqa(model, debug=False):

correct = 0
total = 0
results = []

for row in tqdm(dataset, disable=debug, desc="RealWorldQA"):
image = row["image"]
question = row["question"]
answer = row["answer"]
model_answer = model.query(image, question)["answer"]
is_correct = model_answer.strip().lower() == answer.strip().lower()

results.append({
"question": question,
"ground_truth": answer,
"model_answer": model_answer,
"is_correct": is_correct,
})

total += 1
if model_answer.strip().lower() == answer.strip().lower():
if is_correct:
correct += 1
elif debug:
print(f"Image: {row['image_path']}")
Expand All @@ -38,6 +47,7 @@ def eval_realworldqa(model, debug=False):
"acc": correct * 100 / total,
"correct_count": correct,
"total_count": total,
"results": results,
}


Expand Down
12 changes: 11 additions & 1 deletion moondream/eval/textvqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def eval_textvqa(model, debug=False):

total_score = 0
total_samples = 0
results = []

for row in tqdm(dataset, disable=debug, desc="TextVQA"):
image = row["image"]
Expand All @@ -30,6 +31,15 @@ def eval_textvqa(model, debug=False):
total_score += score
total_samples += 1

results.append(
{
"question": question,
"ground_truth": row["answers"],
"model_answer": model_answer,
"score": score,
}
)

if debug:
print(f"Question: {row['question']}")
print(f"Ground Truth Answers: {row['answers']}")
Expand All @@ -38,7 +48,7 @@ def eval_textvqa(model, debug=False):
print(f"Running Average Score: {total_score * 100 / total_samples:.2f}")
print("---------")

return {"score": total_score * 100 / total_samples}
return {"score": total_score * 100 / total_samples, "results": results}


if __name__ == "__main__":
Expand Down

0 comments on commit f981200

Please sign in to comment.