Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vikhyat committed Jan 3, 2025
1 parent 9f49364 commit d2b0c26
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 46 deletions.
2 changes: 2 additions & 0 deletions moondream/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class VisionConfig:
enc_n_heads: int = 16
crop_size: int = 378
in_channels: int = 3
max_crops: int = 12
overlap_margin: int = 4


@dataclass(frozen=True)
Expand Down
36 changes: 15 additions & 21 deletions moondream/torch/image_crops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def select_tiling(


class OverlapCropOutput(TypedDict):
global_crop: np.ndarray
local_crops: list[np.ndarray]
crops: np.ndarray
tiling: tuple[int, int]


Expand Down Expand Up @@ -79,8 +78,8 @@ def overlap_crop_image(
Returns:
OverlapCropOutput: Dictionary containing:
- global_crop: Single resized crop of full image
- crops: List of overlapping cropped regions
- crops: A numpy array containing the global crop of the full image (index 0)
followed by the overlapping cropped regions (indices 1+)
- tiling: Tuple of (height,width) tile counts
"""
original_h, original_w = image.shape[:2]
Expand All @@ -102,6 +101,12 @@ def overlap_crop_image(
max_crops,
)

# Pre-allocate crops.
n_crops = tiling[0] * tiling[1] + 1 # 1 = global crop
crops = np.zeros(
(n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
)

# Resize image to fit tiling
target_size = (
tiling[0] * crop_window_size + total_margin_pixels,
Expand All @@ -110,8 +115,6 @@ def overlap_crop_image(

# Convert to vips for resizing
vips_image = pyvips.Image.new_from_array(image)

# Resize using vips
scale_x = target_size[1] / image.shape[1]
scale_y = target_size[0] / image.shape[0]
resized = vips_image.resize(scale_x, vscale=scale_y)
Expand All @@ -121,10 +124,7 @@ def overlap_crop_image(
scale_x = base_size[1] / vips_image.width
scale_y = base_size[0] / vips_image.height
global_vips = vips_image.resize(scale_x, vscale=scale_y)
global_crop = global_vips.numpy()

# Extract crops with overlap
crops = []
crops[0] = global_vips.numpy()

for i in range(tiling[0]):
for j in range(tiling[1]):
Expand All @@ -136,18 +136,12 @@ def overlap_crop_image(
y_end = min(y0 + base_size[0], image.shape[0])
x_end = min(x0 + base_size[1], image.shape[1])

crop = np.zeros(
(base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
)
crop_region = image[y0:y_end, x0:x_end]
crop[: crop_region.shape[0], : crop_region.shape[1]] = crop_region
crops.append(crop)

return {
"global_crop": global_crop,
"local_crops": crops,
"tiling": tiling,
}
crops[
1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
] = crop_region

return {"crops": crops, "tiling": tiling}


def reconstruct_from_crops(
Expand Down
65 changes: 40 additions & 25 deletions moondream/torch/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F
import numpy as np

from typing import Union, Tuple
from einops import rearrange
from PIL import Image

Expand All @@ -20,39 +21,23 @@ def adaptive_avg_pool2d(input, output_size):
adaptive_avg_pool2d = F.adaptive_avg_pool2d


def encode_image(
image: Image.Image, weights: VisionModel, config: VisionConfig
) -> torch.Tensor:
def prepare_crops(
image: Image.Image, config: VisionConfig, device: Union[str, torch.device, int]
) -> Tuple[torch.Tensor, Tuple[int, int]]:
np_image = np.array(image.convert("RGB"))
crops = overlap_crop_image(np_image, max_crops=12, overlap_margin=4)
all_crops = np.stack([crops["global_crop"], *crops["local_crops"]], axis=0)
overlap_crops = overlap_crop_image(
np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
)
all_crops = overlap_crops["crops"]
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
all_crops = (
torch.from_numpy(all_crops)
.to(device=weights.pos_emb.device, dtype=torch.float16)
.to(device=device, dtype=torch.float16)
.div_(255.0)
.sub_(0.5)
.div_(0.5)
)

outputs = vision_encoder(all_crops, weights, config)

global_features = outputs[0]
local_features = outputs[1:].view(-1, 27, 27, 1152)

reconstructed = reconstruct_from_crops(
local_features,
crops["tiling"],
patch_size=1,
overlap_margin=4,
)

reconstructed = reconstructed.permute(2, 0, 1)
reconstructed = adaptive_avg_pool2d(reconstructed, output_size=(27, 27))
reconstructed = reconstructed.permute(1, 2, 0).view(729, 1152)
final_features = torch.cat([global_features, reconstructed], dim=-1)

return mlp(final_features, weights.proj_mlp)
return all_crops, overlap_crops["tiling"]


def vision_encoder(input_BCHW: torch.Tensor, w: VisionModel, config: VisionConfig):
Expand All @@ -71,3 +56,33 @@ def vision_encoder(input_BCHW: torch.Tensor, w: VisionModel, config: VisionConfi
x = layer_norm(x, w.post_ln)

return x


def vision_projection(
global_features: torch.Tensor, reconstructed: torch.Tensor, w: VisionModel
):
reconstructed = reconstructed.permute(2, 0, 1)
reconstructed = adaptive_avg_pool2d(reconstructed, output_size=(27, 27))
reconstructed = reconstructed.permute(1, 2, 0).view(729, 1152)
final_features = torch.cat([global_features, reconstructed], dim=-1)
return mlp(final_features, w.proj_mlp)


def encode_image(
image: Image.Image, weights: VisionModel, config: VisionConfig
) -> torch.Tensor:
# This is split into sub-functions to allow sections to be compiled without
# graph breaks, which is needed if we want to enable reduce-overhead mode.
# `vision_encoder` and `vision_projection` can be compiled if needed.

all_crops, tiling = prepare_crops(image, config, device=weights.pos_emb.device)

outputs = vision_encoder(all_crops, weights, config)

global_features = outputs[0]
local_features = outputs[1:].view(-1, 27, 27, 1152)
reconstructed = reconstruct_from_crops(
local_features, tiling, patch_size=1, overlap_margin=config.overlap_margin
)

return vision_projection(global_features, reconstructed, weights)

0 comments on commit d2b0c26

Please sign in to comment.