From f77317effbfe391cfd4132e09ef34716550adaa4 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 31 Jan 2025 14:19:52 +0000 Subject: [PATCH] add chatgpt bridge example --- examples/chatgpt/.env.example | 5 + examples/chatgpt/.gitignore | 14 ++ examples/chatgpt/.python-version | 1 + examples/chatgpt/.vscode/launch.json | 21 ++ examples/chatgpt/README.md | 13 ++ examples/chatgpt/pyproject.toml | 81 +++++++ examples/chatgpt/requirements.lock | 91 ++++++++ .../chatgpt/src/chatgpt_bridge/__init__.py | 115 ++++++++++ examples/chatgpt/src/chatgpt_bridge/env.py | 20 ++ examples/chatgpt/src/chatgpt_bridge/py.typed | 0 .../src/chatgpt_bridge/skills/__init__.py | 0 .../chatgpt/src/chatgpt_bridge/skills/box.py | 83 +++++++ .../src/chatgpt_bridge/skills/cutout.py | 126 +++++++++++ .../src/chatgpt_bridge/skills/erase.py | 127 +++++++++++ .../src/chatgpt_bridge/skills/recolor.py | 206 ++++++++++++++++++ .../src/chatgpt_bridge/skills/shadow.py | 126 +++++++++++ .../chatgpt/src/chatgpt_bridge/skills/undo.py | 48 ++++ .../src/chatgpt_bridge/skills/upscale.py | 64 ++++++ examples/chatgpt/src/chatgpt_bridge/utils.py | 106 +++++++++ examples/chatgpt/tests/__init__.py | 0 examples/chatgpt/tests/conftest.py | 26 +++ examples/chatgpt/tests/test_box.py | 93 ++++++++ examples/chatgpt/tests/test_cutout.py | 153 +++++++++++++ examples/chatgpt/tests/test_erase.py | 133 +++++++++++ examples/chatgpt/tests/test_recolor.py | 200 +++++++++++++++++ examples/chatgpt/tests/test_shadow.py | 171 +++++++++++++++ examples/chatgpt/tests/test_upscale.py | 60 +++++ examples/chatgpt/tests/utils.py | 26 +++ 28 files changed, 2109 insertions(+) create mode 100644 examples/chatgpt/.env.example create mode 100644 examples/chatgpt/.gitignore create mode 100644 examples/chatgpt/.python-version create mode 100644 examples/chatgpt/.vscode/launch.json create mode 100644 examples/chatgpt/README.md create mode 100644 examples/chatgpt/pyproject.toml create mode 100644 examples/chatgpt/requirements.lock create mode 100644 examples/chatgpt/src/chatgpt_bridge/__init__.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/env.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/py.typed create mode 100644 examples/chatgpt/src/chatgpt_bridge/skills/__init__.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/skills/box.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/skills/cutout.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/skills/erase.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/skills/recolor.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/skills/shadow.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/skills/undo.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/skills/upscale.py create mode 100644 examples/chatgpt/src/chatgpt_bridge/utils.py create mode 100644 examples/chatgpt/tests/__init__.py create mode 100644 examples/chatgpt/tests/conftest.py create mode 100644 examples/chatgpt/tests/test_box.py create mode 100644 examples/chatgpt/tests/test_cutout.py create mode 100644 examples/chatgpt/tests/test_erase.py create mode 100644 examples/chatgpt/tests/test_recolor.py create mode 100644 examples/chatgpt/tests/test_shadow.py create mode 100644 examples/chatgpt/tests/test_upscale.py create mode 100644 examples/chatgpt/tests/utils.py diff --git a/examples/chatgpt/.env.example b/examples/chatgpt/.env.example new file mode 100644 index 0000000..891a259 --- /dev/null +++ b/examples/chatgpt/.env.example @@ -0,0 +1,5 @@ +LOGLEVEL='INFO' +APP_LOGLEVEL='INFO' +FG_API_USER='myuser@something.com' +FG_API_PASSWORD='VERY_SECURE_PASSWORD' +CHATGPT_AUTH_TOKEN='VERY_SECURE_TOKEN' diff --git a/examples/chatgpt/.gitignore b/examples/chatgpt/.gitignore new file mode 100644 index 0000000..ea1c79d --- /dev/null +++ b/examples/chatgpt/.gitignore @@ -0,0 +1,14 @@ +# python generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + + +.env +.venv +.coverage +coverage.xml +requirements-dev.lock diff --git a/examples/chatgpt/.python-version b/examples/chatgpt/.python-version new file mode 100644 index 0000000..04e2079 --- /dev/null +++ b/examples/chatgpt/.python-version @@ -0,0 +1 @@ +3.12.8 diff --git a/examples/chatgpt/.vscode/launch.json b/examples/chatgpt/.vscode/launch.json new file mode 100644 index 0000000..1f2b129 --- /dev/null +++ b/examples/chatgpt/.vscode/launch.json @@ -0,0 +1,21 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Module", + "type": "debugpy", + "request": "launch", + "module": "quart", + "args": [ + "run", + "--port", + "8601", + "--host", + "0.0.0.0", + ], + "env": { + "QUART_APP": "chatgpt_bridge:app", + } + } + ] +} diff --git a/examples/chatgpt/README.md b/examples/chatgpt/README.md new file mode 100644 index 0000000..86da69f --- /dev/null +++ b/examples/chatgpt/README.md @@ -0,0 +1,13 @@ +# ChatGPT bridge + +Use the Finegrain API right from ChatGPT. + +## Usage + +```bash +rye sync --all-features +``` + +```bash +QUART_APP=chatgpt_bridge quart run +``` diff --git a/examples/chatgpt/pyproject.toml b/examples/chatgpt/pyproject.toml new file mode 100644 index 0000000..f33d4fc --- /dev/null +++ b/examples/chatgpt/pyproject.toml @@ -0,0 +1,81 @@ +[project] +authors = [ + { name = "Laurent Fainsin", email = "laurent@lagon.tech" }, +] +dependencies = [ + "environs>=14.1.0", + "finegrain @ git+https://github.com/finegrain-ai/finegrain-python.git#subdirectory=finegrain", + "pydantic>=2.10.5", + "quart>=0.20.0", + "pillow>=11.1.0", +] +description = "Finegrain API ChatGPT bridge" +name = "chatgpt-bridge" +readme = "README.md" +requires-python = ">= 3.12" +version = "0.1.0" + +[tool.rye] +dev-dependencies = [ + "pyright>=1.1.392", + "pytest-asyncio>=0.25.2", + "pytest>=8.3.4", + "ruff>=0.9.2", + "typos>=1.29.4", + "pytest-cov>=6.0.0", +] +managed = true + +[build-system] +build-backend = "hatchling.build" +requires = ["hatchling"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/chatgpt_bridge"] + +[tool.ruff] +line-length = 120 +target-version = "py312" + +[tool.ruff.lint] +ignore = [ + "C901", # is too complex + "E731", # do-not-assign-lambda + "EM101", # exception must not use a string literal + "EM102", # f-string literal in exception message + "G004", # f-string literal in logging message + "N812", # imported as non-lowercase + "S101", # use of assert + "S311", # non secure cryptographic random +] +select = [ + "A", # flake8-builtins + "B", # flake8-bugbear + "E", # pycodestyle errors + "EM", # flake8-errmsg + "F", # Pyflakes + "G", # flake8-logging-format + "I", # isort + "N", # pep8-naming + "PIE", # flake8-pie + "PTH", # flake8-use-pathlib + "RUF", # ruff + "S", # flake8-bandit + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warnings +] + +[tool.pyright] +pythonVersion = "3.12" +reportMissingImports = true +reportMissingTypeStubs = false +reportPrivateUsage = false +reportUntypedFunctionDecorator = false + +[tool.pytest.ini_options] +asyncio_mode = "auto" +log_cli = true diff --git a/examples/chatgpt/requirements.lock b/examples/chatgpt/requirements.lock new file mode 100644 index 0000000..63491de --- /dev/null +++ b/examples/chatgpt/requirements.lock @@ -0,0 +1,91 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: true +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +aiofiles==24.1.0 + # via quart +annotated-types==0.7.0 + # via pydantic +anyio==4.8.0 + # via httpx +blinker==1.9.0 + # via flask + # via quart +certifi==2024.12.14 + # via httpcore + # via httpx +click==8.1.8 + # via flask + # via quart +environs==14.1.0 + # via chatgpt-bridge +finegrain @ git+https://github.com/finegrain-ai/finegrain-python.git@f0760005675992988c2913afee37cdec7d7d9327#subdirectory=finegrain + # via chatgpt-bridge +flask==3.1.0 + # via quart +h11==0.14.0 + # via httpcore + # via hypercorn + # via wsproto +h2==4.1.0 + # via hypercorn +hpack==4.1.0 + # via h2 +httpcore==1.0.7 + # via httpx +httpx==0.28.1 + # via finegrain +httpx-sse==0.4.0 + # via finegrain +hypercorn==0.17.3 + # via quart +hyperframe==6.1.0 + # via h2 +idna==3.10 + # via anyio + # via httpx +itsdangerous==2.2.0 + # via flask + # via quart +jinja2==3.1.5 + # via flask + # via quart +markupsafe==3.0.2 + # via jinja2 + # via quart + # via werkzeug +marshmallow==3.26.0 + # via environs +packaging==24.2 + # via marshmallow +pillow==11.1.0 + # via chatgpt-bridge +priority==2.0.0 + # via hypercorn +pydantic==2.10.6 + # via chatgpt-bridge +pydantic-core==2.27.2 + # via pydantic +python-dotenv==1.0.1 + # via environs +quart==0.20.0 + # via chatgpt-bridge +sniffio==1.3.1 + # via anyio +typing-extensions==4.12.2 + # via anyio + # via pydantic + # via pydantic-core +werkzeug==3.1.3 + # via flask + # via quart +wsproto==1.2.0 + # via hypercorn diff --git a/examples/chatgpt/src/chatgpt_bridge/__init__.py b/examples/chatgpt/src/chatgpt_bridge/__init__.py new file mode 100644 index 0000000..e3005b1 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/__init__.py @@ -0,0 +1,115 @@ +import logging +from typing import Any + +from finegrain import EditorAPIContext +from quart import Quart, Response, jsonify, request + +from chatgpt_bridge.env import ( + APP_LOGLEVEL, + CHATGPT_AUTH_TOKEN, + FG_API_PASSWORD, + FG_API_PRIORITY, + FG_API_TIMEOUT, + FG_API_URL, + FG_API_USER, + LOGLEVEL, +) +from chatgpt_bridge.skills.box import _box +from chatgpt_bridge.skills.cutout import _cutout +from chatgpt_bridge.skills.erase import _eraser +from chatgpt_bridge.skills.recolor import _recolor +from chatgpt_bridge.skills.shadow import _shadow +from chatgpt_bridge.skills.undo import _undo +from chatgpt_bridge.skills.upscale import _upscale +from chatgpt_bridge.utils import json_error, require_basic_auth_token + +ctx = EditorAPIContext( + base_url=FG_API_URL, + user=FG_API_USER, + password=FG_API_PASSWORD, + priority=FG_API_PRIORITY, + default_timeout=FG_API_TIMEOUT, +) + +app = Quart(__name__) + +logging.basicConfig(level=LOGLEVEL) +app.logger.setLevel(APP_LOGLEVEL) +app.logger.info(f"LOGLEVEL: {LOGLEVEL}") +app.logger.info(f"FG_API_URL: {FG_API_URL}") +app.logger.info(f"FG_API_USER: {FG_API_USER}") +app.logger.info(f"FG_API_TIMEOUT: {FG_API_TIMEOUT}") +app.logger.info(f"FG_API_PRIORITY: {FG_API_PRIORITY}") + + +@app.before_serving +async def login() -> None: + await ctx.login() + + +@app.before_serving +async def sse_start() -> None: + await ctx.sse_start() + + +@app.after_serving +async def sse_stop() -> None: + await ctx.sse_stop() + + +@app.before_request +async def log_request() -> None: + app.logger.debug(f"Incoming request: {request.method} {request.path}") + + +@app.errorhandler(RuntimeError) +async def handle_runtime_error(error: RuntimeError) -> Response: + app.logger.error(f"RuntimeError: {error}") + return json_error(str(error)) + + +@app.route("/health") +async def health() -> Response: + return jsonify({"status": "healthy"}) + + +@app.post("/upscale") +@require_basic_auth_token(CHATGPT_AUTH_TOKEN) +async def upscale() -> Any: + return await _upscale(ctx, request) + + +@app.post("/box") +@require_basic_auth_token(CHATGPT_AUTH_TOKEN) +async def infer_bbox() -> Any: + return await _box(ctx, request) + + +@app.post("/cutout") +@require_basic_auth_token(CHATGPT_AUTH_TOKEN) +async def cutout() -> Any: + return await _cutout(ctx, request) + + +@app.post("/erase") +@require_basic_auth_token(CHATGPT_AUTH_TOKEN) +async def erase() -> Any: + return await _eraser(ctx, request) + + +@app.post("/recolor") +@require_basic_auth_token(CHATGPT_AUTH_TOKEN) +async def recolor() -> Any: + return await _recolor(ctx, request) + + +@app.post("/shadow") +@require_basic_auth_token(CHATGPT_AUTH_TOKEN) +async def shadow() -> Any: + return await _shadow(ctx, request) + + +@app.post("/undo") +@require_basic_auth_token(CHATGPT_AUTH_TOKEN) +async def undo() -> Any: + return await _undo(ctx, request) diff --git a/examples/chatgpt/src/chatgpt_bridge/env.py b/examples/chatgpt/src/chatgpt_bridge/env.py new file mode 100644 index 0000000..3ff83f5 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/env.py @@ -0,0 +1,20 @@ +from typing import cast + +from environs import Env +from finegrain import Priority + +env = Env() +env.read_env() + +with env.prefixed("FG_"): + FG_API_URL: str = str(env.str("API_URL", "https://api.finegrain.ai/editor")) + FG_API_USER: str = env.str("API_USER") + FG_API_PASSWORD: str = env.str("API_PASSWORD") + FG_API_PRIORITY: Priority = cast(Priority, env.str("API_PRIORITY", "low").lower()) + FG_API_TIMEOUT: int = env.int("API_TIMEOUT", 60) + +with env.prefixed("CHATGPT_"): + CHATGPT_AUTH_TOKEN: str = env.str("AUTH_TOKEN") + +LOGLEVEL = env.str("LOGLEVEL", "INFO").upper() +APP_LOGLEVEL = env.str("APP_LOGLEVEL", "INFO").upper() diff --git a/examples/chatgpt/src/chatgpt_bridge/py.typed b/examples/chatgpt/src/chatgpt_bridge/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/examples/chatgpt/src/chatgpt_bridge/skills/__init__.py b/examples/chatgpt/src/chatgpt_bridge/skills/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/chatgpt/src/chatgpt_bridge/skills/box.py b/examples/chatgpt/src/chatgpt_bridge/skills/box.py new file mode 100644 index 0000000..6963842 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/skills/box.py @@ -0,0 +1,83 @@ +from finegrain import EditorAPIContext +from pydantic import BaseModel +from quart import Request, Response, jsonify +from quart import current_app as app + +from chatgpt_bridge.utils import BoundingBox, OpenaiFileIdRef, StateID, create_state, json_error + + +class BoxParams(BaseModel): + openaiFileIdRefs: list[OpenaiFileIdRef] | None = None # noqa: N815 + stateids_input: list[StateID] | None = None + object_names: list[str] | None = None + + +class BoxOutput(BaseModel): + bounding_boxes: list[BoundingBox] + + +async def process( + ctx: EditorAPIContext, + stateid_input: StateID, + object_name: str, +) -> BoundingBox: + # queue skills/infer-bbox + stateid_bbox = await ctx.ensure_skill( + url=f"infer-bbox/{stateid_input}", + params={"product_name": object_name}, + ) + app.logger.debug(f"stateid_bbox: {stateid_bbox}") + + # get bbox state/meta + metadata_bbox = await ctx.get_meta(stateid_bbox) + bounding_box = metadata_bbox["bbox"] + + return bounding_box + + +async def _box(ctx: EditorAPIContext, request: Request) -> Response: + # parse input data + input_json = await request.get_json() + app.logger.debug(f"json payload: {input_json}") + input_data = BoxParams(**input_json) + app.logger.debug(f"parsed payload: {input_data}") + + # get stateids_input, or create them from openaiFileIdRefs + if input_data.stateids_input: + stateids_input = input_data.stateids_input + elif input_data.openaiFileIdRefs: + stateids_input: list[StateID] = [] + for oai_ref in input_data.openaiFileIdRefs: + if oai_ref.download_link: + stateid_input = await create_state(ctx, oai_ref.download_link) + stateids_input.append(stateid_input) + else: + return json_error("stateids_input or openaiFileIdRefs is required", 400) + app.logger.debug(f"stateids_input: {stateids_input}") + + # validate object_names + if input_data.object_names is None: + return json_error("object_names is required", 400) + if not input_data.object_names: + return json_error("object_names cannot be empty", 400) + for object_name in input_data.object_names: + if not object_name: + return json_error("object name cannot be empty", 400) + if len(input_data.object_names) != len(stateids_input): + return json_error("stateids_input and object_names must have the same length", 400) + + # process the inputs + bounding_boxes = [ + await process(ctx, stateid_input, object_name) + for stateid_input, object_name in zip( + stateids_input, + input_data.object_names, + strict=True, + ) + ] + + # build output response + output_data = BoxOutput(bounding_boxes=bounding_boxes) + app.logger.debug(f"output payload: {output_data}") + output_response = jsonify(output_data.model_dump()) + return output_response diff --git a/examples/chatgpt/src/chatgpt_bridge/skills/cutout.py b/examples/chatgpt/src/chatgpt_bridge/skills/cutout.py new file mode 100644 index 0000000..4be76c8 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/skills/cutout.py @@ -0,0 +1,126 @@ +from finegrain import EditorAPIContext +from PIL import Image +from pydantic import BaseModel +from quart import Request, Response, jsonify +from quart import current_app as app + +from chatgpt_bridge.utils import OpenaiFileIdRef, OpenaiFileResponse, StateID, create_state, download_image, json_error + + +class CutoutParams(BaseModel): + openaiFileIdRefs: list[OpenaiFileIdRef] | None = None # noqa: N815 + stateids_input: list[StateID] | None = None + background_colors: list[str] | None = None + object_names: list[str] | None = None + + +class CutoutOutput(BaseModel): + openaiFileResponse: list[OpenaiFileResponse] # noqa: N815 + stateids_undo: list[StateID] + + +async def process( + ctx: EditorAPIContext, + stateid_input: StateID, + background_color: str, + object_name: str, +) -> Image.Image: + # queue skills/infer-bbox + stateid_bbox = await ctx.ensure_skill( + url=f"infer-bbox/{stateid_input}", + params={"product_name": object_name}, + ) + app.logger.debug(f"stateid_bbox: {stateid_bbox}") + + # queue skills/segment + stateid_mask = await ctx.ensure_skill(url=f"segment/{stateid_bbox}") + app.logger.debug(f"stateid_mask: {stateid_mask}") + + # queue skills/cutout + stateid_cutout = await ctx.ensure_skill(url=f"cutout/{stateid_input}/{stateid_mask}") + app.logger.debug(f"stateid_cutout: {stateid_cutout}") + + # download cutout from API + cutout = await download_image(ctx=ctx, stateid=stateid_cutout) + + # paste cutout onto a blank image, with margins + cutout_margin = Image.new( + mode="RGBA", + size=(int(1.618 * cutout.width), int(1.618 * cutout.height)), + color=background_color, + ) + bbox = ( + (cutout_margin.width - cutout.width) // 2, + (cutout_margin.height - cutout.height) // 2, + (cutout_margin.width - cutout.width) // 2 + cutout.width, + (cutout_margin.height - cutout.height) // 2 + cutout.height, + ) + cutout_margin.paste(cutout, bbox, cutout) + cutout_margin = cutout_margin.convert("RGB") + + return cutout_margin + + +async def _cutout(ctx: EditorAPIContext, request: Request) -> Response: + # parse input data + input_json = await request.get_json() + app.logger.debug(f"json payload: {input_json}") + input_data = CutoutParams(**input_json) + app.logger.debug(f"parsed payload: {input_data}") + + # get stateids_input, or create them from openaiFileIdRefs + if input_data.stateids_input: + stateids_input = input_data.stateids_input + elif input_data.openaiFileIdRefs: + stateids_input: list[str] = [] + for oai_ref in input_data.openaiFileIdRefs: + if oai_ref.download_link: + stateid_input = await create_state(ctx, oai_ref.download_link) + stateids_input.append(stateid_input) + else: + return json_error("stateids_input or openaiFileIdRefs is required", 400) + app.logger.debug(f"stateids_input: {stateids_input}") + + # validate input data + if input_data.object_names is None: + return json_error("object_names is required", 400) + if len(stateids_input) != len(input_data.object_names): + return json_error("stateids_input and object_names must have the same length", 400) + for object_name in input_data.object_names: + if not object_name: + return json_error("object name cannot be empty", 400) + if input_data.background_colors is None: + input_data.background_colors = ["#ffffff"] * len(stateids_input) + if len(input_data.background_colors) != len(stateids_input): + return json_error("stateids_input and background_colors must have the same length", 400) + + # process the inputs + cutouts = [ + await process( + ctx=ctx, + object_name=object_name, + stateid_input=stateid_input, + background_color=background_color, + ) + for stateid_input, object_name, background_color in zip( + stateids_input, + input_data.object_names, + input_data.background_colors, + strict=True, + ) + ] + + # build output response + output_data = CutoutOutput( + openaiFileResponse=[ + OpenaiFileResponse.from_image( + image=cutout, + name=f"cutout_{i}", + ) + for i, cutout in enumerate(cutouts) + ], + stateids_undo=stateids_input, + ) + app.logger.debug(f"output_json: {output_data}") + output_response = jsonify(output_data.model_dump()) + return output_response diff --git a/examples/chatgpt/src/chatgpt_bridge/skills/erase.py b/examples/chatgpt/src/chatgpt_bridge/skills/erase.py new file mode 100644 index 0000000..d5ea584 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/skills/erase.py @@ -0,0 +1,127 @@ +from finegrain import EditorAPIContext +from pydantic import BaseModel +from quart import Request, Response, jsonify +from quart import current_app as app + +from chatgpt_bridge.utils import OpenaiFileIdRef, OpenaiFileResponse, StateID, create_state, download_image, json_error + + +class EraseParams(BaseModel): + openaiFileIdRefs: list[OpenaiFileIdRef] | None = None # noqa: N815 + stateids_input: list[StateID] | None = None + object_names: list[list[str]] | None = None + + +class EraseOutput(BaseModel): + openaiFileResponse: list[OpenaiFileResponse] # noqa: N815 + stateids_output: list[StateID] + stateids_undo: list[StateID] + + +async def process( + ctx: EditorAPIContext, + stateid_input: StateID, + object_names: list[str], +) -> StateID: + # queue skills/infer-bbox + stateids_bbox = [] + for name in object_names: + stateid_bbox = await ctx.ensure_skill( + url=f"infer-bbox/{stateid_input}", + params={"product_name": name}, + ) + app.logger.debug(f"stateid_bbox: {stateid_bbox}") + stateids_bbox.append(stateid_bbox) + app.logger.debug(f"stateids_bbox: {stateids_bbox}") + + # queue skills/segment + stateids_mask = [] + for stateid_bbox in stateids_bbox: + stateid_mask = await ctx.ensure_skill(url=f"segment/{stateid_bbox}") + app.logger.debug(f"stateid_mask: {stateid_mask}") + stateids_mask.append(stateid_mask) + app.logger.debug(f"stateids_mask: {stateids_mask}") + + # queue skills/merge-masks for positive objects + if len(stateids_mask) == 1: + stateid_mask_union = stateids_mask[0] + else: + stateid_mask_union = await ctx.ensure_skill( + url="merge-masks", + params={ + "operation": "union", + "states": stateids_mask, + }, + ) + app.logger.debug(f"stateid_mask_positive: {stateid_mask_union}") + + # queue skills/erase + stateid_erased = await ctx.ensure_skill( + url=f"erase/{stateid_input}/{stateid_mask_union}", + params={"mode": "free"}, + ) + app.logger.debug(f"stateid_erased: {stateid_erased}") + + return stateid_erased + + +async def _eraser(ctx: EditorAPIContext, request: Request) -> Response: + # parse input data + input_json = await request.get_json() + app.logger.debug(f"json payload: {input_json}") + input_data = EraseParams(**input_json) + app.logger.debug(f"parsed payload: {input_data}") + + # get stateids_input, or create them from openaiFileIdRefs + if input_data.stateids_input: + stateids_input = input_data.stateids_input + elif input_data.openaiFileIdRefs: + stateids_input: list[StateID] = [] + for oai_ref in input_data.openaiFileIdRefs: + if oai_ref.download_link: + stateid_input = await create_state(ctx, oai_ref.download_link) + stateids_input.append(stateid_input) + else: + return json_error("stateids_input or openaiFileIdRefs is required", 400) + app.logger.debug(f"stateids_input: {stateids_input}") + + # validate the inputs + if input_data.object_names is None: + return json_error("object_names is required", 400) + if len(stateids_input) != len(input_data.object_names): + return json_error("stateids_input and object_names must have the same length", 400) + for object_names in input_data.object_names: + if not object_names: + return json_error("object list cannot be empty", 400) + for object_name in object_names: + if not object_name: + return json_error("object name cannot be empty", 400) + + # process the inputs + stateids_erased = [ + await process(ctx, stateid_input, object_names) + for stateid_input, object_names in zip(stateids_input, input_data.object_names, strict=True) + ] + app.logger.debug(f"stateids_erased: {stateids_erased}") + + # download images from API + erased_imgs = [ + await download_image(ctx=ctx, stateid=stateid_erased_img) # + for stateid_erased_img in stateids_erased + ] + + # build output response + output_data = EraseOutput( + openaiFileResponse=[ + OpenaiFileResponse.from_image( + image=erased_img, + name=f"erased_{i}", + ) + for i, erased_img in enumerate(erased_imgs) + ], + stateids_output=stateids_erased, + stateids_undo=stateids_input, + ) + app.logger.debug(f"output payload: {output_data}") + output_response = jsonify(output_data.model_dump()) + return output_response diff --git a/examples/chatgpt/src/chatgpt_bridge/skills/recolor.py b/examples/chatgpt/src/chatgpt_bridge/skills/recolor.py new file mode 100644 index 0000000..61936d2 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/skills/recolor.py @@ -0,0 +1,206 @@ +from finegrain import EditorAPIContext +from pydantic import BaseModel +from quart import Request, Response, jsonify +from quart import current_app as app + +from chatgpt_bridge.utils import OpenaiFileIdRef, OpenaiFileResponse, StateID, create_state, download_image, json_error + + +class RecolorParams(BaseModel): + openaiFileIdRefs: list[OpenaiFileIdRef] | None = None # noqa: N815 + stateids_input: list[StateID] | None = None + positive_object_names: list[list[str]] | None = None + negative_object_names: list[list[str]] | None = None + object_colors: list[str] | None = None + + +class RecolorOutput(BaseModel): + openaiFileResponse: list[OpenaiFileResponse] # noqa: N815 + stateids_output: list[StateID] + stateids_undo: list[StateID] + + +async def process( + ctx: EditorAPIContext, + object_color: str, + stateid_input: str, + positive_object_names: list[str], + negative_object_names: list[str], +) -> str: + # queue skills/infer-bbox for positive objects + stateids_bbox_positive = [] + for name in positive_object_names: + stateid_bbox_positive = await ctx.ensure_skill( + url=f"infer-bbox/{stateid_input}", + params={"product_name": name}, + ) + app.logger.debug(f"stateid_bbox_positive: {stateid_bbox_positive}") + stateids_bbox_positive.append(stateid_bbox_positive) + app.logger.debug(f"stateids_bbox_positive: {stateids_bbox_positive}") + + # queue skills/infer-bbox for negative objects + stateids_bbox_negative = [] + for name in negative_object_names: + stateid_bbox_negative = await ctx.ensure_skill( + url=f"infer-bbox/{stateid_input}", + params={"product_name": name}, + ) + app.logger.debug(f"stateid_bbox_negative: {stateid_bbox_negative}") + stateids_bbox_negative.append(stateid_bbox_negative) + app.logger.debug(f"stateids_bbox_negative: {stateids_bbox_negative}") + + # queue skills/segment for positive objects + stateids_mask_positive = [] + for stateid_bbox in stateids_bbox_positive: + stateid_mask_positive = await ctx.ensure_skill(url=f"segment/{stateid_bbox}") + app.logger.debug(f"stateid_mask_positive: {stateid_mask_positive}") + stateids_mask_positive.append(stateid_mask_positive) + app.logger.debug(f"stateids_mask_positive: {stateids_mask_positive}") + + # queue skills/segment for negative objects + stateids_mask_negative = [] + for stateid_bbox in stateids_bbox_negative: + stateid_mask_negative = await ctx.ensure_skill(url=f"segment/{stateid_bbox}") + app.logger.debug(f"stateid_mask_negative: {stateid_mask_negative}") + stateids_mask_negative.append(stateid_mask_negative) + app.logger.debug(f"stateids_mask_negative: {stateids_mask_negative}") + + # queue skills/merge-masks for positive objects + if len(stateids_mask_positive) == 1: + stateid_mask_positive_union = stateids_mask_positive[0] + else: + stateid_mask_positive_union = await ctx.ensure_skill( + url="merge-masks", + params={ + "operation": "union", + "states": stateids_mask_positive, + }, + ) + app.logger.debug(f"stateid_mask_positive_union: {stateid_mask_positive_union}") + + # queue skills/merge-masks for negative objects + if len(stateids_mask_negative) == 0: + stateid_mask_negative_union = None + elif len(stateids_mask_negative) == 1: + stateid_mask_negative_union = stateids_mask_negative[0] + else: + stateid_mask_negative_union = await ctx.ensure_skill( + url="merge-masks", + params={ + "operation": "union", + "states": stateids_mask_negative, + }, + ) + app.logger.debug(f"stateid_mask_negative_union: {stateid_mask_negative_union}") + + # queue skills/merge-masks for difference between positive and negative masks + if stateid_mask_negative_union is not None: + stateid_mask_difference = await ctx.ensure_skill( + url="merge-masks", + params={ + "operation": "difference", + "states": [stateid_mask_positive_union, stateid_mask_negative_union], + }, + ) + else: + stateid_mask_difference = stateid_mask_positive_union + app.logger.debug(f"stateid_mask_difference: {stateid_mask_difference}") + + # queue skills/recolor + stateid_recolor = await ctx.ensure_skill( + url=f"recolor/{stateid_input}/{stateid_mask_difference}", + params={"color": object_color}, + ) + app.logger.debug(f"stateid_recolor: {stateid_recolor}") + + return stateid_recolor + + +async def _recolor(ctx: EditorAPIContext, request: Request) -> Response: + # parse input data + input_json = await request.get_json() + app.logger.debug(f"json payload: {input_json}") + input_data = RecolorParams(**input_json) + app.logger.debug(f"parsed payload: {input_data}") + + # get stateids_input, or create them from openaiFileIdRefs + if input_data.stateids_input: + stateids_input = input_data.stateids_input + elif input_data.openaiFileIdRefs: + stateids_input: list[str] = [] + for oai_ref in input_data.openaiFileIdRefs: + if oai_ref.download_link: + stateid_input = await create_state(ctx, oai_ref.download_link) + stateids_input.append(stateid_input) + else: + return json_error("stateids_input or openaiFileIdRefs is required", 400) + app.logger.debug(f"stateids_input: {stateids_input}") + + # validate object_colors + if input_data.object_colors is None: + return json_error("object_colors is required", 400) + if len(stateids_input) != len(input_data.object_colors): + return json_error("stateids_input and object_colors must have the same length", 400) + + # validate positive_object_names + if input_data.positive_object_names is None: + return json_error("positive_object_names is required", 400) + if len(stateids_input) != len(input_data.positive_object_names): + return json_error("stateids_input and positive_object_names must have the same length", 400) + for object_names in input_data.positive_object_names: + if not object_names: + return json_error("positive object list cannot be empty", 400) + for object_name in object_names: + if not object_name: + return json_error("positive object name cannot be empty", 400) + + # validate negative_object_names + if input_data.negative_object_names is None: + input_data.negative_object_names = [[]] * len(stateids_input) + if len(stateids_input) != len(input_data.negative_object_names): + return json_error("stateids_input and negative_object_names must have the same length", 400) + for object_names in input_data.negative_object_names: + for object_name in object_names: + if not object_name: + return json_error("negative object name cannot be empty", 400) + + # process the inputs + stateids_recolor = [ + await process( + ctx=ctx, + object_color=object_color, + stateid_input=stateid_input, + positive_object_names=positive_object_names, + negative_object_names=negative_object_names, + ) + for stateid_input, positive_object_names, negative_object_names, object_color in zip( + stateids_input, + input_data.positive_object_names, + input_data.negative_object_names, + input_data.object_colors, + strict=True, + ) + ] + app.logger.debug(f"stateids_recolor: {stateids_recolor}") + + # download the output images + recolor_imgs = [ + await download_image(ctx=ctx, stateid=stateid_recolor) # + for stateid_recolor in stateids_recolor + ] + + # build output response + output_data = RecolorOutput( + openaiFileResponse=[ + OpenaiFileResponse.from_image( + image=recolor_img, + name=f"recolored_{i}", + ) + for i, recolor_img in enumerate(recolor_imgs) + ], + stateids_output=stateids_recolor, + stateids_undo=stateids_input, + ) + app.logger.debug(f"output payload: {output_data}") + output_response = jsonify(output_data.model_dump()) + return output_response diff --git a/examples/chatgpt/src/chatgpt_bridge/skills/shadow.py b/examples/chatgpt/src/chatgpt_bridge/skills/shadow.py new file mode 100644 index 0000000..8dd59d4 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/skills/shadow.py @@ -0,0 +1,126 @@ +from finegrain import EditorAPIContext +from pydantic import BaseModel +from quart import Request, Response, jsonify +from quart import current_app as app + +from chatgpt_bridge.utils import OpenaiFileIdRef, OpenaiFileResponse, StateID, create_state, download_image, json_error + + +class ShadowParams(BaseModel): + openaiFileIdRefs: list[OpenaiFileIdRef] | None = None # noqa: N815 + stateids_input: list[StateID] | None = None + background_colors: list[str] | None = None + object_names: list[str] | None = None + + +class ShadowOutput(BaseModel): + openaiFileResponse: list[OpenaiFileResponse] # noqa: N815 + stateids_output: list[StateID] + stateids_undo: list[StateID] + + +async def process( + ctx: EditorAPIContext, + stateid_input: StateID, + object_name: str, + background_color: str, +) -> StateID: + # queue skills/infer-bbox + stateid_bbox = await ctx.ensure_skill( + url=f"infer-bbox/{stateid_input}", + params={"product_name": object_name}, + ) + app.logger.debug(f"stateid_bbox: {stateid_bbox}") + + # queue skills/segment + stateid_mask = await ctx.ensure_skill(url=f"segment/{stateid_bbox}") + app.logger.debug(f"stateid_mask: {stateid_mask}") + + # queue skills/cutout + stateid_cutout = await ctx.ensure_skill(url=f"cutout/{stateid_input}/{stateid_mask}") + app.logger.debug(f"stateid_cutout: {stateid_cutout}") + + # queue skills/shadow + stateid_shadow = await ctx.ensure_skill( + url=f"shadow/{stateid_cutout}", + params={"background": background_color}, + ) + + return stateid_shadow + + +async def _shadow(ctx: EditorAPIContext, request: Request) -> Response: + # parse input data + input_json = await request.get_json() + app.logger.debug(f"json payload: {input_json}") + input_data = ShadowParams(**input_json) + app.logger.debug(f"parsed payload: {input_data}") + + # get stateids_input, or create them from openaiFileIdRefs + if input_data.stateids_input: + stateids_input = input_data.stateids_input + elif input_data.openaiFileIdRefs: + stateids_input: list[str] = [] + for oai_ref in input_data.openaiFileIdRefs: + if oai_ref.download_link: + stateid_input = await create_state(ctx, oai_ref.download_link) + stateids_input.append(stateid_input) + else: + return json_error("stateids_input or openaiFileIdRefs is required", 400) + app.logger.debug(f"stateids_input: {stateids_input}") + + # validate object_names + if input_data.object_names is None: + return json_error("object_names is required", 400) + if len(stateids_input) != len(input_data.object_names): + return json_error("stateids_input and object_names must have the same length", 400) + for object_name in input_data.object_names: + if not object_name: + return json_error("object name cannot be empty", 400) + + # validate background_colors + if input_data.background_colors is None: + input_data.background_colors = ["#ffffff"] * len(stateids_input) + if len(stateids_input) != len(input_data.background_colors): + return json_error("stateids_input and background_colors must have the same length", 400) + for background_color in input_data.background_colors: + if not background_color: + return json_error("background color cannot be empty", 400) + + # process the inputs + stateids_shadow = [ + await process( + ctx=ctx, + stateid_input=stateid_input, + object_name=object_name, + background_color=background_color, + ) + for stateid_input, object_name, background_color in zip( + stateids_input, + input_data.object_names, + input_data.background_colors, + strict=True, + ) + ] + + # download output images + shadow_imgs = [ + await download_image(ctx=ctx, stateid=stateid_shadow) # + for stateid_shadow in stateids_shadow + ] + + # build output response + output_data = ShadowOutput( + openaiFileResponse=[ + OpenaiFileResponse.from_image( + image=shadow_img, + name=f"shadow_{i}", + ) + for i, shadow_img in enumerate(shadow_imgs) + ], + stateids_output=stateids_shadow, + stateids_undo=stateids_input, + ) + app.logger.debug(f"output payload: {output_data}") + output_response = jsonify(output_data.model_dump()) + return output_response diff --git a/examples/chatgpt/src/chatgpt_bridge/skills/undo.py b/examples/chatgpt/src/chatgpt_bridge/skills/undo.py new file mode 100644 index 0000000..af5a563 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/skills/undo.py @@ -0,0 +1,48 @@ +from finegrain import EditorAPIContext +from pydantic import BaseModel +from quart import Request, Response, jsonify +from quart import current_app as app + +from chatgpt_bridge.utils import OpenaiFileResponse, StateID, download_image, json_error + + +class UndoParams(BaseModel): + stateids_undo: list[StateID] | None = None + + +class UndoOutput(BaseModel): + openaiFileResponse: list[OpenaiFileResponse] # noqa: N815 + stateids_output: list[StateID] + + +async def _undo(ctx: EditorAPIContext, request: Request) -> Response: + # parse input data + input_json = await request.get_json() + app.logger.debug(f"json payload: {input_json}") + input_data = UndoParams(**input_json) + app.logger.debug(f"parsed payload: {input_data}") + + # validate input data + if not input_data.stateids_undo: + return json_error("stateids_undo is required", 400) + + # download the image + images = [ + await download_image(ctx, stateid=stateid) # + for stateid in input_data.stateids_undo + ] + + # build output response + output_data = UndoOutput( + openaiFileResponse=[ + OpenaiFileResponse.from_image( + image=image, + name=f"undo_{i}", + ) + for i, image in enumerate(images) + ], + stateids_output=input_data.stateids_undo, + ) + app.logger.debug(f"output payload: {output_data}") + output_response = jsonify(output_data.model_dump()) + return output_response diff --git a/examples/chatgpt/src/chatgpt_bridge/skills/upscale.py b/examples/chatgpt/src/chatgpt_bridge/skills/upscale.py new file mode 100644 index 0000000..08dd9e4 --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/skills/upscale.py @@ -0,0 +1,64 @@ +from finegrain import EditorAPIContext +from pydantic import BaseModel +from quart import Request, Response, jsonify +from quart import current_app as app + +from chatgpt_bridge.utils import OpenaiFileIdRef, OpenaiFileResponse, StateID, create_state, download_image, json_error + + +class UpscaleParams(BaseModel): + openaiFileIdRefs: list[OpenaiFileIdRef] | None = None # noqa: N815 + stateids_input: list[StateID] | None = None + + +class UpscaleOutput(BaseModel): + openaiFileResponse: list[OpenaiFileResponse] # noqa: N815 + stateids_output: list[StateID] + stateids_undo: list[StateID] + + +async def _upscale(ctx: EditorAPIContext, request: Request) -> Response: + # parse input data + input_json = await request.get_json() + app.logger.debug(f"json payload: {input_json}") + input_data = UpscaleParams(**input_json) + app.logger.debug(f"parsed payload: {input_data}") + + # get stateids_input, or create them from openaiFileIdRefs + if input_data.stateids_input: + stateids_input = input_data.stateids_input + elif input_data.openaiFileIdRefs: + stateids_input: list[StateID] = [] + for oai_ref in input_data.openaiFileIdRefs: + if oai_ref.download_link: + stateid_input = await create_state(ctx, oai_ref.download_link) + stateids_input.append(stateid_input) + else: + return json_error("stateids_input or openaiFileIdRefs is required", 400) + app.logger.debug(f"stateids_input: {stateids_input}") + + # queue skills/upscale + stateids_upscaled = [ + await ctx.ensure_skill(url=f"upscale/{stateid_input}") # + for stateid_input in stateids_input + ] + app.logger.debug(f"stateids_upscaled: {stateids_upscaled}") + + # download output images + upscaled_images = [ + await download_image(ctx, stateid_upscaled) # + for stateid_upscaled in stateids_upscaled + ] + + # build output response + output_data = UpscaleOutput( + openaiFileResponse=[ + OpenaiFileResponse.from_image(image=upscaled_img, name=f"upscaled_{i}") + for i, upscaled_img in enumerate(upscaled_images) + ], + stateids_output=stateids_upscaled, + stateids_undo=stateids_input, + ) + app.logger.debug(f"output payload: {output_data}") + output_response = jsonify(output_data.model_dump()) + return output_response diff --git a/examples/chatgpt/src/chatgpt_bridge/utils.py b/examples/chatgpt/src/chatgpt_bridge/utils.py new file mode 100644 index 0000000..ea5f8bf --- /dev/null +++ b/examples/chatgpt/src/chatgpt_bridge/utils.py @@ -0,0 +1,106 @@ +import base64 +import io +from functools import wraps + +from finegrain import EditorAPIContext +from PIL import Image +from pydantic import BaseModel +from quart import Response, jsonify, request +from quart import current_app as app + +StateID = str + + +def json_error(message: str, status: int = 400) -> Response: + response = jsonify(error=message) + response.status_code = status + app.logger.error(message) + return response + + +def require_basic_auth_token(token: str): + def decorator(f): + @wraps(f) + async def decorated_function(*args, **kwargs): + auth_header = request.headers.get("Authorization", "") + if auth_header != f"Basic {token}": + return json_error("Unauthorized", 401) + return await f(*args, **kwargs) + + return decorated_function + + return decorator + + +async def create_state( + ctx: EditorAPIContext, + file_url: str, + timeout: float | None = None, +) -> str: + response = await ctx.request( + method="POST", + url="state/create", + json={"priority": ctx.priority}, + data={"file_url": file_url}, + ) + state_id = response.json()["state"] + status = await ctx.sse_await(state_id, timeout=timeout) + if status: + return state_id + meta = await ctx.get_meta(state_id) + raise RuntimeError(f"create_state failed with {state_id}: {meta}") + + +async def download_image( + ctx: EditorAPIContext, + stateid: str, + image_format: str = "PNG", + image_resolution: str = "DISPLAY", +) -> Image.Image: + response = await ctx.request( + method="GET", + url=f"state/image/{stateid}", + params={ + "format": image_format, + "resolution": image_resolution, + }, + ) + return Image.open(io.BytesIO(response.content)) + + +def image_to_base64( + image: Image.Image, + image_format: str, +) -> str: + buffer = io.BytesIO() + image.save(buffer, format=image_format) + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +class OpenaiFileIdRef(BaseModel): + # https://platform.openai.com/docs/actions/sending-files + id: str | None = None + name: str | None = None + mime_type: str | None = None + download_link: str + + +class OpenaiFileResponse(BaseModel): + # https://platform.openai.com/docs/actions/sending-files + name: str + mime_type: str + content: str + + @staticmethod + def from_image(image: Image.Image, name: str) -> "OpenaiFileResponse": + return OpenaiFileResponse( + name=f"{name}.jpg", + mime_type="image/jpeg", + content=image_to_base64(image.convert("RGB"), "JPEG"), + ) + + def __repr__(self) -> str: + return f"OpenaiFileResponse(name={self.name}, mime_type={self.mime_type}, content_len={len(self.content)})" + + +BoundingBox = tuple[int, int, int, int] # (x1, y1, x2, y2) diff --git a/examples/chatgpt/tests/__init__.py b/examples/chatgpt/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/chatgpt/tests/conftest.py b/examples/chatgpt/tests/conftest.py new file mode 100644 index 0000000..bd8248e --- /dev/null +++ b/examples/chatgpt/tests/conftest.py @@ -0,0 +1,26 @@ +from functools import partial + +import pytest +from quart.typing import TestClientProtocol + +from chatgpt_bridge import app, login +from chatgpt_bridge.env import CHATGPT_AUTH_TOKEN +from chatgpt_bridge.utils import OpenaiFileIdRef + + +@pytest.fixture() +async def test_client() -> TestClientProtocol: + await login() + client = app.test_client() + client.post = partial(client.post, headers={"Authorization": f"Basic {CHATGPT_AUTH_TOKEN}"}) + return client + + +@pytest.fixture() +def example_ref() -> OpenaiFileIdRef: + return OpenaiFileIdRef( + name="image.jpg", + id="file-AAAAAAAAAAAAAAAAAAAAAAAA", + mime_type="image/jpeg", + download_link="https://img.freepik.com/free-photo/still-life-device-table_23-2150994394.jpg", + ) diff --git a/examples/chatgpt/tests/test_box.py b/examples/chatgpt/tests/test_box.py new file mode 100644 index 0000000..8d0a494 --- /dev/null +++ b/examples/chatgpt/tests/test_box.py @@ -0,0 +1,93 @@ +from quart.typing import TestClientProtocol + +from chatgpt_bridge.skills.box import BoxParams +from chatgpt_bridge.utils import OpenaiFileIdRef + +from .utils import wrap_sse + + +@wrap_sse +async def test_box_no_object_name(test_client: TestClientProtocol, example_ref: OpenaiFileIdRef) -> None: + data = BoxParams( + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/box", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object_names is required" + + +@wrap_sse +async def test_box_no_images(test_client: TestClientProtocol) -> None: + data = BoxParams( + object_names=["glass of water"], + ) + + response = await test_client.post("/box", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input or openaiFileIdRefs is required" + + +@wrap_sse +async def test_box_empty_object_names(test_client: TestClientProtocol, example_ref: OpenaiFileIdRef) -> None: + data = BoxParams( + object_names=[], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/box", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object_names cannot be empty" + + +@wrap_sse +async def test_box_empty_object_name(test_client: TestClientProtocol, example_ref: OpenaiFileIdRef) -> None: + data = BoxParams( + object_names=[""], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/box", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object name cannot be empty" + + +@wrap_sse +async def test_box_object_names_wrong_cardinality( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = BoxParams( + object_names=["glass of water", "lamp"], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/box", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_names must have the same length" + + +@wrap_sse +async def test_box(test_client: TestClientProtocol, example_ref: OpenaiFileIdRef) -> None: + data = BoxParams( + object_names=["glass of water", "lamp"], + openaiFileIdRefs=[example_ref, example_ref], + ) + + response = await test_client.post("/box", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 200 + assert "bounding_boxes" in response_json + bounding_boxes = response_json["bounding_boxes"] + assert len(bounding_boxes) == 2 diff --git a/examples/chatgpt/tests/test_cutout.py b/examples/chatgpt/tests/test_cutout.py new file mode 100644 index 0000000..7396517 --- /dev/null +++ b/examples/chatgpt/tests/test_cutout.py @@ -0,0 +1,153 @@ +import base64 +import io +from pathlib import Path + +from PIL import Image +from quart.typing import TestClientProtocol + +from chatgpt_bridge.skills.cutout import CutoutOutput, CutoutParams +from chatgpt_bridge.utils import OpenaiFileIdRef + +from .utils import wrap_sse + + +@wrap_sse +async def test_object_cutout_no_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = CutoutParams( + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/cutout", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object_names is required" + + +@wrap_sse +async def test_cutout_empty_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = CutoutParams( + object_names=[], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/cutout", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_names must have the same length" + + +@wrap_sse +async def test_cutout_empty_object_name( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = CutoutParams( + object_names=[""], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/cutout", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object name cannot be empty" + + +@wrap_sse +async def test_cutout_no_images( + test_client: TestClientProtocol, +) -> None: + data = CutoutParams( + object_names=[""], + ) + + response = await test_client.post("/cutout", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input or openaiFileIdRefs is required" + + +@wrap_sse +async def test_cutout_wrong_cardinality( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = CutoutParams( + object_names=["glass of water", "lamp"], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/cutout", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_names must have the same length" + + +@wrap_sse +async def test_cutout_wrong_cardinality2( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = CutoutParams( + object_names=["glass of water"], + openaiFileIdRefs=[example_ref, example_ref], + ) + + response = await test_client.post("/cutout", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_names must have the same length" + + +@wrap_sse +async def test_cutout_wrong_cardinality3( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = CutoutParams( + object_names=["glass of water"], + openaiFileIdRefs=[example_ref], + background_colors=["#ffffff", "#ffffff"], + ) + + response = await test_client.post("/cutout", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and background_colors must have the same length" + + +@wrap_sse +async def test_cutout( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, + tmp_path: Path, +) -> None: + data = CutoutParams( + object_names=["glass of water", "lamp"], + openaiFileIdRefs=[example_ref, example_ref], + ) + + response = await test_client.post("/cutout", json=data.model_dump()) + response_json = await response.get_json() + response_data = CutoutOutput(**response_json) + + assert response.status_code == 200 + assert len(response_json["openaiFileResponse"]) == 2 + for i, oai_file in enumerate(response_data.openaiFileResponse): + assert oai_file.name == f"cutout_{i}.jpg" + assert oai_file.mime_type == "image/jpeg" + image_data = io.BytesIO(base64.b64decode(oai_file.content)) + image = Image.open(image_data) + image.save(tmp_path / oai_file.name) diff --git a/examples/chatgpt/tests/test_erase.py b/examples/chatgpt/tests/test_erase.py new file mode 100644 index 0000000..dde8e0e --- /dev/null +++ b/examples/chatgpt/tests/test_erase.py @@ -0,0 +1,133 @@ +import base64 +import io +from pathlib import Path + +from PIL import Image +from quart.typing import TestClientProtocol + +from chatgpt_bridge.skills.erase import EraseOutput, EraseParams +from chatgpt_bridge.utils import OpenaiFileIdRef +from tests.utils import wrap_sse + + +@wrap_sse +async def test_eraser_no_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = EraseParams( + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/erase", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object_names is required" + + +@wrap_sse +async def test_eraser_empty_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = EraseParams( + object_names=[], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/erase", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_names must have the same length" + + +@wrap_sse +async def test_eraser_empty2_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = EraseParams( + object_names=[[]], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/erase", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object list cannot be empty" + + +@wrap_sse +async def test_eraser_empty_object_name( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = EraseParams( + object_names=[[""]], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/erase", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object name cannot be empty" + + +@wrap_sse +async def test_eraser_object_names_wrong_cardinality( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = EraseParams( + object_names=[["glass of water"]], + openaiFileIdRefs=[example_ref, example_ref], + ) + + response = await test_client.post("/erase", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_names must have the same length" + + +@wrap_sse +async def test_eraser_no_images(test_client: TestClientProtocol) -> None: + data = EraseParams( + object_names=[["glass of water"]], + openaiFileIdRefs=[], + ) + + response = await test_client.post("/erase", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input or openaiFileIdRefs is required" + + +@wrap_sse +async def test_eraser( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, + tmp_path: Path, +) -> None: + data = EraseParams( + object_names=[["glass of water"], ["glass of water", "lamp"]], + openaiFileIdRefs=[example_ref, example_ref], + ) + + response = await test_client.post("/erase", json=data.model_dump()) + response_json = await response.get_json() + response_data = EraseOutput(**response_json) + + assert response.status_code == 200 + assert len(response_json["openaiFileResponse"]) == 2 + for i, oai_file in enumerate(response_data.openaiFileResponse): + assert oai_file.name == f"erased_{i}.jpg" + assert oai_file.mime_type == "image/jpeg" + image_data = io.BytesIO(base64.b64decode(oai_file.content)) + image = Image.open(image_data) + image.save(tmp_path / oai_file.name) diff --git a/examples/chatgpt/tests/test_recolor.py b/examples/chatgpt/tests/test_recolor.py new file mode 100644 index 0000000..a5f82a8 --- /dev/null +++ b/examples/chatgpt/tests/test_recolor.py @@ -0,0 +1,200 @@ +import base64 +import io +import logging +from pathlib import Path + +from PIL import Image +from quart.typing import TestClientProtocol + +from chatgpt_bridge.skills.recolor import RecolorOutput, RecolorParams +from chatgpt_bridge.utils import OpenaiFileIdRef + +from .utils import wrap_sse + + +@wrap_sse +async def test_recolor_no_positive_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref], + object_colors=["#ff0000"], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "positive_object_names is required" + + +@wrap_sse +async def test_recolor_no_object_colors( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref], + positive_object_names=[["glass of water"]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object_colors is required" + + +@wrap_sse +async def test_recolor_no_images( + test_client: TestClientProtocol, +) -> None: + data = RecolorParams( + object_colors=["#ff0000"], + positive_object_names=[["glass of water"]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input or openaiFileIdRefs is required" + + +@wrap_sse +async def test_recolor_empty_positive_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref], + object_colors=["#ff0000"], + positive_object_names=[[]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "positive object list cannot be empty" + + +@wrap_sse +async def test_recolor_empty_negative_object_name( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref], + object_colors=["#ff0000"], + positive_object_names=[["glass of water"]], + negative_object_names=[[""]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "negative object name cannot be empty" + + +@wrap_sse +async def test_recolor_empty2_positive_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref], + object_colors=["#ff0000"], + positive_object_names=[[""]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "positive object name cannot be empty" + + +@wrap_sse +async def test_recolor_wrong_positive_object_color_cardinality( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref], + object_colors=["#ff0000", "#ff0000"], + positive_object_names=[["glass of water"]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_colors must have the same length" + + +@wrap_sse +async def test_recolor_wrong_positive_object_names_cardinality( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref], + object_colors=["#ff0000"], + positive_object_names=[["glass of water"], ["glass of water"]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and positive_object_names must have the same length" + + +@wrap_sse +async def test_recolor_wrong_negative_object_names_cardinality( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref], + object_colors=["#ff0000"], + positive_object_names=[["glass of water"]], + negative_object_names=[["bowl"], ["bowl"]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and negative_object_names must have the same length" + + +@wrap_sse +async def test_recolor( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, + tmp_path: Path, +) -> None: + data = RecolorParams( + openaiFileIdRefs=[example_ref, example_ref, example_ref], + object_colors=["#ff0000", "#00ff00", "#0000ff"], + positive_object_names=[["glass of water"], ["bowl", "glass of water"], ["bowl", "glass of water"]], + negative_object_names=[[], [], ["bowl"]], + ) + + response = await test_client.post("/recolor", json=data.model_dump()) + response_json = await response.get_json() + response_data = RecolorOutput(**response_json) + + assert response.status_code == 200 + assert len(response_json["openaiFileResponse"]) == 3 + for i, oai_file in enumerate(response_data.openaiFileResponse): + assert oai_file.name == f"recolored_{i}.jpg" + assert oai_file.mime_type == "image/jpeg" + image_data = io.BytesIO(base64.b64decode(oai_file.content)) + image = Image.open(image_data) + image.save(tmp_path / oai_file.name) + logging.info(f"Saved image to {tmp_path / oai_file.name}") diff --git a/examples/chatgpt/tests/test_shadow.py b/examples/chatgpt/tests/test_shadow.py new file mode 100644 index 0000000..60ffff7 --- /dev/null +++ b/examples/chatgpt/tests/test_shadow.py @@ -0,0 +1,171 @@ +import base64 +import io +from pathlib import Path + +from PIL import Image +from quart.typing import TestClientProtocol + +from chatgpt_bridge.skills.shadow import ShadowOutput, ShadowParams +from chatgpt_bridge.utils import OpenaiFileIdRef + +from .utils import wrap_sse + + +@wrap_sse +async def test_shadow_no_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = ShadowParams( + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object_names is required" + + +@wrap_sse +async def test_shadow_empty_object_names( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = ShadowParams( + object_names=[], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_names must have the same length" + + +@wrap_sse +async def test_shadow_empty_object_name( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = ShadowParams( + object_names=[""], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "object name cannot be empty" + + +@wrap_sse +async def test_shadow_no_images( + test_client: TestClientProtocol, +) -> None: + data = ShadowParams( + object_names=["glass of water"], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input or openaiFileIdRefs is required" + + +@wrap_sse +async def test_shadow_empty_images( + test_client: TestClientProtocol, +) -> None: + data = ShadowParams( + object_names=["glass of water"], + openaiFileIdRefs=[], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input or openaiFileIdRefs is required" + + +@wrap_sse +async def test_shadow_object_names_wrong_cardinality( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = ShadowParams( + object_names=["glass of water", "lamp"], + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and object_names must have the same length" + + +@wrap_sse +async def test_shadow_empty_background_colors( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = ShadowParams( + object_names=["glass of water"], + openaiFileIdRefs=[example_ref], + background_colors=[""], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "background color cannot be empty" + + +@wrap_sse +async def test_shadow_background_colors_wrong_cardinality( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, +) -> None: + data = ShadowParams( + object_names=["glass of water"], + openaiFileIdRefs=[example_ref], + background_colors=["#ff0000", "#ff0000"], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input and background_colors must have the same length" + + +@wrap_sse +async def test_shadow( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, + tmp_path: Path, +) -> None: + data = ShadowParams( + object_names=["glass of water", "lamp"], + openaiFileIdRefs=[example_ref, example_ref], + background_colors=["#ff0000", "#00ff00"], + ) + + response = await test_client.post("/shadow", json=data.model_dump()) + response_json = await response.get_json() + response_data = ShadowOutput(**response_json) + + assert response.status_code == 200 + assert len(response_json["openaiFileResponse"]) == 2 + for i, oai_file in enumerate(response_data.openaiFileResponse): + assert oai_file.name == f"shadow_{i}.jpg" + assert oai_file.mime_type == "image/jpeg" + image_data = io.BytesIO(base64.b64decode(oai_file.content)) + image = Image.open(image_data) + image.save(tmp_path / oai_file.name) diff --git a/examples/chatgpt/tests/test_upscale.py b/examples/chatgpt/tests/test_upscale.py new file mode 100644 index 0000000..6656613 --- /dev/null +++ b/examples/chatgpt/tests/test_upscale.py @@ -0,0 +1,60 @@ +import base64 +import io +from pathlib import Path + +from PIL import Image +from quart.typing import TestClientProtocol + +from chatgpt_bridge.skills.upscale import UpscaleOutput, UpscaleParams +from chatgpt_bridge.utils import OpenaiFileIdRef + +from .utils import wrap_sse + + +@wrap_sse +async def test_image_upscale_no_images(test_client: TestClientProtocol) -> None: + data = UpscaleParams() + + response = await test_client.post("/upscale", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input or openaiFileIdRefs is required" + + +@wrap_sse +async def test_image_upscale_empty_images(test_client: TestClientProtocol) -> None: + data = UpscaleParams( + openaiFileIdRefs=[], + ) + + response = await test_client.post("/upscale", json=data.model_dump()) + response_json = await response.get_json() + + assert response.status_code == 400 + assert response_json["error"] == "stateids_input or openaiFileIdRefs is required" + + +@wrap_sse +async def test_image_upscale( + test_client: TestClientProtocol, + example_ref: OpenaiFileIdRef, + tmp_path: Path, +) -> None: + data = UpscaleParams( + openaiFileIdRefs=[example_ref], + ) + + response = await test_client.post("/upscale", json=data.model_dump()) + response_json = await response.get_json() + response_data = UpscaleOutput(**response_json) + + assert response.status_code == 200 + assert response.status_code == 200 + assert len(response_json["openaiFileResponse"]) == 1 + for i, oai_file in enumerate(response_data.openaiFileResponse): + assert oai_file.name == f"upscaled_{i}.jpg" + assert oai_file.mime_type == "image/jpeg" + image_data = io.BytesIO(base64.b64decode(oai_file.content)) + image = Image.open(image_data) + image.save(tmp_path / oai_file.name) diff --git a/examples/chatgpt/tests/utils.py b/examples/chatgpt/tests/utils.py new file mode 100644 index 0000000..9f3a411 --- /dev/null +++ b/examples/chatgpt/tests/utils.py @@ -0,0 +1,26 @@ +import logging +from functools import wraps +from pathlib import Path + +import httpx + +from chatgpt_bridge import sse_start, sse_stop + + +def download(url: str, path: Path): + response = httpx.get(url, timeout=10) + response.raise_for_status() + path.write_bytes(response.content) + logging.info(f"Downloaded {url} to {path}") + + +def wrap_sse(f): + @wraps(f) + async def decorated_function(*args, **kwargs): + await sse_start() + try: + await f(*args, **kwargs) + finally: + await sse_stop() + + return decorated_function