Skip to content

Commit

Permalink
Add sidecar python server (#164)
Browse files Browse the repository at this point in the history
* Setup sidecar python server

* Fix ports

* Add local python client support

* Clean up code, fix cloudvl args

* fix build.py

* Fix typo in init exception

* Fix import order

* Cleanup server and init
  • Loading branch information
calebjohn24 authored Dec 5, 2024
1 parent 9446f7d commit fc8a50e
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 17 deletions.
1 change: 1 addition & 0 deletions clients/python/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"numpy": "^2.1.2",
"tokenizers": "^0.20.1",
},
"scripts": {"moondream": "moondream.cli:main"},
},
"pyright": {
"venvPath": ".",
Expand Down
29 changes: 23 additions & 6 deletions clients/python/moondream/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
from typing import Optional
from .types import VLM
from .onnx_vl import OnnxVL

from .cloud_vl import CloudVL
from .onnx_vl import OnnxVL
from .types import VLM

DEFAULT_API_URL = "https://api.moondream.ai/v1"

def vl(*, model: Optional[str] = None, api_key: Optional[str] = None) -> VLM:
if api_key:
return CloudVL(api_key)

def vl(
*,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_url: Optional[str] = None,
) -> VLM:
if model:
return OnnxVL.from_path(model)

raise ValueError("Either model_path or api_key must be provided.")
if api_key:
if not api_url:
api_url = DEFAULT_API_URL

return CloudVL(api_key=api_key, api_url=api_url)

if api_url and api_url == DEFAULT_API_URL:
if not api_key:
raise ValueError("An api_key is required for cloud inference.")

return CloudVL(api_url=api_url)

raise ValueError("At least one of `model`, `api_key`, or `api_url` is required.")
48 changes: 48 additions & 0 deletions clients/python/moondream/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import argparse
import sys
from http import server

from .onnx_vl import OnnxVL
from .server import MoondreamHandler


def main():
parser = argparse.ArgumentParser(description="Moondream CLI")
subparsers = parser.add_subparsers(dest="command", help="Command to run")

# Server command
server_parser = subparsers.add_parser("serve", help="Start the Moondream server")
server_parser.add_argument("--model", type=str, help="Path to the model file")
server_parser.add_argument(
"--host", type=str, default="localhost", help="Host to bind to"
)
server_parser.add_argument(
"--port", type=int, default=3475, help="Port to listen on"
)

args = parser.parse_args()

if args.command == "serve":
if args.model:
model = OnnxVL.from_path(args.model)
else:
parser.error("Model path is required")

MoondreamHandler.model = model
server_address = (args.host, args.port)
try:
httpd = server.HTTPServer(server_address, MoondreamHandler)
print(f"Starting Moondream server on http://{args.host}:{args.port}")
httpd.serve_forever()
except KeyboardInterrupt:
print("\nShutting down server...")
httpd.server_close()
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
else:
parser.print_help()


if __name__ == "__main__":
main()
35 changes: 24 additions & 11 deletions clients/python/moondream/cloud_vl.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,32 @@
import base64
import json
import urllib.request
import base64
from PIL import Image
from io import BytesIO
from typing import Union, Optional, Literal
from typing import Literal, Optional, Union

from PIL import Image

from .types import (
VLM,
Base64EncodedImage,
CaptionOutput,
EncodedImage,
QueryOutput,
DetectOutput,
EncodedImage,
PointOutput,
QueryOutput,
SamplingSettings,
)


class CloudVL(VLM):
def __init__(self, api_key: str):
def __init__(
self,
*,
api_url: str = "https://api.moondream.ai/v1",
api_key: Optional[str] = None,
):
self.api_key = api_key
self.api_url = "https://api.moondream.ai/v1"
self.api_url = api_url

def encode_image(
self, image: Union[Image.Image, EncodedImage]
Expand Down Expand Up @@ -80,7 +85,9 @@ def caption(
}

data = json.dumps(payload).encode("utf-8")
headers = {"X-Moondream-Auth": self.api_key, "Content-Type": "application/json"}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-Moondream-Auth"] = self.api_key
req = urllib.request.Request(
f"{self.api_url}/caption",
data=data,
Expand Down Expand Up @@ -114,7 +121,9 @@ def query(
}

data = json.dumps(payload).encode("utf-8")
headers = {"X-Moondream-Auth": self.api_key, "Content-Type": "application/json"}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-Moondream-Auth"] = self.api_key
req = urllib.request.Request(
f"{self.api_url}/query",
data=data,
Expand All @@ -137,7 +146,9 @@ def detect(
payload = {"image_url": encoded_image.image_url, "object": object}

data = json.dumps(payload).encode("utf-8")
headers = {"X-Moondream-Auth": self.api_key, "Content-Type": "application/json"}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-Moondream-Auth"] = self.api_key
req = urllib.request.Request(
f"{self.api_url}/detect",
data=data,
Expand All @@ -157,7 +168,9 @@ def point(
payload = {"image_url": encoded_image.image_url, "object": object}

data = json.dumps(payload).encode("utf-8")
headers = {"X-Moondream-Auth": self.api_key, "Content-Type": "application/json"}
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-Moondream-Auth"] = self.api_key
req = urllib.request.Request(
f"{self.api_url}/point",
data=data,
Expand Down
208 changes: 208 additions & 0 deletions clients/python/moondream/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import base64
import io
import json
import logging
import urllib.parse
from http import server
from typing import Any, Dict

from PIL import Image

from .onnx_vl import OnnxVL

# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


class MoondreamHandler(server.BaseHTTPRequestHandler):
model: OnnxVL = None # Will be set when starting server

def send_json_response(self, data: Dict[str, Any], status: int = 200) -> None:
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Access-Control-Allow-Origin", "*")
self.end_headers()
self.wfile.write(json.dumps(data).encode())

def send_error_response(self, error: str, status: int = 400) -> None:
self.send_json_response({"error": error}, status)

def handle_image_request(self) -> Image.Image:
content_length = int(self.headers.get("Content-Length", 0))
if content_length == 0:
raise ValueError("No image data received")

image_data = self.rfile.read(content_length)
return Image.open(io.BytesIO(image_data))

def send_streaming_response(self) -> None:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Cache-Control", "no-cache")
self.end_headers()

def stream_tokens(self, chunk: str, completed: bool = False) -> None:
data = {"chunk": chunk, "completed": completed}
self.wfile.write(f"data: {json.dumps(data)}\n\n".encode())
self.wfile.flush()

def do_POST(self) -> None:
try:
if self.headers.get("Content-Type") != "application/json":
raise ValueError("Content-Type must be application/json")

content_length = int(self.headers.get("Content-Length", 0))
if content_length == 0:
raise ValueError("No data received")

payload = json.loads(self.rfile.read(content_length))

image_url = payload.get("image_url")
if not image_url:
raise ValueError("image_url is required")

parsed_path = urllib.parse.urlparse(self.path)
endpoint = parsed_path.path

# Convert base64 image for all endpoints
image = self.decode_base64_image(image_url)

if endpoint == "/caption":
try:
length = payload.get("length", "normal")
stream = payload.get("stream", False)

if length not in ["normal", "short"]:
raise ValueError("Length parameter must be 'normal' or 'short'")

if stream:
self.send_streaming_response()
try:
for tokens in self.model.caption(
image, length=length, stream=True
)["caption"]:
self.stream_tokens(tokens, completed=False)
self.stream_tokens("", completed=True)
except Exception as e:
logger.error(
"Error during caption streaming", exc_info=True
)
self.stream_tokens(
"An error occurred during caption generation.",
completed=True,
)
else:
result = self.model.caption(image, length=length)
self.send_json_response(result)
except Exception as e:
logger.error("Caption generation error", exc_info=True)
self.send_error_response("Caption generation failed.")

elif endpoint == "/query":
try:
question = payload.get("question")
if not question:
raise ValueError("question is required")

stream = payload.get("stream", False)
if stream:
self.send_streaming_response()
try:
for tokens in self.model.query(
image, question, stream=True
)["answer"]:
self.stream_tokens(tokens, completed=False)
self.stream_tokens("", completed=True)
except Exception as e:
logger.error("Error during query streaming", exc_info=True)
self.stream_tokens(
"An error occurred during query processing.",
completed=True,
)
else:
result = self.model.query(image, question)
self.send_json_response(result)
except Exception as e:
logger.error("Query processing error", exc_info=True)
self.send_error_response("Query processing failed.")

elif endpoint == "/detect":
try:
object_name = payload.get("object")
if not object_name:
raise ValueError("object is required")
result = self.model.detect(image, object_name)
self.send_json_response(result)
except Exception as e:
logger.error("Object detection error", exc_info=True)
self.send_error_response("Object detection failed.")

elif endpoint == "/point":
try:
object_name = payload.get("object")
if not object_name:
raise ValueError("object is required")
result = self.model.point(image, object_name)
self.send_json_response(result)
except Exception as e:
logger.error("Object pointing error", exc_info=True)
self.send_error_response("Object pointing failed.")

except Exception as e:
logger.error("Unexpected error in request handling", exc_info=True)
self.send_error_response("An unexpected error occurred.")

def do_GET(self) -> None:
if self.path == "/":
self.send_response(200)
self.send_header("Content-Type", "text/html")
self.send_header("Access-Control-Allow-Origin", "*")
self.end_headers()
html = """
<!DOCTYPE html>
<html>
<head>
<title>Moondream Local Inference Server</title>
<link rel="icon" href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 100 100%22><text y=%22.9em%22 font-size=%2290%22>🌙</text></svg>">
<style>
body { font-family: system-ui, sans-serif; max-width: 1200px; margin: 40px auto; padding: 0 20px; }
a { color: #0066cc; }
</style>
</head>
<body>
<h1>Moondream Local Inference Server is Running!</h1>
<p>Visit the <a href="https://docs.moondream.ai">Moondream documentation</a> to learn more.</p>
</body>
</html>
"""
self.wfile.write(html.encode())
else:
self.send_error_response("Method not allowed", 405)

def decode_base64_image(self, base64_string: str) -> Image.Image:
"""Convert a base64 image string to a PIL Image object.
Args:
base64_string: Base64 encoded image string, may include data URI prefix
Returns:
PIL Image object
Raises:
ValueError: If the base64 string is invalid
"""
# Remove data URI prefix if present (e.g., "data:image/jpeg;base64,")
if "base64," in base64_string:
base64_string = base64_string.split("base64,")[1]

try:
# Decode base64 string to bytes
image_bytes = base64.b64decode(base64_string)
# Convert bytes to PIL Image
return Image.open(io.BytesIO(image_bytes))
except Exception as e:
raise ValueError(f"Invalid base64 image: {str(e)}") from e
2 changes: 2 additions & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ numpy = "^2.1.2"
onnx = "^1.17.0"
tokenizers = "^0.20.1"

[tool.poetry.scripts]
moondream = "moondream.cli:main"

[tool.pyright]
venvPath = "."
Expand Down
Loading

0 comments on commit fc8a50e

Please sign in to comment.