From fc8a50eeaf1e48148a23f529159284e9325955ab Mon Sep 17 00:00:00 2001 From: Caleb John <45307388+calebjohn24@users.noreply.github.com> Date: Thu, 5 Dec 2024 14:38:59 -0800 Subject: [PATCH] Add sidecar python server (#164) * Setup sidecar python server * Fix ports * Add local python client support * Clean up code, fix cloudvl args * fix build.py * Fix typo in init exception * Fix import order * Cleanup server and init --- clients/python/build.py | 1 + clients/python/moondream/__init__.py | 29 ++- clients/python/moondream/cli.py | 48 +++++ clients/python/moondream/cloud_vl.py | 35 ++-- clients/python/moondream/server.py | 208 ++++++++++++++++++++ clients/python/pyproject.toml | 2 + clients/python/scripts/test_local_server.py | 35 ++++ 7 files changed, 341 insertions(+), 17 deletions(-) create mode 100644 clients/python/moondream/cli.py create mode 100644 clients/python/moondream/server.py create mode 100644 clients/python/scripts/test_local_server.py diff --git a/clients/python/build.py b/clients/python/build.py index 761e8075..460a7e52 100644 --- a/clients/python/build.py +++ b/clients/python/build.py @@ -18,6 +18,7 @@ "numpy": "^2.1.2", "tokenizers": "^0.20.1", }, + "scripts": {"moondream": "moondream.cli:main"}, }, "pyright": { "venvPath": ".", diff --git a/clients/python/moondream/__init__.py b/clients/python/moondream/__init__.py index a38e16e1..feb01791 100644 --- a/clients/python/moondream/__init__.py +++ b/clients/python/moondream/__init__.py @@ -1,14 +1,31 @@ from typing import Optional -from .types import VLM -from .onnx_vl import OnnxVL + from .cloud_vl import CloudVL +from .onnx_vl import OnnxVL +from .types import VLM +DEFAULT_API_URL = "https://api.moondream.ai/v1" -def vl(*, model: Optional[str] = None, api_key: Optional[str] = None) -> VLM: - if api_key: - return CloudVL(api_key) +def vl( + *, + model: Optional[str] = None, + api_key: Optional[str] = None, + api_url: Optional[str] = None, +) -> VLM: if model: return OnnxVL.from_path(model) - raise ValueError("Either model_path or api_key must be provided.") + if api_key: + if not api_url: + api_url = DEFAULT_API_URL + + return CloudVL(api_key=api_key, api_url=api_url) + + if api_url and api_url == DEFAULT_API_URL: + if not api_key: + raise ValueError("An api_key is required for cloud inference.") + + return CloudVL(api_url=api_url) + + raise ValueError("At least one of `model`, `api_key`, or `api_url` is required.") diff --git a/clients/python/moondream/cli.py b/clients/python/moondream/cli.py new file mode 100644 index 00000000..81d7b8f9 --- /dev/null +++ b/clients/python/moondream/cli.py @@ -0,0 +1,48 @@ +import argparse +import sys +from http import server + +from .onnx_vl import OnnxVL +from .server import MoondreamHandler + + +def main(): + parser = argparse.ArgumentParser(description="Moondream CLI") + subparsers = parser.add_subparsers(dest="command", help="Command to run") + + # Server command + server_parser = subparsers.add_parser("serve", help="Start the Moondream server") + server_parser.add_argument("--model", type=str, help="Path to the model file") + server_parser.add_argument( + "--host", type=str, default="localhost", help="Host to bind to" + ) + server_parser.add_argument( + "--port", type=int, default=3475, help="Port to listen on" + ) + + args = parser.parse_args() + + if args.command == "serve": + if args.model: + model = OnnxVL.from_path(args.model) + else: + parser.error("Model path is required") + + MoondreamHandler.model = model + server_address = (args.host, args.port) + try: + httpd = server.HTTPServer(server_address, MoondreamHandler) + print(f"Starting Moondream server on http://{args.host}:{args.port}") + httpd.serve_forever() + except KeyboardInterrupt: + print("\nShutting down server...") + httpd.server_close() + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/clients/python/moondream/cloud_vl.py b/clients/python/moondream/cloud_vl.py index 2ed83d32..df18fa8b 100644 --- a/clients/python/moondream/cloud_vl.py +++ b/clients/python/moondream/cloud_vl.py @@ -1,27 +1,32 @@ +import base64 import json import urllib.request -import base64 -from PIL import Image from io import BytesIO -from typing import Union, Optional, Literal +from typing import Literal, Optional, Union +from PIL import Image from .types import ( VLM, Base64EncodedImage, CaptionOutput, - EncodedImage, - QueryOutput, DetectOutput, + EncodedImage, PointOutput, + QueryOutput, SamplingSettings, ) class CloudVL(VLM): - def __init__(self, api_key: str): + def __init__( + self, + *, + api_url: str = "https://api.moondream.ai/v1", + api_key: Optional[str] = None, + ): self.api_key = api_key - self.api_url = "https://api.moondream.ai/v1" + self.api_url = api_url def encode_image( self, image: Union[Image.Image, EncodedImage] @@ -80,7 +85,9 @@ def caption( } data = json.dumps(payload).encode("utf-8") - headers = {"X-Moondream-Auth": self.api_key, "Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["X-Moondream-Auth"] = self.api_key req = urllib.request.Request( f"{self.api_url}/caption", data=data, @@ -114,7 +121,9 @@ def query( } data = json.dumps(payload).encode("utf-8") - headers = {"X-Moondream-Auth": self.api_key, "Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["X-Moondream-Auth"] = self.api_key req = urllib.request.Request( f"{self.api_url}/query", data=data, @@ -137,7 +146,9 @@ def detect( payload = {"image_url": encoded_image.image_url, "object": object} data = json.dumps(payload).encode("utf-8") - headers = {"X-Moondream-Auth": self.api_key, "Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["X-Moondream-Auth"] = self.api_key req = urllib.request.Request( f"{self.api_url}/detect", data=data, @@ -157,7 +168,9 @@ def point( payload = {"image_url": encoded_image.image_url, "object": object} data = json.dumps(payload).encode("utf-8") - headers = {"X-Moondream-Auth": self.api_key, "Content-Type": "application/json"} + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["X-Moondream-Auth"] = self.api_key req = urllib.request.Request( f"{self.api_url}/point", data=data, diff --git a/clients/python/moondream/server.py b/clients/python/moondream/server.py new file mode 100644 index 00000000..5edeea30 --- /dev/null +++ b/clients/python/moondream/server.py @@ -0,0 +1,208 @@ +import base64 +import io +import json +import logging +import urllib.parse +from http import server +from typing import Any, Dict + +from PIL import Image + +from .onnx_vl import OnnxVL + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +class MoondreamHandler(server.BaseHTTPRequestHandler): + model: OnnxVL = None # Will be set when starting server + + def send_json_response(self, data: Dict[str, Any], status: int = 200) -> None: + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Access-Control-Allow-Origin", "*") + self.end_headers() + self.wfile.write(json.dumps(data).encode()) + + def send_error_response(self, error: str, status: int = 400) -> None: + self.send_json_response({"error": error}, status) + + def handle_image_request(self) -> Image.Image: + content_length = int(self.headers.get("Content-Length", 0)) + if content_length == 0: + raise ValueError("No image data received") + + image_data = self.rfile.read(content_length) + return Image.open(io.BytesIO(image_data)) + + def send_streaming_response(self) -> None: + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Access-Control-Allow-Origin", "*") + self.send_header("Cache-Control", "no-cache") + self.end_headers() + + def stream_tokens(self, chunk: str, completed: bool = False) -> None: + data = {"chunk": chunk, "completed": completed} + self.wfile.write(f"data: {json.dumps(data)}\n\n".encode()) + self.wfile.flush() + + def do_POST(self) -> None: + try: + if self.headers.get("Content-Type") != "application/json": + raise ValueError("Content-Type must be application/json") + + content_length = int(self.headers.get("Content-Length", 0)) + if content_length == 0: + raise ValueError("No data received") + + payload = json.loads(self.rfile.read(content_length)) + + image_url = payload.get("image_url") + if not image_url: + raise ValueError("image_url is required") + + parsed_path = urllib.parse.urlparse(self.path) + endpoint = parsed_path.path + + # Convert base64 image for all endpoints + image = self.decode_base64_image(image_url) + + if endpoint == "/caption": + try: + length = payload.get("length", "normal") + stream = payload.get("stream", False) + + if length not in ["normal", "short"]: + raise ValueError("Length parameter must be 'normal' or 'short'") + + if stream: + self.send_streaming_response() + try: + for tokens in self.model.caption( + image, length=length, stream=True + )["caption"]: + self.stream_tokens(tokens, completed=False) + self.stream_tokens("", completed=True) + except Exception as e: + logger.error( + "Error during caption streaming", exc_info=True + ) + self.stream_tokens( + "An error occurred during caption generation.", + completed=True, + ) + else: + result = self.model.caption(image, length=length) + self.send_json_response(result) + except Exception as e: + logger.error("Caption generation error", exc_info=True) + self.send_error_response("Caption generation failed.") + + elif endpoint == "/query": + try: + question = payload.get("question") + if not question: + raise ValueError("question is required") + + stream = payload.get("stream", False) + if stream: + self.send_streaming_response() + try: + for tokens in self.model.query( + image, question, stream=True + )["answer"]: + self.stream_tokens(tokens, completed=False) + self.stream_tokens("", completed=True) + except Exception as e: + logger.error("Error during query streaming", exc_info=True) + self.stream_tokens( + "An error occurred during query processing.", + completed=True, + ) + else: + result = self.model.query(image, question) + self.send_json_response(result) + except Exception as e: + logger.error("Query processing error", exc_info=True) + self.send_error_response("Query processing failed.") + + elif endpoint == "/detect": + try: + object_name = payload.get("object") + if not object_name: + raise ValueError("object is required") + result = self.model.detect(image, object_name) + self.send_json_response(result) + except Exception as e: + logger.error("Object detection error", exc_info=True) + self.send_error_response("Object detection failed.") + + elif endpoint == "/point": + try: + object_name = payload.get("object") + if not object_name: + raise ValueError("object is required") + result = self.model.point(image, object_name) + self.send_json_response(result) + except Exception as e: + logger.error("Object pointing error", exc_info=True) + self.send_error_response("Object pointing failed.") + + except Exception as e: + logger.error("Unexpected error in request handling", exc_info=True) + self.send_error_response("An unexpected error occurred.") + + def do_GET(self) -> None: + if self.path == "/": + self.send_response(200) + self.send_header("Content-Type", "text/html") + self.send_header("Access-Control-Allow-Origin", "*") + self.end_headers() + html = """ + + +
+Visit the Moondream documentation to learn more.
+ + + """ + self.wfile.write(html.encode()) + else: + self.send_error_response("Method not allowed", 405) + + def decode_base64_image(self, base64_string: str) -> Image.Image: + """Convert a base64 image string to a PIL Image object. + + Args: + base64_string: Base64 encoded image string, may include data URI prefix + + Returns: + PIL Image object + + Raises: + ValueError: If the base64 string is invalid + """ + # Remove data URI prefix if present (e.g., "data:image/jpeg;base64,") + if "base64," in base64_string: + base64_string = base64_string.split("base64,")[1] + + try: + # Decode base64 string to bytes + image_bytes = base64.b64decode(base64_string) + # Convert bytes to PIL Image + return Image.open(io.BytesIO(image_bytes)) + except Exception as e: + raise ValueError(f"Invalid base64 image: {str(e)}") from e diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 7b1cc58c..c158be5c 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -13,6 +13,8 @@ numpy = "^2.1.2" onnx = "^1.17.0" tokenizers = "^0.20.1" +[tool.poetry.scripts] +moondream = "moondream.cli:main" [tool.pyright] venvPath = "." diff --git a/clients/python/scripts/test_local_server.py b/clients/python/scripts/test_local_server.py new file mode 100644 index 00000000..d146544d --- /dev/null +++ b/clients/python/scripts/test_local_server.py @@ -0,0 +1,35 @@ +import moondream as md +from PIL import Image + +base_url = "http://localhost:3475" +local_server_client = md.vl(api_url=base_url) + +image_path = "../../assets/demo-1.jpg" +image = Image.open(image_path) + +print("# Pointing") +object = "person" +print("Local Server:", local_server_client.point(image, object)) + +print("# Captioning") +print("Local Server:", local_server_client.caption(image)) + +print("# Querying") +question = "What is the character eating?" +print("Local Server:", local_server_client.query(image, question)) + +print("# Detecting") +object_to_detect = "burger" +print("Local Server:", local_server_client.detect(image, object_to_detect)) + +print("# Captioning Stream") +print("Local Server:") +for tok in local_server_client.caption(image, stream=True)["caption"]: + print(tok, end="", flush=True) +print() + +print("# Querying Stream") +print("Local Server:") +for tok in local_server_client.query(image, question, stream=True)["answer"]: + print(tok, end="", flush=True) +print()