Skip to content

Commit

Permalink
Update (base update)
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vikhyat committed Jan 3, 2025
1 parent 9ceccdb commit 9f49364
Show file tree
Hide file tree
Showing 7 changed files with 498 additions and 47 deletions.
44 changes: 44 additions & 0 deletions moondream/torch/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from dataclasses import dataclass


@dataclass(frozen=True)
class TextConfig:
dim: int = 2048
n_layers: int = 24
vocab_size: int = 51200
max_context: int = 2048
n_heads: int = 32


@dataclass(frozen=True)
class VisionConfig:
enc_dim: int = 1152
enc_patch_size: int = 14
enc_n_layers: int = 27
enc_ff_dim: int = 4304
enc_n_heads: int = 16
crop_size: int = 378
in_channels: int = 3


@dataclass(frozen=True)
class RegionConfig:
dim: int = 2048
coord_feat_dim: int = 256
coord_out_dim: int = 1024
size_feat_dim: int = 512
size_out_dim: int = 2048


@dataclass
class MoondreamConfig:
text: TextConfig = TextConfig()
vision: VisionConfig = VisionConfig()
region: RegionConfig = RegionConfig()

@classmethod
def from_dict(cls, config_dict: dict):
text_config = TextConfig(**config_dict.get("text", {}))
vision_config = VisionConfig(**config_dict.get("vision", {}))
region_config = RegionConfig(**config_dict.get("region", {}))
return cls(text=text_config, vision=vision_config, region=region_config)
10 changes: 3 additions & 7 deletions moondream/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ class MLPWeights:

def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
x = linear(x, w.fc1)
if w.act == "gelu_approx":
x = gelu_approx(x)
else:
raise NotImplementedError(f"Activation function {w.act} not implemented.")
x = gelu_approx(x)
x = linear(x, w.fc2)
return x

Expand All @@ -50,12 +47,11 @@ def mlp(x: torch.Tensor, w: MLPWeights) -> torch.Tensor:
class AttentionWeights:
qkv: LinearWeights
proj: LinearWeights
n_heads: int


def attn(x: torch.Tensor, w: AttentionWeights) -> torch.Tensor:
def attn(x: torch.Tensor, w: AttentionWeights, n_heads: int) -> torch.Tensor:
bsz, q_len, d_model = x.shape
n_heads, head_dim = w.n_heads, d_model // w.n_heads
head_dim = d_model // n_heads

q, k, v = [
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
Expand Down
171 changes: 171 additions & 0 deletions moondream/torch/moondream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import torch
import torch.nn as nn

from .config import MoondreamConfig


class MoondreamModel(nn.Module):
def __init__(self, config: MoondreamConfig, dtype=torch.float16):
super().__init__()
self.config = config

# Vision Model
patch_dim = (
config.vision.enc_patch_size
* config.vision.enc_patch_size
* config.vision.in_channels
)
grid_size = config.vision.crop_size // config.vision.enc_patch_size
num_patches = grid_size * grid_size

self.vision = nn.ModuleDict(
{
"patch_emb": nn.Linear(patch_dim, config.vision.enc_dim, dtype=dtype),
"blocks": nn.ModuleList(
[
nn.ModuleDict(
{
"ln1": nn.LayerNorm(config.vision.enc_dim, dtype=dtype),
"attn": nn.ModuleDict(
{
"qkv": nn.Linear(
config.vision.enc_dim,
3 * config.vision.enc_dim,
dtype=dtype,
),
"proj": nn.Linear(
config.vision.enc_dim,
config.vision.enc_dim,
dtype=dtype,
),
}
),
"ln2": nn.LayerNorm(config.vision.enc_dim, dtype=dtype),
"mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.vision.enc_dim,
config.vision.enc_ff_dim,
dtype=dtype,
),
"fc2": nn.Linear(
config.vision.enc_ff_dim,
config.vision.enc_dim,
dtype=dtype,
),
}
),
}
)
for _ in range(config.vision.enc_n_layers)
]
),
"post_ln": nn.LayerNorm(config.vision.enc_dim, dtype=dtype),
"proj_mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.vision.enc_dim * 2, config.text.dim * 4, dtype=dtype
),
"fc2": nn.Linear(
config.text.dim * 4, config.text.dim, dtype=dtype
),
}
),
}
)
self.vision.pos_emb = nn.Parameter(
torch.zeros(1, num_patches, config.vision.enc_dim, dtype=dtype)
)

