Skip to content

Commit

Permalink
Add torch model and GPU support. (#195)
Browse files Browse the repository at this point in the history
* Setup TorchVL

* fix build, setup cpu, gpu, cloud variants

* remove torch files

* update .gitignore

* update tests

* fix torchVL class and add streaming tests

* format

* remove space in torch max tokens

* format

* Add black install

* fix cloud test

* isort imports

* fix test build ci

* Remove build py, update readme

* update readme

* Add back torch deps

* update readme

* update readme

* update readme
  • Loading branch information
calebjohn24 authored Jan 9, 2025
1 parent ad09284 commit 40a90c0
Show file tree
Hide file tree
Showing 11 changed files with 271 additions and 123 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-client-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ jobs:
- name: Install dependencies
working-directory: ./clients/python
run: |
python -m pip install --upgrade pip
poetry install --all-extras
- name: Format code
working-directory: ./clients/python
run: |
poetry run pip install black
poetry run black tests/test_local_inference.py --check
- name: Run tests
Expand All @@ -54,4 +54,4 @@ jobs:
MOONDREAM_API_KEY: ${{ secrets.MOONDREAM_API_KEY }}
run: |
poetry run pip install pytest pytest-asyncio
poetry run pytest tests/test_*.py -v
poetry run pytest tests/test_api_inference.py -v
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ data
/pyproject.toml
poetry.lock
dist
clients/python/moondream/torch
26 changes: 24 additions & 2 deletions clients/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,23 @@ local inference and cloud-based API access.
Install the package from PyPI:

```bash
pip install moondream==0.0.5
pip install moondream==0.0.6
```

To install the CPU dependencies for local inference, run:

```bash
pip install moondream[cpu]
```

To install the GPU dependencies for local inference, run:

```bash
# Copy the torch implementation from the root moondream repo into the moondream/torch directory
cp -r moondream/torch clients/python/moondream/torch

# Install the GPU dependencies
pip install moondream[gpu]
```

## Quick Start
Expand Down Expand Up @@ -160,7 +176,13 @@ All methods return typed dictionaries:
- CUDA (GPU) and MPS (Apple Silicon) support coming soon
- For optimal performance with GPU/MPS, use the PyTorch implementation for now

## Development Notes

- Copy the torch implementation from the root moondream repo into the `torch` directory
- Run `poetry install --extras "gpu"` to install the GPU dependencies
- Run `poetry install --extras "cpu"` to install the CPU dependencies

## Links

- [Website](https://moondream.ai/)
- [Demo](https://moondream.ai/playground)
- [Demo](https://moondream.ai/playground)
93 changes: 0 additions & 93 deletions clients/python/build.py

This file was deleted.

15 changes: 13 additions & 2 deletions clients/python/moondream/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

from .cloud_vl import CloudVL
from .onnx_vl import OnnxVL
from .types import VLM

DEFAULT_API_URL = "https://api.moondream.ai/v1"
Expand All @@ -14,7 +13,19 @@ def vl(
api_url: Optional[str] = None,
) -> VLM:
if model:
return OnnxVL.from_path(model)
model_filetype = model.split(".")[-1]
if model_filetype == "safetensors":
from .torch_vl import TorchVL

return TorchVL(model=model)
elif model_filetype == "mf":
from .onnx_vl import OnnxVL

return OnnxVL.from_path(model)

raise ValueError(
"Unsupported model filetype. Please use a .safetensors model for GPU use or .mf model for CPU use."
)

if api_key:
if not api_url:
Expand Down
110 changes: 110 additions & 0 deletions clients/python/moondream/torch_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Literal, Optional, Union

import torch
from PIL import Image

from .torch.moondream import MoondreamConfig, MoondreamModel
from .torch.weights import load_weights_into_model
from .types import (
VLM,
Base64EncodedImage,
CaptionOutput,
DetectOutput,
EncodedImage,
PointOutput,
QueryOutput,
SamplingSettings,
)
from .version import __version__


class TorchVL(VLM):
def __init__(
self,
*,
model: str,
):
config = MoondreamConfig()
self.model = MoondreamModel(config)
load_weights_into_model(model, self.model)
self.model.eval()
# Move model to the appropriate device
if torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
self.model.to(self.device)

def encode_image(
self, image: Union[Image.Image, EncodedImage]
) -> Base64EncodedImage:
if isinstance(image, EncodedImage):
assert type(image) == Base64EncodedImage
return image

if not self.model:
raise ValueError("No local model loaded")

return self.model.encode_image(image)

def caption(
self,
image: Union[Image.Image, EncodedImage],
length: Literal["normal", "short"] = "normal",
stream: bool = False,
settings: Optional[SamplingSettings] = None,
) -> CaptionOutput:
if not self.model:
raise ValueError("No local model loaded")

encoded_image = (
self.model.encode_image(image) if isinstance(image, Image.Image) else image
)
return self.model.caption(
encoded_image, length=length, stream=stream, settings=settings
)

def query(
self,
image: Union[Image.Image, EncodedImage],
question: str,
stream: bool = False,
settings: Optional[SamplingSettings] = None,
) -> QueryOutput:
if not self.model:
raise ValueError("No local model loaded")

encoded_image = (
self.model.encode_image(image) if isinstance(image, Image.Image) else image
)
return self.model.query(
encoded_image, question, stream=stream, settings=settings
)

def detect(
self,
image: Union[Image.Image, EncodedImage],
object: str,
) -> DetectOutput:
if not self.model:
raise ValueError("No local model loaded")

encoded_image = (
self.model.encode_image(image) if isinstance(image, Image.Image) else image
)
return self.model.detect(encoded_image, object)

def point(
self,
image: Union[Image.Image, EncodedImage],
object: str,
) -> PointOutput:
if not self.model:
raise ValueError("No local model loaded")

encoded_image = (
self.model.encode_image(image) if isinstance(image, Image.Image) else image
)
return self.model.point(encoded_image, object)
55 changes: 35 additions & 20 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,32 +1,47 @@
[build-system]
requires = [ "poetry-core",]
build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "moondream"
version = "0.0.5"
version = "0.0.2"
description = "Python client library for moondream"
authors = ["vik <vik@moondream.ai>"]
authors = [ "M87 Labs <contact@moondream.ai>",]
readme = "README.md"
[[tool.poetry.packages]]
include = "moondream"
from = "."

[tool.pyright]
venvPath = "."
venv = ".venv"
reportMissingParameterType = false

[tool.poetry.dependencies]
python = "^3.10"
pillow = "^10.4.0"
onnxruntime = "^1.19.2"
numpy = "^2.1.2"
onnx = "^1.17.0"
tokenizers = "^0.20.1"
onnxruntime = { version = ">=1.19.2", optional = true }
tokenizers = { version = ">=0.20.1", optional = true }
torch = { version = ">=2.5.0", optional = true }
safetensors = { version = ">=0.4.2", optional = true }
einops = { version = ">=0.7.0", optional = true }
pyvips-binary = { version = ">=8.16.0", optional = true }
pyvips = { version = ">=2.2.1", optional = true }

[tool.poetry.extras]
cpu = [
"onnxruntime",
"tokenizers"
]
gpu = [
"torch",
"safetensors",
"einops",
"pyvips-binary",
"pyvips",
"tokenizers"
]

[tool.poetry.scripts]
moondream = "moondream.cli:main"

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.4"
pytest-asyncio = "^0.25.1"
requests = "^2.32.3"
black = "^24.10.0"

[tool.pyright]
venvPath = "."
venv = ".venv"
reportMissingParameterType = false

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
2 changes: 0 additions & 2 deletions clients/python/tests/test_api_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_api_initialization(model):
assert isinstance(model, md.cloud_vl.CloudVL)


@pytest.mark.skip(reason="API returning 502 errors, needs investigation")
def test_image_captioning(model, test_image):
# Test normal length caption
result = model.caption(test_image, length="normal")
Expand Down Expand Up @@ -56,7 +55,6 @@ def test_streaming_caption(model, test_image):
assert len(caption) > 0


@pytest.mark.skip(reason="API returning 502 errors, needs investigation")
def test_query_answering(model, test_image):
# Test basic question answering
result = model.query(test_image, "What is in this image?")
Expand Down
3 changes: 2 additions & 1 deletion clients/python/tests/test_local_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_invalid_caption_length(model, test_image):

def test_invalid_model_path():
with pytest.raises(
ValueError, match="Model path is invalid or file does not exist"
ValueError,
match="Unsupported model filetype. Please use a .safetensors for GPU use or .mf for CPU use.",
):
md.vl(model="invalid/path/to/model.bin")
Loading

0 comments on commit 40a90c0

Please sign in to comment.