diff --git a/test_dragonfly_attention.py b/test_dragonfly_attention.py index 7e09e4b..e6eea09 100644 --- a/test_dragonfly_attention.py +++ b/test_dragonfly_attention.py @@ -110,12 +110,10 @@ def main(): patch_indices = model_outputs["query_ranks"][0].cpu().tolist() highlighted_image = combine_patches(high_patches, high_image_padded.size, se, patch_indices) + highlighted_image.resize(image.size) prefix = "/".join(image_path.split(".")[:-1]) image_type = image_path.split(".")[-1] - - highlighted_image.resize(image.size) - save_path = f"{prefix}_highlighted.{image_type}" highlighted_image.save(save_path)