Skip to content

Commit

Permalink
add some hint typing
Browse files Browse the repository at this point in the history
  • Loading branch information
eljandoubi committed Aug 25, 2024
1 parent 428b83b commit a25de92
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions src/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(self,
# self.inv_freq.to(x.device)
# Copy the inv_freq tensor for batch in the sequence
# inv_freq_expanded: [Batch_Size, Head_Dim // 2, 1]
inv_freq_expanded = self.inv_freq[None, :, None].expand(
inv_freq_expanded: torch.FloatTensor = self.inv_freq[None, :, None].expand(
position_ids.shape[0], -1, 1)
# position_ids_expanded: [Batch_Size, 1, Seq_Len]
position_ids_expanded = position_ids[:, None, :].float()
Expand Down Expand Up @@ -123,11 +123,11 @@ def forward(self,
# [Batch_Size, Seq_Len, Hidden_Size]
bsz, q_len, _ = hidden_states.size()
# [Batch_Size, Seq_Len, Num_Heads_Q * Head_Dim]
query_states = self.q_proj(hidden_states)
query_states: torch.FloatTensor = self.q_proj(hidden_states)
# [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim]
key_states = self.k_proj(hidden_states)
key_states: torch.FloatTensor = self.k_proj(hidden_states)
# [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim]
value_states = self.v_proj(hidden_states)
value_states: torch.FloatTensor = self.v_proj(hidden_states)
# [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim]
query_states = query_states.view(bsz, q_len,
self.num_heads,
Expand Down Expand Up @@ -155,7 +155,7 @@ def forward(self,
key_states, value_states = kv_cache.update(
key_states, value_states, self.layer_idx)

attn_output = F.scaled_dot_product_attention(query_states,
attn_output: torch.FloatTensor = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
attn_mask=attention_mask,
Expand Down
2 changes: 1 addition & 1 deletion src/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(self,
assert torch.all(attention_mask == 1), "The input can not be padded "
# 1. Extract the input embeddings
# shape: (Batch_Size, Seq_Len, Hidden_Size)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds: torch.FloatTensor = self.language_model.get_input_embeddings()(input_ids)
# 2. Merge text and images
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
selected_image_feature = self.vision_tower(
Expand Down
10 changes: 5 additions & 5 deletions src/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]
# where Num_Patches_H = height // patch_size and Num_Patches_W = width
# // patch_size
patch_embeds = self.patch_embedding(pixel_values)
patch_embeds: torch.FloatTensor = self.patch_embedding(pixel_values)
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]
# -> [Batch_Size, Embed_Dim, Num_Patches]
# where Num_Patches = Num_Patches_H * Num_Patches_W
Expand All @@ -77,7 +77,7 @@ def __init__(self, config: SiglipVisionConfig):
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
# Equivalent to 1 / sqrt(self.head_dim)
self.scale = self.head_dim**-0.5
self.scale: float = self.head_dim**-0.5
self.dropout = config.attention_dropout

self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
Expand All @@ -91,9 +91,9 @@ def forward(self,
"""Forward method"""
batch_size, seq_len, _ = hidden_states.size()
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states: torch.FloatTensor = self.q_proj(hidden_states)
key_states: torch.FloatTensor = self.k_proj(hidden_states)
value_states: torch.FloatTensor = self.v_proj(hidden_states)
# [Batch_Size, Num_Heads, Num_Patches, Head_Dim]
query_states = query_states.view(batch_size, seq_len,
self.num_heads, self.head_dim
Expand Down

0 comments on commit a25de92

Please sign in to comment.