Skip to content

Commit

Permalink
Apply black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
openhands-agent committed Jan 3, 2025
1 parent ddc91e2 commit 8299eb6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
24 changes: 14 additions & 10 deletions moondream/torch/image_crops.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,29 @@ def overlap_crop_image(
)

# Convert to vips for resizing
vips_image = pyvips.Image.new_from_memory(image.tobytes(),
image.shape[1], image.shape[0],
image.shape[2], 'uchar')
vips_image = pyvips.Image.new_from_memory(
image.tobytes(), image.shape[1], image.shape[0], image.shape[2], "uchar"
)

# Resize using vips
scale_x = target_size[1] / image.shape[1]
scale_y = target_size[0] / image.shape[0]
resized = vips_image.resize(scale_x, vscale=scale_y)
image = np.ndarray(buffer=resized.write_to_memory(),
dtype=np.uint8,
shape=[resized.height, resized.width, resized.bands])
image = np.ndarray(
buffer=resized.write_to_memory(),
dtype=np.uint8,
shape=[resized.height, resized.width, resized.bands],
)

# Create global crop
scale_x = base_size[1] / vips_image.width
scale_y = base_size[0] / vips_image.height
global_vips = vips_image.resize(scale_x, vscale=scale_y)
global_crop = np.ndarray(buffer=global_vips.write_to_memory(),
dtype=np.uint8,
shape=[global_vips.height, global_vips.width, global_vips.bands])
global_crop = np.ndarray(
buffer=global_vips.write_to_memory(),
dtype=np.uint8,
shape=[global_vips.height, global_vips.width, global_vips.bands],
)

# Extract crops with overlap
crops = []
Expand Down
17 changes: 11 additions & 6 deletions tests/test_image_crops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from moondream.torch.image_crops import overlap_crop_image, reconstruct_from_crops


def test_overlap_crop_basic():
# Create a test image
test_image = np.zeros((800, 600, 3), dtype=np.uint8)
Expand All @@ -16,16 +17,18 @@ def test_overlap_crop_basic():
assert all(crop.shape == (378, 378, 3) for crop in result["local_crops"])
assert len(result["tiling"]) == 2


def test_overlap_crop_small_image():
# Test with image smaller than crop size
test_image = np.zeros((300, 200, 3), dtype=np.uint8)
result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12)

# Should still produce valid output
assert result["global_crop"].shape == (378, 378, 3)
assert len(result["local_crops"]) == 1
assert result["tiling"] == (1, 1)


def test_reconstruction():
# Create a test image
test_image = np.zeros((800, 600, 3), dtype=np.uint8)
Expand All @@ -35,7 +38,9 @@ def test_reconstruction():
# Crop and reconstruct
result = overlap_crop_image(test_image, overlap_margin=4, max_crops=12)
crops_tensor = [torch.from_numpy(crop) for crop in result["local_crops"]]
reconstructed = reconstruct_from_crops(crops_tensor, result["tiling"], overlap_margin=4)
reconstructed = reconstruct_from_crops(
crops_tensor, result["tiling"], overlap_margin=4
)

# Convert back to numpy for comparison
reconstructed_np = reconstructed.numpy()
Expand All @@ -45,9 +50,9 @@ def test_reconstruction():
# but the white rectangle should still be visible in the middle
center_original = test_image[300:500, 200:400].mean()
center_reconstructed = reconstructed_np[
reconstructed_np.shape[0]//2-100:reconstructed_np.shape[0]//2+100,
reconstructed_np.shape[1]//2-100:reconstructed_np.shape[1]//2+100
reconstructed_np.shape[0] // 2 - 100 : reconstructed_np.shape[0] // 2 + 100,
reconstructed_np.shape[1] // 2 - 100 : reconstructed_np.shape[1] // 2 + 100,
].mean()

# The center region should be significantly brighter than the edges
assert center_reconstructed > reconstructed_np[:100, :100].mean() + 100
assert center_reconstructed > reconstructed_np[:100, :100].mean() + 100

0 comments on commit 8299eb6

Please sign in to comment.