# Text Model
self.text = nn.ModuleDict(
{
"blocks": nn.ModuleList(
[
nn.ModuleDict(
{
"ln": nn.LayerNorm(config.text.dim, dtype=dtype),
"attn": nn.ModuleDict(
{
"qkv": nn.Linear(
config.text.dim,
3 * config.text.dim,
dtype=dtype,
),
"proj": nn.Linear(
config.text.dim,
config.text.dim,
dtype=dtype,
),
}
),
"mlp": nn.ModuleDict(
{
"fc1": nn.Linear(
config.text.dim,
4 * config.text.dim,
dtype=dtype,
),
"fc2": nn.Linear(
4 * config.text.dim,
config.text.dim,
dtype=dtype,
),
}
),
}
)
for _ in range(config.text.n_layers)
]
),
"post_ln": nn.LayerNorm(config.text.dim, dtype=dtype),
"lm_head": nn.Linear(
config.text.dim, config.text.vocab_size, dtype=dtype
),
}
)
self.text.wte = nn.Parameter(
torch.empty(config.text.vocab_size, config.text.dim, dtype=dtype)
)

# Region Model
self.region = nn.ModuleDict(
{
"coord_encoder": nn.Linear(
config.region.coord_feat_dim, config.region.dim, dtype=dtype
),
"coord_decoder": nn.ModuleDict(
{
"fc1": nn.Linear(
config.region.dim, config.region.dim * 4, dtype=dtype
),
"fc2": nn.Linear(
config.region.dim * 4,
config.region.coord_out_dim,
dtype=dtype,
),
}
),
"size_encoder": nn.Linear(
config.region.size_feat_dim, config.region.dim, dtype=dtype
),
"size_decoder": nn.ModuleDict(
{
"fc1": nn.Linear(
config.region.dim, config.region.dim * 4, dtype=dtype
),
"fc2": nn.Linear(
config.region.dim * 4,
config.region.size_out_dim,
dtype=dtype,
),
}
),
}
)
self.region.coord_features = nn.Parameter(
torch.empty(config.region.coord_feat_dim // 2, 1, dtype=dtype).T
)
self.region.size_features = nn.Parameter(
torch.empty(config.region.size_feat_dim // 2, 2, dtype=dtype).T
)
69 changes: 45 additions & 24 deletions moondream/torch/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from .rope import precompute_freqs_cis
from .text import lm_head, text_decoder, text_encoder
from .vision import encode_image
from .weights import load_from_pt, load_from_safetensors
from .weights import load_weights_into_model
from .moondream import MoondreamModel, MoondreamConfig

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -28,49 +29,69 @@

# Load config.
config = json.loads(args.config)
text_n_heads = config.get("text_n_heads", 32)

# Load model.
model_path = args.model
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found at {model_path}")
if model_path.endswith(".pt"):
model = load_from_pt(model_path, **config)
elif model_path.endswith(".safetensors"):
model = load_from_safetensors(model_path, **config)
else:
raise ValueError(f"Invalid model format: {model_path}")
# model_path = args.model
# if not os.path.exists(model_path):
# raise FileNotFoundError(f"Model not found at {model_path}")
# if model_path.endswith(".pt"):
# model = load_from_pt(model_path, **config)
# elif model_path.endswith(".safetensors"):
# model = load_from_safetensors(model_path, **config)
# else:
# raise ValueError(f"Invalid model format: {model_path}")

# Load model.
config = MoondreamConfig()
model = MoondreamModel(config)
load_weights_into_model(args.model, model)

# Encode image.
image_path = args.image
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image not found at {image_path}")
image = Image.open(image_path)
image_tensor = encode_image(image, model.vision)
with torch.no_grad():
image_tensor = encode_image(image, model.vision, config.vision)

# Encode text, and create inputs_embeds.
tokenizer = AutoTokenizer.from_pretrained("vikhyatk/moondream2")
prompt = f"\n\nQuestion: {args.prompt}\n\nAnswer:"
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
input_ids = torch.cat([torch.tensor([[tokenizer.eos_token_id]]), input_ids], dim=1)
inputs_embeds = text_encoder(input_ids, model.text)
inputs_embeds = torch.cat(
[
inputs_embeds[:, 0:1, :],
image_tensor.unsqueeze(0),
inputs_embeds[:, 1:, :],
],
dim=1,
)
with torch.no_grad():
input_ids = torch.cat(
[torch.tensor([[tokenizer.eos_token_id]]), input_ids], dim=1
)
inputs_embeds = text_encoder(input_ids, model.text)
inputs_embeds = torch.cat(
[
inputs_embeds[:, 0:1, :],
image_tensor.unsqueeze(0),
inputs_embeds[:, 1:, :],
],
dim=1,
)

kv_cache = torch.empty(24, 2, 1, text_n_heads, 2048, 64, dtype=torch.float16)
kv_cache = torch.empty(
config.text.n_layers,
2, # k, v
1, # bsz
config.text.n_heads,
config.text.max_context,
config.text.dim // config.text.n_heads,
dtype=torch.float16,
)
freqs_cis = precompute_freqs_cis(32, 2048)
pos = 0

for _ in range(args.max_tokens):
with torch.no_grad():
hidden, kv_cache_update = text_decoder(
inputs_embeds, model.text, kv_cache[:, :, :, :, :pos, :], freqs_cis
inputs_embeds,
model.text,
kv_cache[:, :, :, :, :pos, :],
freqs_cis,
config.text,
)
logits = lm_head(hidden, model.text)
kv_cache[:, :, :, :, pos : pos + kv_cache_update.size(-2), :] = (
Expand Down
13 changes: 9 additions & 4 deletions moondream/torch/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .layers import layer_norm, linear, mlp
from .rope import apply_rotary_emb
from .weights import AttentionWeights, TextModel
from .config import TextConfig


def text_encoder(input_ids: torch.Tensor, w: TextModel):
Expand Down Expand Up @@ -35,19 +36,20 @@ def attn(
w: AttentionWeights,
freqs_cis: torch.Tensor,
layer_kv_cache: torch.Tensor,
n_heads: int,
):
bsz, q_len, d_model = x.shape
pos = 0 if layer_kv_cache is None else layer_kv_cache.shape[3]
n_heads, head_dim = w.n_heads, d_model // w.n_heads
head_dim = d_model // n_heads

q, k, v = [
t.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
for t in linear(x, w.qkv).chunk(3, dim=-1)
]

position_ids = torch.arange(pos, pos + q_len, dtype=torch.long)
q = apply_rotary_emb(q, freqs_cis, position_ids, w.n_heads)
k = apply_rotary_emb(k, freqs_cis, position_ids, w.n_heads)
q = apply_rotary_emb(q, freqs_cis, position_ids, n_heads)
k = apply_rotary_emb(k, freqs_cis, position_ids, n_heads)

k_, v_ = k, v
if layer_kv_cache is not None:
Expand All @@ -70,13 +72,16 @@ def text_decoder(
w: TextModel,
kv_cache: torch.Tensor,
freqs_cis: torch.Tensor,
config: TextConfig,
):
hidden_BTC = inputs_embeds
new_kv_cache = [torch.empty(0)] * len(w.blocks)

for i, block in enumerate(w.blocks):
l_in = layer_norm(hidden_BTC, block.ln)
l_attn, new_kv_cache[i] = attn(l_in, block.attn, freqs_cis, kv_cache[i])
l_attn, new_kv_cache[i] = attn(
l_in, block.attn, freqs_cis, kv_cache[i], n_heads=config.n_heads
)
l_mlp = mlp(l_in, block.mlp)
hidden_BTC = hidden_BTC + l_attn + l_mlp

Expand Down
Loading

0 comments on commit 9f49364

Please sign in to comment.