diff --git a/README.md b/README.md index d66bfc6d..4ead895e 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,29 @@ print(output) python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Compare these images" --image path/to/image1.jpg path/to/image2.jpg ``` +## Video Understanding + +MLX-VLM also supports video analysis such as captioning, summarization, and more, with select models. + +### Supported Models + +The following models support video chat: + +1. Qwen2-VL +2. Qwen2.5-VL +3. Idefics3 +4. LLaVA + +With more coming soon. + +### Usage Examples + +#### Command Line +```sh +python -m mlx_vlm.video_generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Describe this video" --video path/to/video.mp4 --max-pixels 224 224 --fps 1.0 +``` + + These examples demonstrate how to use multiple images with MLX-VLM for more complex visual reasoning tasks. # Fine-tuning diff --git a/examples/video_understanding.ipynb b/examples/video_understanding.ipynb new file mode 100644 index 00000000..5630c0a7 --- /dev/null +++ b/examples/video_understanding.ipynb @@ -0,0 +1,204 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Video Understanding\n", + "\n", + "In this example, we will generate a description of a video using `Qwen2-VL`, `Qwen2-5-VL`, `LLava`, and `Idefics3`, with more models coming soon.\n", + "\n", + "This feature is currently in beta, may not work as expected.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Install Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U mlx-vlm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import Dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/homebrew/Caskroom/miniconda/base/envs/mlx_code/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "This is a beta version of the video understanding. It may not work as expected.\n" + ] + } + ], + "source": [ + "from pprint import pprint\n", + "from mlx_vlm import load\n", + "from mlx_vlm.utils import generate\n", + "from mlx_vlm.video_generate import process_vision_info\n", + "\n", + "import mlx.core as mx" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the model and processor\n", + "model, processor = load(\"mlx-community/Qwen2.5-VL-7B-Instruct-4bit\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "numpy reader: video_path=videos/fastmlx_local_ai_hub.mp4, total_frames=1134, video_fps=59.941855343141576, time=0.000s\n" + ] + } + ], + "source": [ + "# Messages containing a video and a text query\n", + "messages = [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": [\n", + " {\n", + " \"type\": \"video\",\n", + " \"video\": \"videos/fastmlx_local_ai_hub.mp4\",\n", + " \"max_pixels\": 360 * 360,\n", + " \"fps\": 1.0,\n", + " },\n", + " {\"type\": \"text\", \"text\": \"Describe this video.\"},\n", + " ],\n", + " }\n", + "]\n", + "\n", + "# Preparation for inference\n", + "text = processor.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=True\n", + ")\n", + "image_inputs, video_inputs = process_vision_info(messages)\n", + "inputs = processor(\n", + " text=[text],\n", + " images=image_inputs,\n", + " videos=video_inputs,\n", + " padding=True,\n", + " return_tensors=\"pt\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert inputs to mlx arrays\n", + "input_ids = mx.array(inputs['input_ids'])\n", + "pixel_values = mx.array(inputs['pixel_values_videos'])\n", + "mask = mx.array(inputs['attention_mask'])\n", + "image_grid_thw = mx.array(inputs['video_grid_thw'])\n", + "\n", + "kwargs = {\n", + " \"image_grid_thw\": image_grid_thw,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "kwargs[\"video\"] = \"videos/fastmlx_local_ai_hub.mp4\"\n", + "kwargs[\"input_ids\"] = input_ids\n", + "kwargs[\"pixel_values\"] = pixel_values\n", + "kwargs[\"mask\"] = mask\n", + "response = generate(model, processor, prompt=text, temp=0.7, max_tokens=100, **kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('The video appears to be a live stream or a recording of a coding session, '\n", + " 'likely on a platform like Discord, as indicated by the presence of text '\n", + " \"chats and a streamer's interface. The video is primarily focused on a \"\n", + " 'computer screen displaying a code editor with various programming languages '\n", + " 'and snippets of code. The coder seems to be explaining or demonstrating '\n", + " 'something related to the code, possibly working through a programming '\n", + " 'problem, explaining the logic, or showing the process of solving a problem.\\n'\n", + " '\\n'\n", + " 'Here are some key observations from')\n" + ] + } + ], + "source": [ + "pprint(response)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# open video and play it\n", + "from ipywidgets import Video\n", + "Video.from_file(\"videos/fastmlx_local_ai_hub.mp4\", width=320, height=240)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mlx_code", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/videos/fastmlx_local_ai_hub.mp4 b/examples/videos/fastmlx_local_ai_hub.mp4 new file mode 100644 index 00000000..dcf641e2 Binary files /dev/null and b/examples/videos/fastmlx_local_ai_hub.mp4 differ diff --git a/mlx_vlm/__init__.py b/mlx_vlm/__init__.py index 2d716f6b..b6c515b0 100644 --- a/mlx_vlm/__init__.py +++ b/mlx_vlm/__init__.py @@ -8,3 +8,4 @@ quantize_model, ) from .version import __version__ +from .video_generate import VideoFrameExtractor, process_vision_info diff --git a/mlx_vlm/chat.py b/mlx_vlm/chat.py new file mode 100644 index 00000000..c6b895a8 --- /dev/null +++ b/mlx_vlm/chat.py @@ -0,0 +1,233 @@ +import argparse +import os +import sys +import time +from typing import Dict, List + +import mlx.core as mx +from rich import print as rprint +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.prompt import Prompt + +from mlx_vlm import load +from mlx_vlm.prompt_utils import get_message_json +from mlx_vlm.utils import generate_step, load_image + + +class MLXVisionChat: + def __init__( + self, + model_path: str = "mlx-community/idefics2-8b-chatty-4bit", + temperature: float = 0.7, + max_tokens: int = 1000, + verbose: bool = False, + ): + self.console = Console() + self.verbose = verbose + self.temperature = temperature + self.max_tokens = max_tokens + self.history: List[Dict] = [] + self.current_image = None + + with self.console.status("[bold green]Loading model..."): + self.model, self.processor = load(model_path) + + rprint("[bold green]Model loaded successfully![/bold green]") + self.print_help() + + def print_help(self) -> None: + """Print available commands.""" + help_text = """ +[bold yellow]Available Commands:[/bold yellow] +• /image - Load a new image for discussion +• /clear - Clear conversation history +• /help - Show this help message +• /exit - Exit the chat +• Any other input will be treated as a question or comment about the current image + """ + rprint(Panel(help_text, title="Help", border_style="blue")) + + def process_image(self, image_path: str) -> bool: + """Process an image and prepare it for the model. Returns True if successful.""" + try: + if not os.path.exists(image_path): + rprint( + f"[bold red]Error:[/bold red] Image file not found: {image_path}" + ) + return False + + self.current_image = load_image(image_path) + rprint(f"[bold blue]Loaded image:[/bold blue] {image_path}") + return True + except Exception as e: + rprint(f"[bold red]Error loading image:[/bold red] {str(e)}") + return False + + def add_to_history(self, role: str, text: str) -> None: + """Add a message to the conversation history.""" + content = [{"type": "text", "text": text}] + self.history.append({"role": role, "content": content}) + + def generate_response(self) -> str: + """Generate a response from the model based on the conversation history.""" + if self.current_image is None: + return "Please load an image first using the /image command." + + messages = [] + for i, message in enumerate(self.history): + skip_token = True + if i == len(self.history) - 1 and message["role"] == "user": + skip_token = False + messages.append( + get_message_json( + self.model.config.model_type, + message["content"][0]["text"], + role=message["role"], + skip_image_token=skip_token, + ) + ) + + text_prompt = self.processor.apply_chat_template( + messages, add_generation_prompt=True + ) + + inputs = self.processor( + text=[text_prompt], + images=[self.current_image], + padding=True, + return_tensors="np", + ) + + pixel_values = mx.array(inputs["pixel_values"]) + input_ids = mx.array(inputs["input_ids"]) + mask = mx.array(inputs["attention_mask"]) + + detokenizer = self.processor.detokenizer + detokenizer.reset() + + tic = time.perf_counter() + + generator = generate_step( + input_ids, + self.model, + pixel_values, + mask, + temp=self.temperature, + ) + + # Use print instead of rprint to avoid rich console's automatic newlines + rprint("[bold green]Assistant:[/bold green]", end=" ", flush=True) + for (token, prob), n in zip(generator, range(self.max_tokens)): + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + + if token == self.processor.tokenizer.eos_token_id and n > 0: + break + + detokenizer.add_token(token) + + if self.verbose: + rprint(detokenizer.last_segment, end="", flush=True) + + detokenizer.finalize() + return detokenizer.text + + def handle_command(self, command: str, args: str) -> bool: + """Handle special commands. Returns True if should continue chat, False if should exit.""" + if command == "/exit": + rprint("[bold yellow]Goodbye![/bold yellow]") + return False + elif command == "/help": + self.print_help() + elif command == "/clear": + self.history.clear() + rprint("[bold blue]Conversation history cleared.[/bold blue]") + elif command == "/image": + if not args: + rprint("[bold red]Error:[/bold red] Please provide an image path") + return True + self.process_image(args.strip()) + else: + rprint(f"[bold red]Unknown command:[/bold red] {command}") + return True + + def chat_loop(self) -> None: + """Main chat loop for interaction.""" + while True: + try: + user_input = Prompt.ask("\n[bold cyan]You[/bold cyan]").strip() + + # Handle commands + if user_input.startswith("/"): + parts = user_input.split(maxsplit=1) + command = parts[0].lower() + args = parts[1] if len(parts) > 1 else "" + if not self.handle_command(command, args): + break + continue + # Handle regular chat input + if self.current_image is None: + rprint( + "[bold yellow]Please load an image first using the /image command[/bold yellow]" + ) + continue + + self.add_to_history("user", user_input) + response = self.generate_response() + + if not self.verbose: + rprint(Panel(Markdown(response), border_style="green")) + + # Remove the eos token from the response + response = response.replace("", "") + + self.add_to_history("assistant", response) + + except KeyboardInterrupt: + rprint( + "\n[bold yellow]Interrupted by user. Type /exit to quit.[/bold yellow]" + ) + continue + except Exception as e: + rprint(f"[bold red]Error:[/bold red] {str(e)}") + continue + + +def main(): + parser = argparse.ArgumentParser(description="MLX Vision Chat CLI") + parser.add_argument( + "--model", + default="mlx-community/idefics2-8b-chatty-4bit", + help="Path to the model or model identifier", + ) + parser.add_argument("--verbose", action="store_false", help="Enable verbose output") + parser.add_argument( + "--temperature", type=float, default=0.7, help="Temperature for the model" + ) + parser.add_argument( + "--max-tokens", + type=int, + default=1000, + help="Maximum number of new tokens to generate", + ) + + args = parser.parse_args() + + try: + chat = MLXVisionChat( + model_path=args.model, + temperature=args.temperature, + max_tokens=args.max_tokens, + verbose=args.verbose, + ) + chat.chat_loop() + except Exception as e: + rprint(f"[bold red]Fatal error:[/bold red] {str(e)}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/mlx_vlm/chat_ui.py b/mlx_vlm/chat_ui.py index 210597f0..11ddf144 100644 --- a/mlx_vlm/chat_ui.py +++ b/mlx_vlm/chat_ui.py @@ -4,7 +4,7 @@ from mlx_vlm import load -from .prompt_utils import apply_chat_template +from .prompt_utils import get_chat_template, get_message_json from .utils import load, load_config, load_image_processor, stream_generate @@ -28,17 +28,40 @@ def parse_arguments(): def chat(message, history, temperature, max_tokens): - chat = [] - if len(message["files"]) >= 1: - chat.append({"role": "user", "content": message["text"]}) - else: - raise gr.Error("Please upload an image. Text only chat is not supported.") + if config["model_type"] != "paligemma": + if len(message["files"]) >= 1: + chat_history = [] + for item in history: + chat_history.append({"role": "user", "content": item[0]}) + if item[1] is not None: + chat_history.append({"role": "assistant", "content": item[1]}) + + chat_history.append({"role": "user", "content": message["text"]}) + + messages = [] + for i, m in enumerate(chat_history): + skip_token = True + if i == len(chat_history) - 1 and m["role"] == "user": + skip_token = False + messages.append( + get_message_json( + config["model_type"], + m["content"], + role=m["role"], + skip_image_token=skip_token, + ) + ) - file = message["files"][-1] - if model.config.model_type != "paligemma": - prompt = apply_chat_template(processor, config, chat) + messages = get_chat_template( + processor, messages, add_generation_prompt=True + ) + + else: + raise gr.Error("Please upload an image. Text only chat is not supported.") else: - prompt = message.text + messages = message["text"] + + files = message["files"][-1]["path"] response = "" for chunk in stream_generate( diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 696d9dba..50041510 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -110,27 +110,16 @@ def _merge_input_ids_with_image_features( # Positions of tokens in input_ids, assuming batch size is 1 image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + num_images, _, vision_hidden_size = image_features.shape - if len(image_positions) != num_images: - raise ValueError( - f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." - ) - - text_segments = [] - start_idx = 0 + reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size) - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 - - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] - - # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + # cast to the dtype of the input_embeds to support quantized models + reshaped_image_hidden_states = reshaped_image_hidden_states.astype( + inputs_embeds.dtype + ) + inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states + return inputs_embeds def __call__( self, diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index 730d4dcc..5363752f 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -161,6 +161,8 @@ def get_input_embeddings( def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, seq_length, embed_dim = inputs_embeds.shape num_images, num_image_patches, _ = image_features.shape @@ -184,6 +186,8 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id # (batch_size, num_image_patches + sequence_len, embed_dim) return mx.concatenate(final_embeddings, axis=0) + # Create a final embedding of shape + def __call__( self, input_ids: mx.array, diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index f10649ff..20a43137 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -100,6 +100,8 @@ def get_input_embeddings( # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) + + # Add a newline token to the image features if self.image_newline is not None: self.image_newline = np.array(self.image_newline)[None, None, :] self.image_newline = np.broadcast_to( @@ -118,9 +120,10 @@ def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index + num_images, num_image_patches, embed_dim = image_features.shape + + image_positions = np.where(input_ids == image_token_index)[1].tolist() - # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() text_segments = [] start_idx = 0 diff --git a/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py b/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py index d9a0f984..31f7e299 100644 --- a/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +++ b/mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py @@ -87,9 +87,12 @@ def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids ): image_token_id = self.config.image_token_id - + video_token_id = self.config.video_token_id # Positions of tokens in input_ids, assuming batch size is 1 image_positions = input_ids == image_token_id + if mx.sum(image_positions) == 0: + image_positions = input_ids == video_token_id + image_indices = np.where(image_positions)[1].tolist() inputs_embeds[:, image_indices, :] = image_features return inputs_embeds @@ -106,13 +109,12 @@ def __call__( second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) video_grid_thw = kwargs.pop("video_grid_thw", None) position_ids = kwargs.pop("position_ids", None) + grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw if image_grid_thw is not None: image_grid_thw = mx.array(image_grid_thw) - inputs_embeds = self.get_input_embeddings( - input_ids, pixel_values, image_grid_thw - ) + inputs_embeds = self.get_input_embeddings(input_ids, pixel_values, grid_thw) logits = self.language_model(None, cache=cache, inputs_embeds=inputs_embeds) return logits diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index bc907628..1de7ddfc 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -21,6 +21,7 @@ class ModelConfig: model_type: str ignore_index: int = -100 image_token_index: int = 151655 + video_token_index: int = 151656 vision_feature_select_strategy: str = "default" vision_feature_layer: int = -2 vocab_size: int = 32000 @@ -83,11 +84,19 @@ def _merge_input_ids_with_image_features( self, image_features, inputs_embeds, input_ids ): image_token_index = self.config.image_token_index + video_token_index = self.config.video_token_index # Positions of tokens in input_ids, assuming batch size is 1 image_positions = input_ids == image_token_index - image_indices = np.where(image_positions)[1].tolist() - inputs_embeds[:, image_indices, :] = image_features + if mx.sum(image_positions) == 0: + image_positions = input_ids == video_token_index + + image_features = image_features.astype(mx.float32) + pad_size = inputs_embeds.shape[1] - image_features.shape[1] + image_features = mx.pad(image_features, ((0, 0), (0, pad_size), (0, 0))) + inputs_embeds = mx.where( + image_positions[:, :, None], image_features, inputs_embeds + ) return inputs_embeds @@ -99,6 +108,7 @@ def __call__( cache=None, **kwargs, ): + image_grid_thw = kwargs.pop("image_grid_thw", None) if image_grid_thw is not None: image_grid_thw = mx.array(image_grid_thw) diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index 4a915875..725e811f 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -1,5 +1,5 @@ def get_message_json( - model_name, prompt, role="user", skip_image_token=False, num_images=1 + model_name, prompt, role="user", skip_image_token=False, num_images=1, **kwargs ): """ Get the appropriate JSON message based on the specified model. @@ -16,69 +16,111 @@ def get_message_json( """ model_name = model_name.lower() - def create_message(role, prompt): - return {"role": role, "content": prompt} + # Base message creation + def create_text_message(text): + return {"type": "text", "text": text} - def add_image_tokens(message, token_format): - if role == "system": - return message + def create_text_content_message(text): + return {"type": "text", "content": text} + + def create_video_message(video_path, max_pixels=224 * 224, fps=1): + return { + "type": "video", + "video": video_path, + "max_pixels": max_pixels, + "fps": fps, + } + + # Message format handlers + def handle_list_with_image(): + content = [create_text_message(prompt)] if role == "user" and not skip_image_token: - if isinstance(message["content"], list): - if model_name in ["pixtral", "idefics3"]: - message["content"] = [{"type": "image"}] * num_images + message[ - "content" - ] - else: - message["content"].extend([{"type": "image"}] * num_images) - else: - if model_name == "phi3_v": - message["content"] = f"{token_format}{message['content']}" - else: - message["content"] = ( - f"{token_format * num_images}{message['content']}" - ) - if role == "assistant" and model_name == "pixtral": + image_tokens = [{"type": "image"}] * num_images + content = ( + image_tokens + content + if model_name in ["pixtral", "idefics3"] + else content + image_tokens + ) + return {"role": role, "content": content} + + def handle_list_with_image_type(): + message = {"role": role, "content": [create_text_content_message(prompt)]} + if role == "user" and not skip_image_token: + message["content"] = [{"type": "image"}] * num_images + message["content"] + if role == "assistant": message["content"] = message["content"][0]["content"] return message + def handle_image_token(token_format): + content = prompt + if role != "system" and role == "user" and not skip_image_token: + prefix = ( + token_format * num_images if model_name != "phi3_v" else token_format + ) + content = f"{prefix}{content}" + return {"role": role, "content": content} + + def handle_video_with_text(): + return { + "role": "user", + "content": [ + create_video_message( + kwargs["video"], + kwargs.get("max_pixels", 224 * 224), + kwargs.get("fps", 1), + ), + create_text_message(prompt), + ], + } + + # Message format mapping message_formats = { - "message_list_with_image": lambda: add_image_tokens( - {"role": role, "content": [{"type": "text", "text": prompt}]}, "" - ), - "message_list_with_image_type": lambda: add_image_tokens( - {"role": role, "content": [{"type": "text", "content": prompt}]}, "" - ), - "message_with_image_token": lambda: add_image_tokens( - create_message(role, prompt), "" - ), - "message_with_image_token_new_line": lambda: add_image_tokens( - create_message(role, prompt), "\n" - ), - "message_with_numbered_image_tokens": lambda: add_image_tokens( - create_message(role, prompt), - " ".join([f"<|image_{i+1}|>" for i in range(num_images)]), + "message_list_with_image": handle_list_with_image, + "message_list_with_image_type": handle_list_with_image_type, + "message_with_image_token": lambda: handle_image_token(""), + "message_with_image_token_new_line": lambda: handle_image_token("\n"), + "message_with_numbered_image_tokens": lambda: handle_image_token( + " ".join([f"<|image_{i+1}|>" for i in range(num_images)]) ), "prompt_only": lambda: prompt, "prompt_with_image_token": lambda: "" * num_images + prompt, + "message_video_with_text": handle_video_with_text, } + # Model to format mapping model_to_format = { + # Models using message_list_with_image format "idefics2": "message_list_with_image", "idefics3": "message_list_with_image", - "qwen2_vl": "message_list_with_image", - "qwen2_5_vl": "message_list_with_image", "llava": "message_list_with_image", "llava_next": "message_list_with_image", + "mllama": "message_list_with_image", + # Models that can handle both image and video formats + "qwen2_vl": ( + "message_video_with_text" + if kwargs.get("video") + else "message_list_with_image" + ), + "qwen2_5_vl": ( + "message_video_with_text" + if kwargs.get("video") + else "message_list_with_image" + ), + # Models using message_with_image_token_new_line format "llava-qwen2": "message_with_image_token_new_line", "bunny-llama": "message_with_image_token_new_line", + # Models using message_with_numbered_image_tokens format "phi3_v": "message_with_numbered_image_tokens", + # Models using message_with_image_token format "multi_modality": "message_with_image_token", + "deepseek_vl_v2": "message_with_image_token_new_line", + # Models using message_list_with_image_type format "pixtral": "message_list_with_image_type", + # Models using prompt_with_image_token format "paligemma": "prompt_with_image_token", + # Models using prompt_only format "florence2": "prompt_only", - "mllama": "message_list_with_image", "molmo": "prompt_only", - "deepseek_vl_v2": "message_with_image_token_new_line", } if num_images > 1 and model_name in [ @@ -100,6 +142,25 @@ def add_image_tokens(message, token_format): raise ValueError(f"Unsupported model: {model_name}") +def get_chat_template(processor, messages, add_generation_prompt, tokenize=False): + if "chat_template" in processor.__dict__.keys(): + return processor.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + ) + elif "tokenizer" in processor.__dict__.keys(): + return processor.tokenizer.apply_chat_template( + messages, + tokenize=tokenize, + add_generation_prompt=add_generation_prompt, + ) + else: + raise ValueError( + "Error: processor does not have 'chat_template' or 'tokenizer' attribute." + ) + + def apply_chat_template( processor, config, @@ -107,6 +168,7 @@ def apply_chat_template( add_generation_prompt=True, return_messages=False, num_images=1, + **kwargs, ): config = config if isinstance(config, dict) else config.__dict__ @@ -117,6 +179,7 @@ def process_single_prompt(p, is_first=True): p, skip_image_token=not is_first, num_images=num_images, + **kwargs, ) elif isinstance(p, dict) and "role" in p: return get_message_json( @@ -125,6 +188,7 @@ def process_single_prompt(p, is_first=True): p["role"], skip_image_token=not is_first, num_images=num_images, + **kwargs, ) else: raise ValueError("Invalid prompt type") @@ -155,21 +219,4 @@ def process_single_prompt(p, is_first=True): if config["model_type"] in ["paligemma", "molmo", "florence2"]: return messages[-1] - if "chat_template" in processor.__dict__.keys(): - return processor.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=add_generation_prompt, - ) - - elif "tokenizer" in processor.__dict__.keys(): - return processor.tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=add_generation_prompt, - ) - - else: - raise ValueError( - "Error: processor does not have 'chat_template' or 'tokenizer' attribute." - ) + return get_chat_template(processor, messages, add_generation_prompt) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index e1d0c37d..4acff3e8 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -1008,22 +1008,27 @@ def stream_generate( resize_shape = kwargs.pop("resize_shape", None) image_token_index = getattr(model.config, "image_token_index", None) - if not image: - input_ids = prompt_tokens[None, :] - pixel_values = mask = None + if kwargs.get("pixel_values") is None: + if not image: + input_ids = prompt_tokens[None, :] + pixel_values = mask = None + else: + inputs = prepare_inputs( + processor, image, prompt, image_token_index, resize_shape + ) + input_ids = inputs["input_ids"] + pixel_values = inputs["pixel_values"] + mask = inputs["attention_mask"] + data_kwargs = { + k: v + for k, v in inputs.items() + if k not in ["input_ids", "pixel_values", "attention_mask"] + } + kwargs.update(data_kwargs) else: - inputs = prepare_inputs( - processor, image, prompt, image_token_index, resize_shape - ) - input_ids = inputs["input_ids"] - pixel_values = inputs["pixel_values"] - mask = inputs["attention_mask"] - data_kwargs = { - k: v - for k, v in inputs.items() - if k not in ["input_ids", "pixel_values", "attention_mask"] - } - kwargs.update(data_kwargs) + input_ids = kwargs.pop("input_ids") + pixel_values = kwargs.pop("pixel_values") + mask = kwargs.pop("mask") detokenizer = processor.detokenizer detokenizer.reset() @@ -1093,11 +1098,20 @@ def generate( if verbose: print("=" * 10) - print("Image:", image, "\n") + if image is not None: + input_path = image + elif kwargs.get("video") is not None: + input_path = kwargs.get("video") + else: + input_path = None + + print(f"Files: {input_path}", "\n") + print("Prompt:", prompt) text = "" last_response = None + for response in stream_generate(model, processor, prompt, image, **kwargs): if verbose: print(response.text, end="", flush=True) diff --git a/mlx_vlm/version.py b/mlx_vlm/version.py index 74acd0ef..3cb7d95e 100644 --- a/mlx_vlm/version.py +++ b/mlx_vlm/version.py @@ -1 +1 @@ -__version__ = "0.1.12" +__version__ = "0.1.13" diff --git a/mlx_vlm/video_generate.py b/mlx_vlm/video_generate.py new file mode 100644 index 00000000..1afd07e1 --- /dev/null +++ b/mlx_vlm/video_generate.py @@ -0,0 +1,602 @@ +from __future__ import annotations + +import argparse +import base64 +import logging +import math +import os +import sys +import time +from io import BytesIO +from typing import List + +import cv2 +import mlx.core as mx +import numpy as np +import requests +from PIL import Image + +from .utils import generate, load, load_image + +# This is a beta version of the video generation script. +# It is not fully tested and may not work as expected. + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) + +logger.info( + "This is a beta version of the video understanding. It may not work as expected." +) + +IMAGE_FACTOR = 28 +MIN_PIXELS = 4 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +VIDEO_MIN_PIXELS = 128 * 28 * 28 +VIDEO_MAX_PIXELS = 768 * 28 * 28 +FRAME_FACTOR = 2 +FPS = 2.0 +FPS_MIN_FRAMES = 4 +FPS_MAX_FRAMES = 768 + +# Set the maximum number of video token inputs. +VIDEO_TOTAL_PIXELS = int( + float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9)) +) + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def to_rgb(pil_image: Image.Image) -> Image.Image: + if pil_image.mode == "RGBA": + white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) + white_background.paste( + pil_image, mask=pil_image.split()[3] + ) # Use alpha channel as mask + return white_background + else: + return pil_image.convert("RGB") + + +def fetch_image( + ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR +) -> Image.Image: + if "image" in ele: + image = ele["image"] + else: + image = ele["image_url"] + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + response = requests.get(image, stream=True) + image_obj = Image.open(BytesIO(response.content)) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + image_obj = Image.open(BytesIO(data)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError( + f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}" + ) + image = to_rgb(image_obj) + ## resize + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=size_factor, + ) + else: + width, height = image.size + min_pixels = ele.get("min_pixels", MIN_PIXELS) + max_pixels = ele.get("max_pixels", MAX_PIXELS) + resized_height, resized_width = smart_resize( + height, + width, + factor=size_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + image = image.resize((resized_width, resized_height)) + return image + + +def smart_nframes( + ele: dict, + total_frames: int, + video_fps: int | float, +) -> int: + """Calculate the number of frames for the video to be used as model inputs. + + Either a fixed 'nframes' is provided in ele or 'fps' is used to calculate how many frames to sample. + """ + assert not ( + "fps" in ele and "nframes" in ele + ), "Only accept either `fps` or `nframes`" + if "nframes" in ele: + nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) + else: + fps = ele.get("fps", FPS) + min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) + max_frames = floor_by_factor( + ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR + ) + nframes = total_frames / video_fps * fps + if nframes > total_frames: + logger.warning( + f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]" + ) + nframes = min(min(max(nframes, min_frames), max_frames), total_frames) + nframes = floor_by_factor(nframes, FRAME_FACTOR) + if not (FRAME_FACTOR <= nframes and nframes <= total_frames): + raise ValueError( + f"nframes should be in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}." + ) + return nframes + + +def load_video( + ele: dict, +) -> (np.ndarray, float): + """ + Read video using cv2.VideoCapture. + + The video is read as a NumPy array with shape (T, C, H, W) where T is the number of frames, + C is the number of channels, and H, W are the frame dimensions. + """ + video_path = ele["video"] + if video_path.startswith("file://"): + video_path = video_path[7:] + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Cannot open video: {video_path}") + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + video_fps = cap.get(cv2.CAP_PROP_FPS) or 1.0 # default to 1.0 if fps returns 0 + st = time.time() + logger.info( + f"numpy reader: video_path={video_path}, total_frames={total_frames}, video_fps={video_fps}, time={time.time()-st:.3f}s" + ) + nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) + indices = np.linspace(0, total_frames - 1, nframes).round().astype(int) + frames = [] + for idx in indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, idx) + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + cap.release() + if not frames: + raise ValueError("No frames read from the video.") + # Stack frames into a numpy array: (T, H, W, C) + video_np = np.stack(frames, axis=0) + # Rearrange to (T, C, H, W) + video_np = np.transpose(video_np, (0, 3, 1, 2)) + sample_fps = nframes / max(total_frames, 1e-6) * video_fps + return video_np, sample_fps + + +def fetch_video( + ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False +) -> np.ndarray | list[Image.Image]: + if isinstance(ele["video"], str): + video, sample_fps = load_video(ele) + nframes, _, height, width = video.shape + min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) + total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) + max_pixels = max( + min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), + int(min_pixels * 1.05), + ) + max_pixels_supposed = ele.get("max_pixels", max_pixels) + if max_pixels_supposed > max_pixels: + logger.warning( + f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]." + ) + max_pixels = min(max_pixels_supposed, max_pixels) + if "resized_height" in ele and "resized_width" in ele: + resized_height, resized_width = smart_resize( + ele["resized_height"], + ele["resized_width"], + factor=image_factor, + ) + else: + resized_height, resized_width = smart_resize( + height, + width, + factor=image_factor, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + # Resize each frame using OpenCV (similar to torchvision.transforms.functional.resize with BICUBIC) + resized_frames = [] + # video is (T, C, H, W) so we need to process each frame + for frame in video: + # Rearrange from (C, H, W) to (H, W, C) + frame_np = np.transpose(frame, (1, 2, 0)) + # cv2.resize expects size as (width, height) + resized = cv2.resize( + frame_np, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC + ) + # Convert back to (C, H, W) + resized = np.transpose(resized, (2, 0, 1)) + resized_frames.append(resized) + video = np.stack(resized_frames, axis=0).astype(np.float32) + if return_video_sample_fps: + return video, sample_fps + return video + else: + # Assume video is provided as a list/tuple of image objects. + process_info = ele.copy() + process_info.pop("type", None) + process_info.pop("video", None) + images = [ + fetch_image( + {"image": video_element, **process_info}, size_factor=image_factor + ) + for video_element in ele["video"] + ] + nframes = ceil_by_factor(len(images), FRAME_FACTOR) + if len(images) < nframes: + images.extend([images[-1]] * (nframes - len(images))) + if return_video_sample_fps: + return images, process_info.pop("fps", 2.0) + return images + + +def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]: + vision_infos = [] + if isinstance(conversations[0], dict): + conversations = [conversations] + for conversation in conversations: + for message in conversation: + if isinstance(message["content"], list): + for ele in message["content"]: + if ( + "image" in ele + or "image_url" in ele + or "video" in ele + or ele["type"] in ("image", "image_url", "video") + ): + vision_infos.append(ele) + return vision_infos + + +def process_vision_info( + conversations: list[dict] | list[list[dict]], + return_video_kwargs: bool = False, +) -> tuple[ + list[Image.Image] | None, list[np.ndarray | list[Image.Image]] | None, dict | None +]: + vision_infos = extract_vision_info(conversations) + ## Read images or videos + image_inputs = [] + video_inputs = [] + video_sample_fps_list = [] + for vision_info in vision_infos: + if "image" in vision_info or "image_url" in vision_info: + image_inputs.append(fetch_image(vision_info)) + elif "video" in vision_info: + video_input, video_sample_fps = fetch_video( + vision_info, return_video_sample_fps=True + ) + video_sample_fps_list.append(video_sample_fps) + video_inputs.append(video_input) + else: + raise ValueError("Content must include image, image_url, or video.") + if len(image_inputs) == 0: + image_inputs = None + if len(video_inputs) == 0: + video_inputs = None + if return_video_kwargs: + return image_inputs, video_inputs, {"fps": video_sample_fps_list} + return image_inputs, video_inputs + + +class VideoFrameExtractor: + def __init__(self, max_frames: int = 50): + self.max_frames = max_frames + + def resize_and_center_crop( + self, image: Image.Image, target_size: int + ) -> Image.Image: + # Get current dimensions + width, height = image.size + + # Calculate new dimensions keeping aspect ratio + if width < height: + new_width = target_size + new_height = int(height * (target_size / width)) + else: + new_height = target_size + new_width = int(width * (target_size / height)) + + # Resize + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # Center crop + left = (new_width - target_size) // 2 + top = (new_height - target_size) // 2 + right = left + target_size + bottom = top + target_size + + return image.crop((left, top, right, bottom)) + + def extract_frames(self, video_path: str) -> List[Image.Image]: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Could not open video: {video_path}") + + # Get video properties + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = int(cap.get(cv2.CAP_PROP_FPS)) + + # Calculate frame indices to extract (1fps) + frame_indices = list(range(0, total_frames, fps)) + + # If we have more frames than max_frames, sample evenly + if len(frame_indices) > self.max_frames: + indices = np.linspace(0, len(frame_indices) - 1, self.max_frames, dtype=int) + frame_indices = [frame_indices[i] for i in indices] + + frames = [] + for frame_idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame) + pil_image = self.resize_and_center_crop(pil_image, 384) + frames.append(pil_image) + + cap.release() + return frames + + +def is_video_model(model): + return hasattr(model.config, "video_token_id") or hasattr( + model.config, "video_token_index" + ) + + +def is_video_file(video_path: List[str]) -> bool: + video_extensions = [".mp4", ".avi", ".mov"] + for path in video_path: + if not any(path.endswith(ext) for ext in video_extensions): + return False + return True + + +def main(): + parser = argparse.ArgumentParser(description="Video Description CLI") + parser.add_argument( + "--video", type=str, nargs="+", required=True, help="Path to the video file" + ) + parser.add_argument( + "--max-pixels", + type=int, + nargs=2, + default=224 * 224, + help="Maximum number of pixels", + ) + parser.add_argument( + "--max-frames", type=int, default=None, help="Maximum number of frames" + ) + parser.add_argument("--fps", type=float, default=1.0, help="Frames per second") + parser.add_argument( + "--prompt", default="Describe this video.", help="Text prompt for the model" + ) + parser.add_argument( + "--temp", type=float, default=0.7, help="Temperature for generation" + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--model", + default="mlx-community/Qwen2.5-VL-7B-Instruct-4bit", + help="Select the model to use", + ) + parser.add_argument("--verbose", action="store_false", help="Print verbose output") + + args = parser.parse_args() + + print(f"\033[32mLoading model:\033[0m {args.model}") + model, processor = load(args.model) + + # Validate the model + if not is_video_model(model): + logger.warning( + "Warning: The model selected doesn't natively support video inputs. Performance may be degraded." + ) + + if isinstance(args.max_pixels, tuple): + max_pixels = args.max_pixels[0] * args.max_pixels[1] + else: + max_pixels = args.max_pixels + + kwargs = {} + if is_video_model(model): + + # Checke if video is image or video + if is_video_file(args.video): + messages = [ + { + "role": "user", + "content": [ + { + "type": "video", + "video": args.video[0], + "max_pixels": max_pixels, + "fps": args.fps, + }, + {"type": "text", "text": args.prompt}, + ], + } + ] + else: + messages = [ + { + "role": "user", + "content": [ + *[{"type": "image", "image": image} for image in args.video], + {"type": "text", "text": args.prompt}, + ], + } + ] + + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + image_inputs, video_inputs, fps = process_vision_info(messages, True) + + if args.max_frames is not None: + video_inputs = video_inputs[: args.max_frames] + inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="np", + ) + + input_ids = mx.array(inputs["input_ids"]) + pixel_values = inputs.get( + "pixel_values_videos", inputs.get("pixel_values", None) + ) + if pixel_values is None: + raise ValueError("Please provide a valid video or image input.") + pixel_values = mx.array(pixel_values) + + mask = mx.array(inputs["attention_mask"]) + if inputs.get("video_grid_thw", None) is not None: + kwargs["video_grid_thw"] = mx.array(inputs["video_grid_thw"]) + if inputs.get("image_grid_thw", None) is not None: + kwargs["image_grid_thw"] = mx.array(inputs["image_grid_thw"]) + + else: + if is_video_file(args.video): + if len(args.video) > 1: + raise ValueError("Only one video is supported for video models.") + else: + frame_extractor = VideoFrameExtractor(args.max_frames) + frames = frame_extractor.extract_frames(args.video[0]) + else: + frames = [load_image(image) for image in args.video] + + # Create prompt with frames + image_tokens = [{"type": "image"} for _ in range(len(frames))] + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Answer briefly."}, + *image_tokens, + {"type": "text", "text": args.prompt}, + ], + } + ] + + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Configure processor for video frames + processor.image_processor.size = ( + args.max_pixels + if isinstance(args.max_pixels, tuple) + else (args.max_pixels, args.max_pixels) + ) + if hasattr(processor.image_processor, "do_resize"): + processor.image_processor.do_resize = False + if hasattr(processor.image_processor, "do_image_splitting"): + processor.image_processor.do_image_splitting = False + + # Process inputs + inputs = processor( + text=text, images=[img for img in frames], return_tensors="np" + ) + + input_ids = mx.array(inputs["input_ids"]) + pixel_values = mx.array(inputs["pixel_values"]) + mask = mx.array(inputs["attention_mask"]) + + logger.info("\033[32mGenerating response...\033[0m") + + kwargs["video"] = args.video + kwargs["input_ids"] = input_ids + kwargs["pixel_values"] = pixel_values + kwargs["mask"] = mask + kwargs["temp"] = args.temp + kwargs["max_tokens"] = args.max_tokens + + response = generate( + model, + processor, + prompt=text, + verbose=args.verbose, + **kwargs, + ) + + if not args.verbose: + print(response) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index a076bcfa..e9edefbc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ scipy==1.13.1 gradio>=4.44.0 Pillow>=10.3.0 requests>=2.31.0 +opencv-python==4.10.0.84