diff --git a/moondream/torch/config.py b/moondream/torch/config.py index ef3f7875..acfe7e02 100644 --- a/moondream/torch/config.py +++ b/moondream/torch/config.py @@ -9,6 +9,7 @@ class TextConfig: vocab_size: int = 51200 max_context: int = 2048 n_heads: int = 32 + prefix_attn: int = 730 @dataclass(frozen=True) diff --git a/moondream/torch/text.py b/moondream/torch/text.py index d370d871..744e62f2 100644 --- a/moondream/torch/text.py +++ b/moondream/torch/text.py @@ -160,7 +160,8 @@ def build_text_model(config: TextConfig, dtype: torch.dtype) -> nn.Module: attn_mask = torch.tril( torch.ones(1, 1, config.max_context, config.max_context, dtype=torch.bool) ) - attn_mask[..., :730, :730] = 1 + if config.prefix_attn != 0: + attn_mask[..., : config.prefix_attn, : config.prefix_attn] = 1 text.register_buffer("attn_mask", attn_mask, persistent=False) return text