From a25de92d42ea44bc18fdfb3a44a891ef3adfc144 Mon Sep 17 00:00:00 2001 From: eljandoubi Date: Sun, 25 Aug 2024 11:24:46 +0200 Subject: [PATCH] add some hint typing --- src/models/gemma.py | 10 +++++----- src/models/paligemma.py | 2 +- src/models/siglip.py | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/models/gemma.py b/src/models/gemma.py index 8b381b0..a034c2d 100644 --- a/src/models/gemma.py +++ b/src/models/gemma.py @@ -55,7 +55,7 @@ def forward(self, # self.inv_freq.to(x.device) # Copy the inv_freq tensor for batch in the sequence # inv_freq_expanded: [Batch_Size, Head_Dim // 2, 1] - inv_freq_expanded = self.inv_freq[None, :, None].expand( + inv_freq_expanded: torch.FloatTensor = self.inv_freq[None, :, None].expand( position_ids.shape[0], -1, 1) # position_ids_expanded: [Batch_Size, 1, Seq_Len] position_ids_expanded = position_ids[:, None, :].float() @@ -123,11 +123,11 @@ def forward(self, # [Batch_Size, Seq_Len, Hidden_Size] bsz, q_len, _ = hidden_states.size() # [Batch_Size, Seq_Len, Num_Heads_Q * Head_Dim] - query_states = self.q_proj(hidden_states) + query_states: torch.FloatTensor = self.q_proj(hidden_states) # [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim] - key_states = self.k_proj(hidden_states) + key_states: torch.FloatTensor = self.k_proj(hidden_states) # [Batch_Size, Seq_Len, Num_Heads_KV * Head_Dim] - value_states = self.v_proj(hidden_states) + value_states: torch.FloatTensor = self.v_proj(hidden_states) # [Batch_Size, Num_Heads_Q, Seq_Len, Head_Dim] query_states = query_states.view(bsz, q_len, self.num_heads, @@ -155,7 +155,7 @@ def forward(self, key_states, value_states = kv_cache.update( key_states, value_states, self.layer_idx) - attn_output = F.scaled_dot_product_attention(query_states, + attn_output: torch.FloatTensor = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, diff --git a/src/models/paligemma.py b/src/models/paligemma.py index 081cdc1..310a509 100644 --- a/src/models/paligemma.py +++ b/src/models/paligemma.py @@ -156,7 +156,7 @@ def forward(self, assert torch.all(attention_mask == 1), "The input can not be padded " # 1. Extract the input embeddings # shape: (Batch_Size, Seq_Len, Hidden_Size) - inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + inputs_embeds: torch.FloatTensor = self.language_model.get_input_embeddings()(input_ids) # 2. Merge text and images # [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim] selected_image_feature = self.vision_tower( diff --git a/src/models/siglip.py b/src/models/siglip.py index abc287f..5974462 100644 --- a/src/models/siglip.py +++ b/src/models/siglip.py @@ -53,7 +53,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: # [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] # where Num_Patches_H = height // patch_size and Num_Patches_W = width # // patch_size - patch_embeds = self.patch_embedding(pixel_values) + patch_embeds: torch.FloatTensor = self.patch_embedding(pixel_values) # [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] # -> [Batch_Size, Embed_Dim, Num_Patches] # where Num_Patches = Num_Patches_H * Num_Patches_W @@ -77,7 +77,7 @@ def __init__(self, config: SiglipVisionConfig): self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads # Equivalent to 1 / sqrt(self.head_dim) - self.scale = self.head_dim**-0.5 + self.scale: float = self.head_dim**-0.5 self.dropout = config.attention_dropout self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) @@ -91,9 +91,9 @@ def forward(self, """Forward method""" batch_size, seq_len, _ = hidden_states.size() # [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim] - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states: torch.FloatTensor = self.q_proj(hidden_states) + key_states: torch.FloatTensor = self.k_proj(hidden_states) + value_states: torch.FloatTensor = self.v_proj(hidden_states) # [Batch_Size, Num_Heads, Num_Patches, Head_Dim] query_states = query_states.view(batch_size, seq_len, self.num_heads, self.head_dim