Skip to content

Commit

Permalink
max prefix self-attention token count configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
vikhyat committed Jan 5, 2025
1 parent 47b48de commit ceb8236
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions moondream/torch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class TextConfig:
vocab_size: int = 51200
max_context: int = 2048
n_heads: int = 32
prefix_attn: int = 730


@dataclass(frozen=True)
Expand Down
3 changes: 2 additions & 1 deletion moondream/torch/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module:
attn_mask = torch.tril(
torch.ones(1, 1, config.max_context, config.max_context, dtype=torch.bool)
)
attn_mask[..., :730, :730] = 1
if config.prefix_attn != 0:
attn_mask[..., : config.prefix_attn, : config.prefix_attn] = 1
text.register_buffer("attn_mask", attn_mask, persistent=False)

return text

0 comments on commit ceb8236

Please sign in to comment.