From 9f49364a9ec607d7ff32228e2a6c1255975b5a72 Mon Sep 17 00:00:00 2001 From: vik Date: Fri, 3 Jan 2025 13:03:12 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- moondream/torch/config.py | 44 +++++++ moondream/torch/layers.py | 10 +- moondream/torch/moondream.py | 171 +++++++++++++++++++++++++++ moondream/torch/sample.py | 69 +++++++---- moondream/torch/text.py | 13 +- moondream/torch/vision.py | 15 ++- moondream/torch/weights.py | 223 ++++++++++++++++++++++++++++++++++- 7 files changed, 498 insertions(+), 47 deletions(-) create mode 100644 moondream/torch/config.py create mode 100644 moondream/torch/moondream.py diff --git a/moondream/torch/config.py b/moondream/torch/config.py new file mode 100644 index 00000000..a67ee4ed --- /dev/null +++ b/moondream/torch/config.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass + + +@dataclass(frozen=True) +class TextConfig: + dim: int = 2048 + n_layers: int = 24 + vocab_size: int = 51200 + max_context: int = 2048 + n_heads: int = 32 + + +@dataclass(frozen=True) +class VisionConfig: + enc_dim: int = 1152 + enc_patch_size: int = 14 + enc_n_layers: int = 27 + enc_ff_dim: int = 4304 + enc_n_heads: int = 16 + crop_size: int = 378 + in_channels: int = 3 + + +@dataclass(frozen=True) +class RegionConfig: + dim: int = 2048 + coord_feat_dim: int = 256 + coord_out_dim: int = 1024 + size_feat_dim: int = 512 + size_out_dim: int = 2048 + + +@dataclass +class MoondreamConfig: + text: TextConfig = TextConfig() + vision: VisionConfig = VisionConfig() + region: RegionConfig = RegionConfig() + + @classmethod + def from_dict(cls, config_dict: dict): + text_config = TextConfig(**config_dict.get("text", {})) + vision_config = VisionConfig(**config_dict.get("vision", {})) + region_config = RegionConfig(**config_dict.get("region", {})) + return cls(text=text_config, vision=vision_config, region=region_config) diff --git a/moondream/torch/layers.py b/moondream/torch/layers.py index eee5bbf1..a515293f 100644 --- a/moondream/torch/layers.py +++ b/moondream/torch/layers.py @@ -38,10 +38,7 @@ class MLPWeights: def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: x = linear(x, w.fc1) - if w.act == "gelu_approx": - x = gelu_approx(x) - else: - raise NotImplementedError(f"Activation function {w.act} not implemented.") + x = gelu_approx(x) x = linear(x, w.fc2) return x @@ -50,12 +47,11 @@ def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor: class AttentionWeights: qkv: LinearWeights proj: LinearWeights - n_heads: int -def attn(x: torch.Tensor, w: AttentionWeights) -> torch.Tensor: +def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor: bsz, q_len, d_model = x.shape - n_heads, head_dim = w.n_heads, d_model // w.n_heads + head_dim = d_model // n_heads q, k, v = [ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) diff --git a/moondream/torch/moondream.py b/moondream/torch/moondream.py new file mode 100644 index 00000000..97beaf5b --- /dev/null +++ b/moondream/torch/moondream.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn + +from .config import MoondreamConfig + + +class MoondreamModel(nn.Module): + def __init__(self, config: MoondreamConfig, dtype=torch.float16): + super().__init__() + self.config = config + + # Vision Model + patch_dim = ( + config.vision.enc_patch_size + * config.vision.enc_patch_size + * config.vision.in_channels + ) + grid_size = config.vision.crop_size // config.vision.enc_patch_size + num_patches = grid_size * grid_size + + self.vision = nn.ModuleDict( + { + "patch_emb": nn.Linear(patch_dim, config.vision.enc_dim, dtype=dtype), + "blocks": nn.ModuleList( + [ + nn.ModuleDict( + { + "ln1": nn.LayerNorm(config.vision.enc_dim, dtype=dtype), + "attn": nn.ModuleDict( + { + "qkv": nn.Linear( + config.vision.enc_dim, + 3 * config.vision.enc_dim, + dtype=dtype, + ), + "proj": nn.Linear( + config.vision.enc_dim, + config.vision.enc_dim, + dtype=dtype, + ), + } + ), + "ln2": nn.LayerNorm(config.vision.enc_dim, dtype=dtype), + "mlp": nn.ModuleDict( + { + "fc1": nn.Linear( + config.vision.enc_dim, + config.vision.enc_ff_dim, + dtype=dtype, + ), + "fc2": nn.Linear( + config.vision.enc_ff_dim, + config.vision.enc_dim, + dtype=dtype, + ), + } + ), + } + ) + for _ in range(config.vision.enc_n_layers) + ] + ), + "post_ln": nn.LayerNorm(config.vision.enc_dim, dtype=dtype), + "proj_mlp": nn.ModuleDict( + { + "fc1": nn.Linear( + config.vision.enc_dim * 2, config.text.dim * 4, dtype=dtype + ), + "fc2": nn.Linear( + config.text.dim * 4, config.text.dim, dtype=dtype + ), + } + ), + } + ) + self.vision.pos_emb = nn.Parameter( + torch.zeros(1, num_patches, config.vision.enc_dim, dtype=dtype) + ) + + # Text Model + self.text = nn.ModuleDict( + { + "blocks": nn.ModuleList( + [ + nn.ModuleDict( + { + "ln": nn.LayerNorm(config.text.dim, dtype=dtype), + "attn": nn.ModuleDict( + { + "qkv": nn.Linear( + config.text.dim, + 3 * config.text.dim, + dtype=dtype, + ), + "proj": nn.Linear( + config.text.dim, + config.text.dim, + dtype=dtype, + ), + } + ), + "mlp": nn.ModuleDict( + { + "fc1": nn.Linear( + config.text.dim, + 4 * config.text.dim, + dtype=dtype, + ), + "fc2": nn.Linear( + 4 * config.text.dim, + config.text.dim, + dtype=dtype, + ), + } + ), + } + ) + for _ in range(config.text.n_layers) + ] + ), + "post_ln": nn.LayerNorm(config.text.dim, dtype=dtype), + "lm_head": nn.Linear( + config.text.dim, config.text.vocab_size, dtype=dtype + ), + } + ) + self.text.wte = nn.Parameter( + torch.empty(config.text.vocab_size, config.text.dim, dtype=dtype) + ) + + # Region Model + self.region = nn.ModuleDict( + { + "coord_encoder": nn.Linear( + config.region.coord_feat_dim, config.region.dim, dtype=dtype + ), + "coord_decoder": nn.ModuleDict( + { + "fc1": nn.Linear( + config.region.dim, config.region.dim * 4, dtype=dtype + ), + "fc2": nn.Linear( + config.region.dim * 4, + config.region.coord_out_dim, + dtype=dtype, + ), + } + ), + "size_encoder": nn.Linear( + config.region.size_feat_dim, config.region.dim, dtype=dtype + ), + "size_decoder": nn.ModuleDict( + { + "fc1": nn.Linear( + config.region.dim, config.region.dim * 4, dtype=dtype + ), + "fc2": nn.Linear( + config.region.dim * 4, + config.region.size_out_dim, + dtype=dtype, + ), + } + ), + } + ) + self.region.coord_features = nn.Parameter( + torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T + ) + self.region.size_features = nn.Parameter( + torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T + ) diff --git a/moondream/torch/sample.py b/moondream/torch/sample.py index f25c5dfd..071dcbd0 100644 --- a/moondream/torch/sample.py +++ b/moondream/torch/sample.py @@ -9,7 +9,8 @@ from .rope import precompute_freqs_cis from .text import lm_head, text_decoder, text_encoder from .vision import encode_image -from .weights import load_from_pt, load_from_safetensors +from .weights import load_weights_into_model +from .moondream import MoondreamModel, MoondreamConfig if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -28,49 +29,69 @@ # Load config. config = json.loads(args.config) - text_n_heads = config.get("text_n_heads", 32) # Load model. - model_path = args.model - if not os.path.exists(model_path): - raise FileNotFoundError(f"Model not found at {model_path}") - if model_path.endswith(".pt"): - model = load_from_pt(model_path, **config) - elif model_path.endswith(".safetensors"): - model = load_from_safetensors(model_path, **config) - else: - raise ValueError(f"Invalid model format: {model_path}") + # model_path = args.model + # if not os.path.exists(model_path): + # raise FileNotFoundError(f"Model not found at {model_path}") + # if model_path.endswith(".pt"): + # model = load_from_pt(model_path, **config) + # elif model_path.endswith(".safetensors"): + # model = load_from_safetensors(model_path, **config) + # else: + # raise ValueError(f"Invalid model format: {model_path}") + + # Load model. + config = MoondreamConfig() + model = MoondreamModel(config) + load_weights_into_model(args.model, model) # Encode image. image_path = args.image if not os.path.exists(image_path): raise FileNotFoundError(f"Image not found at {image_path}") image = Image.open(image_path) - image_tensor = encode_image(image, model.vision) + with torch.no_grad(): + image_tensor = encode_image(image, model.vision, config.vision) # Encode text, and create inputs_embeds. tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2") prompt = f"\n\nQuestion: {args.prompt}\n\nAnswer:" input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"] - input_ids = torch.cat([torch.tensor([[tokenizer.eos_token_id]]), input_ids], dim=1) - inputs_embeds = text_encoder(input_ids, model.text) - inputs_embeds = torch.cat( - [ - inputs_embeds[:, 0:1, :], - image_tensor.unsqueeze(0), - inputs_embeds[:, 1:, :], - ], - dim=1, - ) + with torch.no_grad(): + input_ids = torch.cat( + [torch.tensor([[tokenizer.eos_token_id]]), input_ids], dim=1 + ) + inputs_embeds = text_encoder(input_ids, model.text) + inputs_embeds = torch.cat( + [ + inputs_embeds[:, 0:1, :], + image_tensor.unsqueeze(0), + inputs_embeds[:, 1:, :], + ], + dim=1, + ) - kv_cache = torch.empty(24, 2, 1, text_n_heads, 2048, 64, dtype=torch.float16) + kv_cache = torch.empty( + config.text.n_layers, + 2, # k, v + 1, # bsz + config.text.n_heads, + config.text.max_context, + config.text.dim // config.text.n_heads, + dtype=torch.float16, + ) freqs_cis = precompute_freqs_cis(32, 2048) pos = 0 for _ in range(args.max_tokens): with torch.no_grad(): hidden, kv_cache_update = text_decoder( - inputs_embeds, model.text, kv_cache[:, :, :, :, :pos, :], freqs_cis + inputs_embeds, + model.text, + kv_cache[:, :, :, :, :pos, :], + freqs_cis, + config.text, ) logits = lm_head(hidden, model.text) kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = ( diff --git a/moondream/torch/text.py b/moondream/torch/text.py index ddbe51df..f5ae3949 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -4,6 +4,7 @@ from .layers import layer_norm, linear, mlp from .rope import apply_rotary_emb from .weights import AttentionWeights, TextModel +from .config import TextConfig def text_encoder(input_ids: torch.Tensor, w: TextModel): @@ -35,10 +36,11 @@ def attn( w: AttentionWeights, freqs_cis: torch.Tensor, layer_kv_cache: torch.Tensor, + n_heads: int, ): bsz, q_len, d_model = x.shape pos = 0 if layer_kv_cache is None else layer_kv_cache.shape[3] - n_heads, head_dim = w.n_heads, d_model // w.n_heads + head_dim = d_model // n_heads q, k, v = [ t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2) @@ -46,8 +48,8 @@ def attn( ] position_ids = torch.arange(pos, pos + q_len, dtype=torch.long) - q = apply_rotary_emb(q, freqs_cis, position_ids, w.n_heads) - k = apply_rotary_emb(k, freqs_cis, position_ids, w.n_heads) + q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads) + k = apply_rotary_emb(k, freqs_cis, position_ids, n_heads) k_, v_ = k, v if layer_kv_cache is not None: @@ -70,13 +72,16 @@ def text_decoder( w: TextModel, kv_cache: torch.Tensor, freqs_cis: torch.Tensor, + config: TextConfig, ): hidden_BTC = inputs_embeds new_kv_cache = [torch.empty(0)] * len(w.blocks) for i, block in enumerate(w.blocks): l_in = layer_norm(hidden_BTC, block.ln) - l_attn, new_kv_cache[i] = attn(l_in, block.attn, freqs_cis, kv_cache[i]) + l_attn, new_kv_cache[i] = attn( + l_in, block.attn, freqs_cis, kv_cache[i], n_heads=config.n_heads + ) l_mlp = mlp(l_in, block.mlp) hidden_BTC = hidden_BTC + l_attn + l_mlp diff --git a/moondream/torch/vision.py b/moondream/torch/vision.py index 76ca896d..33b57501 100644 --- a/moondream/torch/vision.py +++ b/moondream/torch/vision.py @@ -8,6 +8,7 @@ from .layers import attn, layer_norm, linear, mlp from .weights import VisionModel from .image_crops import overlap_crop_image, reconstruct_from_crops +from .config import VisionConfig if torch.backends.mps.is_available(): # Non-divisible input sizes are not implemented on MPS device yet. @@ -19,7 +20,9 @@ def adaptive_avg_pool2d(input, output_size): adaptive_avg_pool2d = F.adaptive_avg_pool2d -def encode_image(image: Image.Image, weights: VisionModel) -> torch.Tensor: +def encode_image( + image: Image.Image, weights: VisionModel, config: VisionConfig +) -> torch.Tensor: np_image = np.array(image.convert("RGB")) crops = overlap_crop_image(np_image, max_crops=12, overlap_margin=4) all_crops = np.stack([crops["global_crop"], *crops["local_crops"]], axis=0) @@ -32,7 +35,7 @@ def encode_image(image: Image.Image, weights: VisionModel) -> torch.Tensor: .div_(0.5) ) - outputs = vision_encoder(all_crops, weights) + outputs = vision_encoder(all_crops, weights, config) global_features = outputs[0] local_features = outputs[1:].view(-1, 27, 27, 1152) @@ -52,18 +55,18 @@ def encode_image(image: Image.Image, weights: VisionModel) -> torch.Tensor: return mlp(final_features, weights.proj_mlp) -def vision_encoder(input_BCHW: torch.Tensor, w: VisionModel): +def vision_encoder(input_BCHW: torch.Tensor, w: VisionModel, config: VisionConfig): x = rearrange( input_BCHW, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", - p1=w.patch_size, - p2=w.patch_size, + p1=config.enc_patch_size, + p2=config.enc_patch_size, ) # B3HW -> B(HxW)(3xP1xP2), aka BTC x = linear(x, w.patch_emb) x = x + w.pos_emb for block in w.blocks: - x = x + attn(layer_norm(x, block.ln1), block.attn) + x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads) x = x + mlp(layer_norm(x, block.ln2), block.mlp) x = layer_norm(x, w.post_ln) diff --git a/moondream/torch/weights.py b/moondream/torch/weights.py index 08850c4b..9f559a50 100644 --- a/moondream/torch/weights.py +++ b/moondream/torch/weights.py @@ -1,4 +1,3 @@ -import math from contextlib import contextmanager from dataclasses import dataclass from typing import Callable, List @@ -7,6 +6,7 @@ import torch from .layers import AttentionWeights, LayerNormWeights, LinearWeights, MLPWeights +from .moondream import MoondreamModel as MoondreamModule @dataclass @@ -19,7 +19,6 @@ class VisionBlock: @dataclass class VisionModel: - patch_size: int patch_emb: LinearWeights pos_emb: torch.Tensor blocks: List[VisionBlock] @@ -72,6 +71,11 @@ def safetensors_open(safetensors_file: str): def get_tensor(name: str) -> torch.Tensor: return st.get_tensor(name) + def get_keys() -> List[str]: + return st.keys() + + get_tensor.keys = get_keys + yield get_tensor @@ -87,7 +91,6 @@ def load_model( patch_emb = LinearWeights( weight=get_tensor(f"{prefix}.weight"), bias=get_tensor(f"{prefix}.bias") ) - patch_size = int(math.sqrt(patch_emb.weight.shape[1] // 3)) pos_emb = get_tensor("vision_encoder.encoder.model.visual.pos_embed") post_ln = LayerNormWeights( weight=get_tensor("vision_encoder.encoder.model.visual.norm.weight"), @@ -111,7 +114,6 @@ def load_model( weight=get_tensor(f"{prefix}.attn.proj.weight"), bias=get_tensor(f"{prefix}.attn.proj.bias"), ), - n_heads=vision_n_heads, ), ln2=LayerNormWeights( weight=get_tensor(f"{prefix}.norm2.weight"), @@ -141,7 +143,6 @@ def load_model( act="gelu_approx", ) vision = VisionModel( - patch_size=patch_size, patch_emb=patch_emb, pos_emb=pos_emb, blocks=blocks, @@ -177,7 +178,6 @@ def load_model( weight=get_tensor(f"{prefix}.mixer.out_proj.weight"), bias=get_tensor(f"{prefix}.mixer.out_proj.bias"), ), - n_heads=text_n_heads, ), mlp=MLPWeights( fc1=LinearWeights( @@ -254,6 +254,217 @@ def load_from_pt( return load_model(lambda x: tensors[x], vision_blocks, text_blocks, **kwargs) +def _load_weights( + get_tensor: Callable[[str], torch.Tensor], model: MoondreamModule +) -> None: + """Internal function to load weights using a tensor getter function.""" + model = model.to(dtype=torch.float16) + + # Vision Model + model.vision["patch_emb"].weight.data.copy_( + get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.weight") + ) + model.vision["patch_emb"].bias.data.copy_( + get_tensor("vision_encoder.encoder.model.visual.patch_embed.linear.bias") + ) + model.vision.pos_emb.data.copy_( + get_tensor("vision_encoder.encoder.model.visual.pos_embed") + ) + + for i in range(len(model.vision["blocks"])): + prefix = f"vision_encoder.encoder.model.visual.blocks.{i}" + + # Layer norms + model.vision["blocks"][i]["ln1"].weight.data.copy_( + get_tensor(f"{prefix}.norm1.weight") + ) + model.vision["blocks"][i]["ln1"].bias.data.copy_( + get_tensor(f"{prefix}.norm1.bias") + ) + model.vision["blocks"][i]["ln2"].weight.data.copy_( + get_tensor(f"{prefix}.norm2.weight") + ) + model.vision["blocks"][i]["ln2"].bias.data.copy_( + get_tensor(f"{prefix}.norm2.bias") + ) + + # Attention + model.vision["blocks"][i]["attn"]["qkv"].weight.data.copy_( + get_tensor(f"{prefix}.attn.qkv.weight") + ) + model.vision["blocks"][i]["attn"]["qkv"].bias.data.copy_( + get_tensor(f"{prefix}.attn.qkv.bias") + ) + model.vision["blocks"][i]["attn"]["proj"].weight.data.copy_( + get_tensor(f"{prefix}.attn.proj.weight") + ) + model.vision["blocks"][i]["attn"]["proj"].bias.data.copy_( + get_tensor(f"{prefix}.attn.proj.bias") + ) + + # MLP + model.vision["blocks"][i]["mlp"]["fc1"].weight.data.copy_( + get_tensor(f"{prefix}.mlp.fc1.weight") + ) + model.vision["blocks"][i]["mlp"]["fc1"].bias.data.copy_( + get_tensor(f"{prefix}.mlp.fc1.bias") + ) + model.vision["blocks"][i]["mlp"]["fc2"].weight.data.copy_( + get_tensor(f"{prefix}.mlp.fc2.weight") + ) + model.vision["blocks"][i]["mlp"]["fc2"].bias.data.copy_( + get_tensor(f"{prefix}.mlp.fc2.bias") + ) + + model.vision["post_ln"].weight.data.copy_( + get_tensor("vision_encoder.encoder.model.visual.norm.weight") + ) + model.vision["post_ln"].bias.data.copy_( + get_tensor("vision_encoder.encoder.model.visual.norm.bias") + ) + + model.vision["proj_mlp"]["fc1"].weight.data.copy_( + get_tensor("vision_encoder.projection.mlp.fc1.weight") + ) + model.vision["proj_mlp"]["fc1"].bias.data.copy_( + get_tensor("vision_encoder.projection.mlp.fc1.bias") + ) + model.vision["proj_mlp"]["fc2"].weight.data.copy_( + get_tensor("vision_encoder.projection.mlp.fc2.weight") + ) + model.vision["proj_mlp"]["fc2"].bias.data.copy_( + get_tensor("vision_encoder.projection.mlp.fc2.bias") + ) + + # Text Model + model.text.wte.data.copy_(get_tensor("text_model.transformer.embd.wte.weight")) + + for i in range(len(model.text["blocks"])): + prefix = f"text_model.transformer.h.{i}" + + # Layer norm + model.text["blocks"][i]["ln"].weight.data.copy_( + get_tensor(f"{prefix}.ln.weight") + ) + model.text["blocks"][i]["ln"].bias.data.copy_(get_tensor(f"{prefix}.ln.bias")) + + # Attention + model.text["blocks"][i]["attn"]["qkv"].weight.data.copy_( + get_tensor(f"{prefix}.mixer.Wqkv.weight") + ) + model.text["blocks"][i]["attn"]["qkv"].bias.data.copy_( + get_tensor(f"{prefix}.mixer.Wqkv.bias") + ) + model.text["blocks"][i]["attn"]["proj"].weight.data.copy_( + get_tensor(f"{prefix}.mixer.out_proj.weight") + ) + model.text["blocks"][i]["attn"]["proj"].bias.data.copy_( + get_tensor(f"{prefix}.mixer.out_proj.bias") + ) + + # MLP + model.text["blocks"][i]["mlp"]["fc1"].weight.data.copy_( + get_tensor(f"{prefix}.mlp.fc1.weight") + ) + model.text["blocks"][i]["mlp"]["fc1"].bias.data.copy_( + get_tensor(f"{prefix}.mlp.fc1.bias") + ) + model.text["blocks"][i]["mlp"]["fc2"].weight.data.copy_( + get_tensor(f"{prefix}.mlp.fc2.weight") + ) + model.text["blocks"][i]["mlp"]["fc2"].bias.data.copy_( + get_tensor(f"{prefix}.mlp.fc2.bias") + ) + + model.text["post_ln"].weight.data.copy_(get_tensor("text_model.lm_head.ln.weight")) + model.text["post_ln"].bias.data.copy_(get_tensor("text_model.lm_head.ln.bias")) + + model.text["lm_head"].weight.data.copy_( + get_tensor("text_model.lm_head.linear.weight") + ) + model.text["lm_head"].bias.data.copy_(get_tensor("text_model.lm_head.linear.bias")) + + # Region Model + model.region.coord_features.data.copy_( + get_tensor("region_model.coordinate_features.weight").T + ) + model.region["coord_encoder"].weight.data.copy_( + get_tensor("region_model.coordinate_encoder.weight") + ) + model.region["coord_encoder"].bias.data.copy_( + get_tensor("region_model.coordinate_encoder.bias") + ) + + model.region["coord_decoder"]["fc1"].weight.data.copy_( + get_tensor("region_model.coordinate_decoder.fc1.weight") + ) + model.region["coord_decoder"]["fc1"].bias.data.copy_( + get_tensor("region_model.coordinate_decoder.fc1.bias") + ) + model.region["coord_decoder"]["fc2"].weight.data.copy_( + get_tensor("region_model.coordinate_decoder.fc2.weight") + ) + model.region["coord_decoder"]["fc2"].bias.data.copy_( + get_tensor("region_model.coordinate_decoder.fc2.bias") + ) + + model.region.size_features.data.copy_( + get_tensor("region_model.size_features.weight").T + ) + model.region["size_encoder"].weight.data.copy_( + get_tensor("region_model.size_encoder.weight") + ) + model.region["size_encoder"].bias.data.copy_( + get_tensor("region_model.size_encoder.bias") + ) + + model.region["size_decoder"]["fc1"].weight.data.copy_( + get_tensor("region_model.size_decoder.fc1.weight") + ) + model.region["size_decoder"]["fc1"].bias.data.copy_( + get_tensor("region_model.size_decoder.fc1.bias") + ) + model.region["size_decoder"]["fc2"].weight.data.copy_( + get_tensor("region_model.size_decoder.fc2.weight") + ) + model.region["size_decoder"]["fc2"].bias.data.copy_( + get_tensor("region_model.size_decoder.fc2.bias") + ) + + +def load_weights_from_safetensors(weights_file: str, model: MoondreamModule) -> None: + """Load weights from a safetensors file into a MoondreamModel instance.""" + with safetensors_open(weights_file) as get_tensor: + # Wrap the get_tensor function to handle key normalization + name_map = {k.replace("._orig_mod", ""): k for k in get_tensor.keys()} + _load_weights(lambda x: get_tensor(name_map[x]).to(dtype=torch.float16), model) + + +def load_weights_from_pt(weights_file: str, model: MoondreamModule) -> None: + """Load weights from a PyTorch file into a MoondreamModel instance.""" + device = str(torch.empty(0).device) + tensors = torch.load(weights_file, map_location=device, weights_only=True) + tensors = { + k.replace("._orig_mod", ""): v.to(dtype=torch.float16) + for k, v in tensors.items() + } + _load_weights(lambda x: tensors[x], model) + + +def load_weights_into_model(weights_file: str, model: MoondreamModule) -> None: + """ + Load weights from either a safetensors or PyTorch file directly into a MoondreamModel instance. + + Args: + weights_file: Path to weights file (either .safetensors or .pt) + model: MoondreamModel instance to load weights into + """ + if weights_file.endswith(".safetensors"): + load_weights_from_safetensors(weights_file, model) + else: + load_weights_from_pt(weights_file, model) + + if __name__ == "__main__": weights = load_from_safetensors("model.safetensors") print(weights)