diff --git a/README.md b/README.md index c5c113c..7d44cb0 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,32 @@ Please [see its documentation](https://github.com/NERC-CEH/object_store_api) for `python src/os_api/api.py` +## Feature extraction API + +FastAPI wrapper around different models - POST an image URL, get back embeddings + +## Label Studio ML backend + +Pre-annotation backend for Label Studio following their standard pattern. + +Build an image embedding model which will assign a likely-detritus tag: + +``` +cd scripts +dvc repro +``` + +Application is in `src/label_studio_cyto_ml` + +[Setup documentation](src/label_studio_cyto_ml/README.md) + +Short version, for testing + +``` +cd src +label-studio-ml start ./label_studio_cyto_ml +``` + ## Pipelines ### DVC @@ -93,10 +119,14 @@ Please see [PIPELINES.md](PIPELINES) for detailed documentation about a pipeline ## Contents - ### Feature extraction -Experiment testing workflows by using [this plankton model from SciVision](https://sci.vision/#/model/resnet50-plankton) to extract features from images for use in similarity search, clustering, etc. +The repository contains work on _feature extraction_ from different off-the-shelf ML models that have been trained on datasets of plankton imagery. + +The approach is useful for image search, clustering based on image similarity, and potentially for timeseries analysis of features given an image collection that forms a timeseries. + +* [ResNet50 plankton model from SciVision](https://sci.vision/#/model/resnet50-plankton) +* [ResNet18 plankton model from Alan Turing Inst] ### Running Jupyter notebooks diff --git a/models/kmeans-untagged-images-lana.pkl b/models/kmeans-untagged-images-lana.pkl new file mode 100644 index 0000000..8e2568b Binary files /dev/null and b/models/kmeans-untagged-images-lana.pkl differ diff --git a/pyproject.toml b/pyproject.toml index 78fb802..e87babf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "cyto_ml" version = "0.2.0" -requires-python = ">=3.12" +requires-python = ">=3.9" description = "This package supports the processing and analysis of plankton sample data" readme = "README.md" dependencies = [ @@ -30,6 +30,7 @@ dependencies = [ "torchvision", "xarray", "resnet50-cefas@git+https://github.com/jmarshrossney/resnet50-cefas", + "label-studio-ml@git+https://github.com/HumanSignal/label-studio-ml-backend" ] [project.optional-dependencies] diff --git a/scripts/cluster.py b/scripts/cluster.py index d3d2fed..b6f916c 100644 --- a/scripts/cluster.py +++ b/scripts/cluster.py @@ -6,7 +6,7 @@ import yaml from sklearn.cluster import KMeans -from cyto_ml.data.vectorstore import embeddings, vector_store +from cyto_ml.data.vectorstore import vector_store def main() -> None: @@ -25,8 +25,8 @@ def main() -> None: n_clusters = 5 kmeans = KMeans(n_clusters=n_clusters, random_state=42) - store = vector_store(collection_name) - X = embeddings(store) + store = vector_store("sqlite", collection_name) + X = store.embeddings() kmeans.fit(X) # We supply a -o for output directory - this doesn't ensure we write there. diff --git a/scripts/dvc.yaml b/scripts/dvc.yaml index 855c878..ff64c77 100644 --- a/scripts/dvc.yaml +++ b/scripts/dvc.yaml @@ -1,6 +1,6 @@ stages: - index: - cmd: python image_metadata.py + # index: + # cmd: python image_metadata.py embeddings: cmd: python image_embeddings.py #outs: diff --git a/src/label_studio_cyto_ml/Dockerfile b/src/label_studio_cyto_ml/Dockerfile new file mode 100644 index 0000000..be98181 --- /dev/null +++ b/src/label_studio_cyto_ml/Dockerfile @@ -0,0 +1,48 @@ +# syntax=docker/dockerfile:1 +ARG PYTHON_VERSION=3.12 + +FROM python:${PYTHON_VERSION}-slim AS python-base +ARG TEST_ENV + +WORKDIR /app + +ENV PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 \ + PORT=${PORT:-9090} \ + PIP_CACHE_DIR=/.cache \ + WORKERS=1 \ + THREADS=8 + +# Update the base OS +RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \ + --mount=type=cache,target="/var/lib/apt/lists",sharing=locked \ + set -eux; \ + apt-get update; \ + apt-get upgrade -y; \ + apt install --no-install-recommends -y \ + git; \ + apt-get autoremove -y + +# install base requirements +COPY requirements-base.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -r requirements-base.txt + +# install custom requirements +COPY requirements.txt . +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + pip install -r requirements.txt + +# install test requirements if needed +COPY requirements-test.txt . +# build only when TEST_ENV="true" +RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \ + if [ "$TEST_ENV" = "true" ]; then \ + pip install -r requirements-test.txt; \ + fi + +COPY . . + +EXPOSE 9090 + +CMD gunicorn --preload --bind :$PORT --workers $WORKERS --threads $THREADS --timeout 0 _wsgi:app diff --git a/src/label_studio_cyto_ml/README.md b/src/label_studio_cyto_ml/README.md new file mode 100644 index 0000000..0bcfbca --- /dev/null +++ b/src/label_studio_cyto_ml/README.md @@ -0,0 +1,58 @@ +This guide describes the simplest way to start using ML backend with Label Studio. + +## Running with Docker (Recommended) + +1. Start Machine Learning backend on `http://localhost:9090` with prebuilt image: + +```bash +docker-compose up +``` + +2. Validate that backend is running + +```bash +$ curl http://localhost:9090/ +{"status":"UP"} +``` + +3. Connect to the backend from Label Studio running on the same host: go to your project `Settings -> Machine Learning -> Add Model` and specify `http://localhost:9090` as a URL. + + +## Building from source (Advanced) + +To build the ML backend from source, you have to clone the repository and build the Docker image: + +```bash +docker-compose build +``` + +## Running without Docker (Advanced) + +To run the ML backend without Docker, you have to clone the repository and install all dependencies using pip: + +```bash +python -m venv ml-backend +source ml-backend/bin/activate +pip install -r requirements.txt +``` + +Then you can start the ML backend: + +```bash +label-studio-ml start ./dir_with_your_model +``` + +# Configuration +Parameters can be set in `docker-compose.yml` before running the container. + + +The following common parameters are available: +- `BASIC_AUTH_USER` - specify the basic auth user for the model server +- `BASIC_AUTH_PASS` - specify the basic auth password for the model server +- `LOG_LEVEL` - set the log level for the model server +- `WORKERS` - specify the number of workers for the model server +- `THREADS` - specify the number of threads for the model server + +# Customization + +The ML backend can be customized by adding your own models and logic inside the `./dir_with_your_model` directory. \ No newline at end of file diff --git a/src/label_studio_cyto_ml/_wsgi.py b/src/label_studio_cyto_ml/_wsgi.py new file mode 100644 index 0000000..f57ef1f --- /dev/null +++ b/src/label_studio_cyto_ml/_wsgi.py @@ -0,0 +1,128 @@ +import argparse +import json +import logging +import logging.config +import os +from typing import Any + +from label_studio_ml.api import init_app + +from label_studio_cyto_ml.model import NewModel + +# Set a default log level if LOG_LEVEL is not defined +log_level = os.getenv("LOG_LEVEL", "INFO") + +logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, # Prevent overriding existing loggers + "formatters": { + "standard": {"format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"} + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": log_level, + "stream": "ext://sys.stdout", + "formatter": "standard", + } + }, + "root": {"level": log_level, "handlers": ["console"], "propagate": True}, + } +) + + +_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.json") + + +def get_kwargs_from_config(config_path: dict = _DEFAULT_CONFIG_PATH) -> dict: + if not os.path.exists(config_path): + return dict() + with open(config_path) as f: + config = json.load(f) + assert isinstance(config, dict) + return config + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Label studio") + parser.add_argument("-p", "--port", dest="port", type=int, default=9090, help="Server port") + parser.add_argument("--host", dest="host", type=str, default="0.0.0.0", help="Server host") + parser.add_argument( + "--kwargs", + "--with", + dest="kwargs", + metavar="KEY=VAL", + nargs="+", + type=lambda kv: kv.split("="), + help="Additional LabelStudioMLBase model initialization kwargs", + ) + parser.add_argument("-d", "--debug", dest="debug", action="store_true", help="Switch debug mode") + parser.add_argument( + "--log-level", + dest="log_level", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + default=log_level, + help="Logging level", + ) + parser.add_argument( + "--model-dir", + dest="model_dir", + default=os.path.dirname(__file__), + help="Directory where models are stored (relative to the project directory)", + ) + parser.add_argument( + "--check", dest="check", action="store_true", help="Validate model instance before launching server" + ) + parser.add_argument( + "--basic-auth-user", default=os.environ.get("ML_SERVER_BASIC_AUTH_USER", None), help="Basic auth user" + ) + + parser.add_argument( + "--basic-auth-pass", default=os.environ.get("ML_SERVER_BASIC_AUTH_PASS", None), help="Basic auth pass" + ) + + args = parser.parse_args() + + # setup logging level + if args.log_level: + logging.root.setLevel(args.log_level) + + def isfloat(value: Any) -> bool: + try: + float(value) + return True + except ValueError: + return False + + def parse_kwargs() -> dict: + param = dict() + for k, v in args.kwargs: + if v.isdigit(): + param[k] = int(v) + elif v == "True" or v == "true": + param[k] = True + elif v == "False" or v == "false": + param[k] = False + elif isfloat(v): + param[k] = float(v) + else: + param[k] = v + return param + + kwargs = get_kwargs_from_config() + + if args.kwargs: + kwargs.update(parse_kwargs()) + + if args.check: + print('Check "' + NewModel.__name__ + '" instance creation..') + model = NewModel(**kwargs) + + app = init_app(model_class=NewModel, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass) + + app.run(host=args.host, port=args.port, debug=args.debug) + +else: + # for uWSGI use + app = init_app(model_class=NewModel) diff --git a/src/label_studio_cyto_ml/docker-compose.yml b/src/label_studio_cyto_ml/docker-compose.yml new file mode 100644 index 0000000..60ddc51 --- /dev/null +++ b/src/label_studio_cyto_ml/docker-compose.yml @@ -0,0 +1,35 @@ +version: "3.8" + +services: + ml-backend: + container_name: ml-backend + image: humansignal/ml-backend:v0 + build: + context: . + args: + TEST_ENV: ${TEST_ENV} + environment: + # specify these parameters if you want to use basic auth for the model server + - BASIC_AUTH_USER= + - BASIC_AUTH_PASS= + # set the log level for the model server + - LOG_LEVEL=DEBUG + # any other parameters that you want to pass to the model server + - ANY=PARAMETER + # specify the number of workers and threads for the model server + - WORKERS=1 + - THREADS=8 + # specify the model directory (likely you don't need to change this) + - MODEL_DIR=/data/models + + # Specify the Label Studio URL and API key to access + # uploaded, local storage and cloud storage files. + # Do not use 'localhost' as it does not work within Docker containers. + # Use prefix 'http://' or 'https://' for the URL always. + # Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows). + - LABEL_STUDIO_URL= + - LABEL_STUDIO_API_KEY= + ports: + - "9090:9090" + volumes: + - "./data/server:/data" diff --git a/src/label_studio_cyto_ml/model.py b/src/label_studio_cyto_ml/model.py new file mode 100644 index 0000000..5e1a616 --- /dev/null +++ b/src/label_studio_cyto_ml/model.py @@ -0,0 +1,129 @@ +import logging +import os +import pickle +from typing import Any, Dict, List, Literal, Optional + +from dotenv import load_dotenv +from label_studio_ml.model import LabelStudioMLBase +from label_studio_ml.response import ModelResponse +from resnet50_cefas import load_model + +from cyto_ml.data.image import load_image_from_url +from cyto_ml.models.utils import flat_embeddings + +# Set AWS_URL_ENDPOINT in here +# Used to convert s3:// URLs coming from Label Studio to https:// URLs +load_dotenv() + +# Label Studio ML limits our ability to manage sessions - +# see cyto_ml/models/api.py for a FastAPI version that's more considered +resnet50_model = load_model(strip_final_layer=True) + + +class ImageNotFoundError(Exception): + pass + + +class NewModel(LabelStudioMLBase): + """Custom ML Backend model""" + + def setup(self) -> None: + """Configure any parameters of your model here""" + self.set("model_version", "0.0.1") + + def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> ModelResponse: + """Write your inference logic here + :param tasks: [Label Studio tasks in JSON format](https://labelstud.io/guide/task_format.html) + :param context: [Label Studio context in JSON format](https://labelstud.io/guide/ml_create#Implement-prediction-logic) + :return model_response + ModelResponse(predictions=predictions) with + predictions: [Predictions array in JSON format](https://labelstud.io/guide/export.html#Label-Studio-JSON-format-of-annotated-tasks) + """ + logging.info(f"""\ + Run prediction on {tasks} + Received context: {context} + Project ID: {self.project_id} + Label config: {self.label_config} + Parsed JSON Label config: {self.parsed_label_config} + Extra params: {self.extra_params}""") + + # TODO as below, check what the response format should really be (try it!) + predictions = [] + for task in tasks: + try: + annotated = self.predict_task(task) + predictions.append(annotated) + except KeyError as err: + # Return 500 with detail + raise (err) + except ImageNotFoundError as err: + # Return 404 with detail + raise (err) + + # https://github.com/HumanSignal/label-studio-ml-backend/blob/master/label_studio_ml/response.py + return ModelResponse(predictions=predictions) + + def convert_url(self, url: str) -> str: + """Convert an s3:// URL to an https:// URL + Set AWS_URL_ENDPOINT in .env""" + if url.startswith("s3://"): + return url.replace("s3://", f"https://{os.getenv('AWS_URL_ENDPOINT')}/") + return url + + def predict_task(self, task: dict) -> dict: + """Receive a single task definition as described here https://labelstud.io/guide/task_format.html + Return the task decorated with predictions as described here + https://labelstud.io/guide/export.html#Label-Studio-JSON-format-of-annotated-tasks + """ + # We use two models here: + # Extract image embeddings with a ResNet + try: + image_url = task.get("data").get("image") + except KeyError as err: + raise (err) + + features = resnet50_model(load_image_from_url(self.convert_url(image_url))) + embeddings = flat_embeddings(features) + # Classify embeddings (KNN to start, many improvements possible!) and return a label + label = self.embeddings_predict(embeddings) + # TODO check what the return format should be - does ModelResponse handle this? + return label + + def embeddings_predict(self, embeddings: List[List[float]]) -> List[str]: + """Predict labels from embeddings + See cyto_ml/visualisation/pages/02_kmeans.py for usage for a collection + See scripts/cluster.py for the model build and save + """ + # TODO load this from config, add to Dockerfile + fitted = pickle.load(open("../models/kmeans-untagged-images-lana.pkl", "rb")) + label = fitted.predict(embeddings) + return label + + def fit( + self, + event: Literal["ANNOTATION_CREATED", "ANNOTATION_UPDATED", "START_TRAINING"], + data: Dict[str, Any], + **kwargs: Any, + ) -> None: + """ + This method is called each time an annotation is created or updated + You can run your logic here to update the model and persist it to the cache + It is not recommended to perform long-running operations here, as it will block the main thread + Instead, consider running a separate process or a thread (like RQ worker) to perform the training + :param event: event type can be ('ANNOTATION_CREATED', 'ANNOTATION_UPDATED', 'START_TRAINING') + :param data: the payload received from the event (check [Webhook event reference](https://labelstud.io/guide/webhook_reference.html)) + """ + + # use cache to retrieve the data from the previous fit() runs + old_data = self.get("my_data") + old_model_version = self.get("model_version") + print(f"Old data: {old_data}") + print(f"Old model version: {old_model_version}") + + # store new data to the cache + self.set("my_data", "my_new_data_value") + self.set("model_version", "my_new_model_version") + print(f"New data: {self.get('my_data')}") + print(f"New model version: {self.get('model_version')}") + + print("fit() completed successfully.") diff --git a/src/label_studio_cyto_ml/requirements-base.txt b/src/label_studio_cyto_ml/requirements-base.txt new file mode 100644 index 0000000..68ce357 --- /dev/null +++ b/src/label_studio_cyto_ml/requirements-base.txt @@ -0,0 +1,2 @@ +gunicorn==22.0.0 +label-studio-ml @ git+https://github.com/HumanSignal/label-studio-ml-backend.git \ No newline at end of file diff --git a/src/label_studio_cyto_ml/requirements-test.txt b/src/label_studio_cyto_ml/requirements-test.txt new file mode 100644 index 0000000..cffeec6 --- /dev/null +++ b/src/label_studio_cyto_ml/requirements-test.txt @@ -0,0 +1,2 @@ +pytest +pytest-cov \ No newline at end of file diff --git a/src/label_studio_cyto_ml/requirements.txt b/src/label_studio_cyto_ml/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/src/label_studio_cyto_ml/test_api.py b/src/label_studio_cyto_ml/test_api.py new file mode 100644 index 0000000..d2d4b46 --- /dev/null +++ b/src/label_studio_cyto_ml/test_api.py @@ -0,0 +1,57 @@ +""" +This file contains tests for the API of your model. You can run these tests by installing test requirements: + + ```bash + pip install -r requirements-test.txt + ``` +Then execute `pytest` in the directory of this file. + +- Change `NewModel` to the name of the class in your model.py file. +- Change the `request` and `expected_response` variables to match the input and output of your model. +""" + +import json +from typing import Generator + +import pytest +from flask.testing import FlaskClient + +from label_studio_cyto_ml.model import NewModel + + +@pytest.fixture +def client() -> Generator[FlaskClient, None, None]: + from _wsgi import init_app + + app = init_app(model_class=NewModel) + app.config["TESTING"] = True + with app.test_client() as client: + yield client + + +@pytest.mark.skip(reason="Skipping until we define a model") +def test_predict(client: FlaskClient) -> None: + request = { + "tasks": [ + { + "data": { + # Your input test data here + } + } + ], + # Your labeling configuration here + "label_config": "", + } + + expected_response = { + "results": [ + { + # Your expected result here + } + ] + } + + response = client.post("/predict", data=json.dumps(request), content_type="application/json") + assert response.status_code == 200 + response = json.loads(response.data) + assert response == expected_response