diff --git a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py index 1ad684a..075fb0b 100644 --- a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +++ b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py @@ -251,14 +251,14 @@ def process_image_features( if num_width_tiles == 0 or num_height_tiles == 0: break - num_tiles_in_image = num_width_tiles * num_height_tiles + num_tiles_in_image = (num_width_tiles * num_height_tiles).tolist() # Get global features [hw, D] global_features = images_embeds[tile_index] # Get local features [num_height_tiles * num_width_tiles, hw, D] local_features = images_embeds[ - tile_index + 1 : tile_index + 1 + int(num_tiles_in_image) + tile_index + 1 : tile_index + 1 + num_tiles_in_image ] tile_index += num_tiles_in_image + 1 @@ -378,14 +378,18 @@ def get_input_embeddings( batch_num_tiles = [0 for _ in range(bs)] total_tiles = [] + + # Total number of tiles in each batch for idx in range(bs): for jdx in range(max_n_images): num_width_tiles, num_height_tiles = images_spatial_crop[idx][jdx] if num_width_tiles == 0 or num_height_tiles == 0: break - batch_num_tiles[idx] += 1 + num_width_tiles * num_height_tiles + batch_num_tiles[idx] += ( + 1 + num_width_tiles * num_height_tiles + ).tolist() - total_tiles.append(pixel_values[idx, : int(batch_num_tiles[idx])]) + total_tiles.append(pixel_values[idx, : batch_num_tiles[idx]]) total_tiles = mx.concatenate(total_tiles, axis=0) assert total_tiles.shape[0] == sum(batch_num_tiles)