From 050a6d74118998714426e652b5d39bd8539075d8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 24 Dec 2024 21:29:52 +0100 Subject: [PATCH] Fix multi-image and 2x speed improvements (DS-VL2) (#157) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix multi-image and use .tolist() for (2.16× prompt and 1.83x for generation speedup) * format --- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py index 1ad684a..075fb0b 100644 --- a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +++ b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py @@ -251,14 +251,14 @@ def process_image_features( if num_width_tiles == 0 or num_height_tiles == 0: break - num_tiles_in_image = num_width_tiles * num_height_tiles + num_tiles_in_image = (num_width_tiles * num_height_tiles).tolist() # Get global features [hw, D] global_features = images_embeds[tile_index] # Get local features [num_height_tiles * num_width_tiles, hw, D] local_features = images_embeds[ - tile_index + 1 : tile_index + 1 + int(num_tiles_in_image) + tile_index + 1 : tile_index + 1 + num_tiles_in_image ] tile_index += num_tiles_in_image + 1 @@ -378,14 +378,18 @@ def get_input_embeddings( batch_num_tiles = [0 for _ in range(bs)] total_tiles = [] + + # Total number of tiles in each batch for idx in range(bs): for jdx in range(max_n_images): num_width_tiles, num_height_tiles = images_spatial_crop[idx][jdx] if num_width_tiles == 0 or num_height_tiles == 0: break - batch_num_tiles[idx] += 1 + num_width_tiles * num_height_tiles + batch_num_tiles[idx] += ( + 1 + num_width_tiles * num_height_tiles + ).tolist() - total_tiles.append(pixel_values[idx, : int(batch_num_tiles[idx])]) + total_tiles.append(pixel_values[idx, : batch_num_tiles[idx]]) total_tiles = mx.concatenate(total_tiles, axis=0) assert total_tiles.shape[0] == sum(batch_num_tiles)