diff --git a/.github/release-drafter-config.yml b/.github/release-drafter-config.yaml similarity index 100% rename from .github/release-drafter-config.yml rename to .github/release-drafter-config.yaml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yaml similarity index 94% rename from .github/workflows/lint.yml rename to .github/workflows/lint.yaml index d112694..d72b272 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yaml @@ -6,6 +6,9 @@ on: branches: - main +env: + POETRY_VERSION: "1.8.3" + jobs: check: name: Style-check ${{ matrix.python-version }} @@ -32,7 +35,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: - version: 1.8.3 + version: ${{ env.POETRY_VERSION }} - name: Install dependencies run: | poetry install --all-extras diff --git a/.github/workflows/release-drafter.yml b/.github/workflows/release-drafter.yaml similarity index 93% rename from .github/workflows/release-drafter.yml rename to .github/workflows/release-drafter.yaml index 796c2ed..3f6728b 100644 --- a/.github/workflows/release-drafter.yml +++ b/.github/workflows/release-drafter.yaml @@ -19,6 +19,6 @@ jobs: - uses: release-drafter/release-drafter@v5 with: # (Optional) specify config name to use, relative to .github/. Default: release-drafter.yml - config-name: release-drafter-config.yml + config-name: release-drafter-config.yaml env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yaml similarity index 89% rename from .github/workflows/release.yml rename to .github/workflows/release.yaml index 244ddd8..4f2a1e1 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yaml @@ -23,10 +23,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: - version: 1.8.3 - - - name: Set Version - run: poetry version ${{ github.event.inputs.version }} + version: ${{ env.POETRY_VERSION }} - name: Build package run: poetry build @@ -52,7 +49,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: - version: 1.8.3 + version: ${{ env.POETRY_VERSION }} - uses: actions/download-artifact@v4 with: diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yaml similarity index 91% rename from .github/workflows/run_tests.yml rename to .github/workflows/run_tests.yaml index fff5dff..b36f192 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yaml @@ -7,6 +7,9 @@ on: branches: - main +env: + POETRY_VERSION: "1.8.3" + jobs: test: name: Python ${{ matrix.python-version }} - ${{ matrix.connection }} [redis-stack ${{matrix.redis-stack-version}}] @@ -15,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.9, 3.10, 3.11] + python-version: ["3.9", "3.10", "3.11"] connection: ['hiredis', 'plain'] redis-stack-version: ['6.2.6-v9', 'latest', 'edge'] @@ -36,7 +39,7 @@ jobs: - name: Install Poetry uses: snok/install-poetry@v1 with: - version: 1.8.3 + version: ${{ env.PYTHON_VERSION }} - name: Install dependencies run: | diff --git a/.gitignore b/.gitignore index 56f914b..b144692 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,9 @@ docs/old_redis_model_store.ipynb docs/test.py __pycache__/ -.mypy_cache/ \ No newline at end of file +.mypy_cache/ +.pytest_cache/ +.coverage +htmlcov/ + +dump.rdb \ No newline at end of file diff --git a/README.md b/README.md index 5620a5c..2d8218b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,120 @@ -# Redis Model Store +# 🧠 Redis Model Store +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +![Language](https://img.shields.io/github/languages/top/redis-applied-ai/redis-model-store) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +![GitHub last commit](https://img.shields.io/github/last-commit/redis-applied-ai/redis-model-store) +[![pypi](https://badge.fury.io/py/redisvl.svg)](https://pypi.org/project/redis-model-store/) + +Store, version, and manage your ML models in Redis with ease. `redis-model-store` provides a simple yet powerful interface for handling machine learning model artifacts in Redis. + +## ✨ Features + +- **🔄 Automatic Versioning**: Track and manage multiple versions of your models +- **📦 Smart Storage**: Large models are automatically sharded for optimal storage +- **🔌 Pluggable Serialization**: Works with any Python object (NumPy, PyTorch, TensorFlow, etc.) +- **🏃‍♂️ High Performance**: Efficient storage and retrieval using Redis pipelining +- **🛡️ Safe Operations**: Atomic operations with automatic cleanup on failures + +## 🚀 Quick Start + +### Installation + +```bash +# Using pip +pip install redis-model-store + +# Or using poetry +poetry add redis-model-store +``` + +### Basic Usage + +Here's a simple example using scikit-learn: + +```python +from redis import Redis +from redis_model_store import ModelStore +from sklearn.ensemble import RandomForestClassifier + +# Connect to Redis and initialize store +redis = Redis(host="localhost", port=6379) +store = ModelStore(redis) + +# Train your model +model = RandomForestClassifier() +model.fit(X_train, y_train) + +# Save model with version tracking +version = store.save_model( + model, + name="my-classifier", + description="Random forest trained on dataset v1" +) + +# List available models +models = store.list_models() +print(f"Available models: {models}") + +# Load latest version +model = store.load_model("my-classifier") + +# Load specific version +model = store.load_model("my-classifier", version=version) + +# View all versions +versions = store.get_all_versions("my-classifier") +for v in versions: + print(f"Version: {v.version}, Created: {v.created_at}") +``` + +## 🛠️ Contributing + +We welcome contributions! Here's how to get started: + +### Development Setup + +1. Clone the repository: +```bash +git clone https://github.com/redis-applied-ai/redis-model-store.git +cd redis-model-store +``` + +2. Install poetry if you haven't: +```bash +curl -sSL https://install.python-poetry.org | python3 - +``` + +3. Install dependencies: +```bash +poetry install --all-extras +``` + +### Linting and Tests + +```bash +poetry run format +poetry run check-mypy +poetry run test +poetry run test-verbose +``` + +### Making Changes + +1. Create a new branch: +```bash +git checkout -b feat/your-feature-name +``` + +2. Make your changes and ensure: + - All tests pass (covering new functionality) + - Code is formatted + - Type hints are valid + - Examples/docs added as notebooks to the `docs/` directory. + +3. Push changes and open a PR + + +## 📚 Documentation + +For more usage examples check out tbhis [Example Notebook](docs/redis_model_store.ipynb). diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..984fd7b --- /dev/null +++ b/conftest.py @@ -0,0 +1,21 @@ +import os +import pytest + +from testcontainers.compose import DockerCompose + + +@pytest.fixture(scope="session", autouse=True) +def redis_container(): + # Set the default Redis version if not already set + os.environ.setdefault("REDIS_VERSION", "edge") + + compose = DockerCompose("tests", compose_file_name="docker-compose.yaml", pull=True) + compose.start() + + redis_host, redis_port = compose.get_service_host_and_port("redis", 6379) + redis_url = f"redis://{redis_host}:{redis_port}" + os.environ["REDIS_URL"] = redis_url + + yield compose + + compose.stop() diff --git a/docs/redis_model_store.ipynb b/docs/redis_model_store.ipynb index 88e13ec..016c52e 100644 --- a/docs/redis_model_store.ipynb +++ b/docs/redis_model_store.ipynb @@ -13,13 +13,13 @@ "- Builds a model metadata index for model version management\n", "- Handles model chunking, serialization, and deserialization to/from Redis using Pickle\n", "\n", - "Then we test with various Python ML-native data types and models.\n", + "Below we test with various Python ML-native data types and models.\n", "\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "0c3ead6c", "metadata": {}, "outputs": [], @@ -29,7 +29,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "307c25ed", "metadata": {}, "outputs": [], @@ -37,7 +37,7 @@ "import os\n", "import redis\n", "\n", - "from model_store import ModelStore\n", + "from redis_model_store import ModelStore\n", "\n", "# Replace values below with your own if using Redis Cloud instance\n", "REDIS_HOST = os.getenv(\"REDIS_HOST\", \"localhost\") # ex: \"redis-18374.c253.us-central1-1.gce.cloud.redislabs.com\"\n", @@ -53,42 +53,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "dfe10ba9", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12:28:23 redisvl.index.index INFO Index already exists, not overwriting.\n" - ] - } - ], + "outputs": [], "source": [ "# Initialize the ModelStore\n", - "model_store = ModelStore(redis_client)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "79463e37", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "1" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model_store.model_registry.clear()" + "model_store = ModelStore(redis_client, shard_size=1012*100) # ~100Kb sized keys in Redis" ] }, { @@ -143,12 +114,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-01-21 12:28:49.913 - model_store.store - INFO - Starting save operation for model 'random_forest'\n", - "2025-01-21 12:28:49.925 - model_store.store - INFO - Added version record (0.0116s)\n", - "2025-01-21 12:28:49.926 - model_store.store - INFO - Starting model serialization and storage\n", - "2025-01-21 12:28:49.930 - model_store.store - INFO - Stored model in 1 shards (0.0044s)\n", - "2025-01-21 12:28:49.932 - model_store.store - INFO - Set shards (0.0012s)\n", - "2025-01-21 12:28:49.932 - model_store.store - INFO - Total save operation completed in 0.0191s\n" + "2025-01-22 10:56:13.871 - model_store.store - INFO - Saving 'random_forest' model\n", + "2025-01-22 10:56:13.873 - model_store.store - INFO - Starting model serialization and storage\n", + "2025-01-22 10:56:13.880 - model_store.store - INFO - Stored model in 2 shards (0.0069s)\n", + "2025-01-22 10:56:13.883 - model_store.store - INFO - Save operation completed in 0.0121s\n" ] } ], @@ -168,11 +137,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-01-21 12:28:49.936 - model_store.store - INFO - Starting load operation for model 'random_forest'\n", - "2025-01-21 12:28:49.938 - model_store.store - INFO - Retrieved version metadata (0.0013s)\n", - "2025-01-21 12:28:49.938 - model_store.store - INFO - Starting model reconstruction from shards\n", - "2025-01-21 12:28:49.943 - model_store.store - INFO - Loaded model from 1 shards (0.0041s)\n", - "2025-01-21 12:28:49.943 - model_store.store - INFO - Total load operation completed (0.0070s)\n" + "2025-01-22 10:56:13.887 - model_store.store - INFO - Loading 'random_forest' model\n", + "2025-01-22 10:56:13.889 - model_store.store - INFO - Starting model reconstruction from shards\n", + "2025-01-22 10:56:13.899 - model_store.store - INFO - Loaded model from 2 shards (0.0105s)\n", + "2025-01-22 10:56:13.899 - model_store.store - INFO - Load operation completed in 0.0130s\n" ] }, { @@ -227,12 +195,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-01-21 12:29:36.338 - model_store.store - INFO - Starting save operation for model 'numpy_array'\n", - "2025-01-21 12:29:36.343 - model_store.store - INFO - Added version record (0.0046s)\n", - "2025-01-21 12:29:36.344 - model_store.store - INFO - Starting model serialization and storage\n", - "2025-01-21 12:29:44.056 - model_store.store - INFO - Stored model in 1025 shards (7.7123s)\n", - "2025-01-21 12:29:44.060 - model_store.store - INFO - Set shards (0.0037s)\n", - "2025-01-21 12:29:44.061 - model_store.store - INFO - Total save operation completed in 7.7231s\n" + "2025-01-22 10:56:14.435 - model_store.store - INFO - Saving 'numpy_array' model\n", + "2025-01-22 10:56:14.438 - model_store.store - INFO - Starting model serialization and storage\n", + "2025-01-22 10:56:21.996 - model_store.store - INFO - Stored model in 10611 shards (7.5575s)\n", + "2025-01-22 10:56:22.004 - model_store.store - INFO - Save operation completed in 7.5689s\n" ] } ], @@ -250,11 +216,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-01-21 12:30:09.563 - model_store.store - INFO - Starting load operation for model 'numpy_array'\n", - "2025-01-21 12:30:09.568 - model_store.store - INFO - Retrieved version metadata (0.0038s)\n", - "2025-01-21 12:30:09.568 - model_store.store - INFO - Starting model reconstruction from shards\n", - "2025-01-21 12:30:13.246 - model_store.store - INFO - Loaded model from 1025 shards (3.6776s)\n", - "2025-01-21 12:30:13.340 - model_store.store - INFO - Total load operation completed (3.7763s)\n" + "2025-01-22 10:56:22.008 - model_store.store - INFO - Loading 'numpy_array' model\n", + "2025-01-22 10:56:22.013 - model_store.store - INFO - Starting model reconstruction from shards\n", + "2025-01-22 10:56:25.661 - model_store.store - INFO - Loaded model from 10611 shards (3.6480s)\n", + "2025-01-22 10:56:25.702 - model_store.store - INFO - Load operation completed in 3.6941s\n" ] } ], @@ -341,7 +306,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Prediction for input 4.0: 6.903447151184082\n" + "Prediction for input 4.0: 7.4401021003723145\n" ] } ], @@ -354,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "id": "86dc9f9d", "metadata": {}, "outputs": [ @@ -362,12 +327,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-01-21 12:32:58.244 - model_store.store - INFO - Starting save operation for model 'pytorch'\n", - "2025-01-21 12:32:58.247 - model_store.store - INFO - Added version record (0.0028s)\n", - "2025-01-21 12:32:58.248 - model_store.store - INFO - Starting model serialization and storage\n", - "2025-01-21 12:32:58.250 - model_store.store - INFO - Stored model in 1 shards (0.0019s)\n", - "2025-01-21 12:32:58.252 - model_store.store - INFO - Set shards (0.0018s)\n", - "2025-01-21 12:32:58.252 - model_store.store - INFO - Total save operation completed in 0.0083s\n" + "2025-01-22 10:56:32.978 - model_store.store - INFO - Saving 'pytorch' model\n", + "2025-01-22 10:56:32.982 - model_store.store - INFO - Starting model serialization and storage\n", + "2025-01-22 10:56:32.983 - model_store.store - INFO - Stored model in 1 shards (0.0012s)\n", + "2025-01-22 10:56:32.984 - model_store.store - INFO - Save operation completed in 0.0056s\n" ] }, { @@ -376,7 +339,7 @@ "'1.0'" ] }, - "execution_count": 15, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -390,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "id": "670f353f", "metadata": {}, "outputs": [ @@ -398,18 +361,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-01-21 12:33:00.741 - model_store.store - INFO - Starting load operation for model 'pytorch'\n", - "2025-01-21 12:33:00.744 - model_store.store - INFO - Retrieved version metadata (0.0018s)\n", - "2025-01-21 12:33:00.745 - model_store.store - INFO - Starting model reconstruction from shards\n", - "2025-01-21 12:33:00.747 - model_store.store - INFO - Loaded model from 1 shards (0.0019s)\n", - "2025-01-21 12:33:00.747 - model_store.store - INFO - Total load operation completed (0.0055s)\n" + "2025-01-22 10:56:32.987 - model_store.store - INFO - Loading 'pytorch' model\n", + "2025-01-22 10:56:32.988 - model_store.store - INFO - Starting model reconstruction from shards\n", + "2025-01-22 10:56:32.990 - model_store.store - INFO - Loaded model from 1 shards (0.0012s)\n", + "2025-01-22 10:56:32.990 - model_store.store - INFO - Load operation completed in 0.0027s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Prediction for input 4.0 with loaded model: 6.903447151184082\n" + "Prediction for input 4.0 with loaded model: 7.4401021003723145\n" ] } ], @@ -431,25 +393,17 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "id": "e8bd2de2", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/tyler.hutcherson/Library/Caches/pypoetry/virtualenvs/redis-model-store-__eJJx5C-py3.11/lib/python3.11/site-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n", - " super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n" - ] - }, { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 17, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -458,9 +412,9 @@ "import tensorflow as tf\n", "\n", "# Define a simple model\n", - "model = tf.keras.Sequential([\n", - " tf.keras.layers.Dense(1, input_shape=(1,))\n", - "])\n", + "inputs = tf.keras.Input(shape=(1,))\n", + "outputs = tf.keras.layers.Dense(1)(inputs)\n", + "model = tf.keras.Model(inputs=inputs, outputs=outputs)\n", "\n", "# Compile the model\n", "model.compile(optimizer='sgd', loss='mse')\n", @@ -475,7 +429,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "id": "c26f35ec", "metadata": {}, "outputs": [ @@ -483,7 +437,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Prediction for input 4.0: 7.460770606994629\n" + "Prediction for input 4.0: 7.760094165802002\n" ] } ], @@ -496,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 17, "id": "5aeaf0ac", "metadata": {}, "outputs": [ @@ -504,12 +458,10 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-01-21 12:33:24.276 - model_store.store - INFO - Starting save operation for model 'tensorflow'\n", - "2025-01-21 12:33:24.280 - model_store.store - INFO - Added version record (0.0028s)\n", - "2025-01-21 12:33:24.280 - model_store.store - INFO - Starting model serialization and storage\n", - "2025-01-21 12:33:24.291 - model_store.store - INFO - Stored model in 1 shards (0.0107s)\n", - "2025-01-21 12:33:24.292 - model_store.store - INFO - Set shards (0.0014s)\n", - "2025-01-21 12:33:24.293 - model_store.store - INFO - Total save operation completed in 0.0163s\n" + "2025-01-22 10:56:36.920 - model_store.store - INFO - Saving 'tensorflow' model\n", + "2025-01-22 10:56:36.922 - model_store.store - INFO - Starting model serialization and storage\n", + "2025-01-22 10:56:36.932 - model_store.store - INFO - Stored model in 1 shards (0.0103s)\n", + "2025-01-22 10:56:36.934 - model_store.store - INFO - Save operation completed in 0.0138s\n" ] }, { @@ -518,7 +470,7 @@ "'1.0'" ] }, - "execution_count": 19, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -532,7 +484,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 18, "id": "117255d6", "metadata": {}, "outputs": [ @@ -540,18 +492,17 @@ "name": "stderr", "output_type": "stream", "text": [ - "2025-01-21 12:33:24.297 - model_store.store - INFO - Starting load operation for model 'tensorflow'\n", - "2025-01-21 12:33:24.300 - model_store.store - INFO - Retrieved version metadata (0.0027s)\n", - "2025-01-21 12:33:24.301 - model_store.store - INFO - Starting model reconstruction from shards\n", - "2025-01-21 12:33:24.312 - model_store.store - INFO - Loaded model from 1 shards (0.0108s)\n", - "2025-01-21 12:33:24.312 - model_store.store - INFO - Total load operation completed (0.0150s)\n" + "2025-01-22 10:56:36.941 - model_store.store - INFO - Loading 'tensorflow' model\n", + "2025-01-22 10:56:36.943 - model_store.store - INFO - Starting model reconstruction from shards\n", + "2025-01-22 10:56:36.954 - model_store.store - INFO - Loaded model from 1 shards (0.0106s)\n", + "2025-01-22 10:56:36.954 - model_store.store - INFO - Load operation completed in 0.0131s\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Prediction for input 4.0 with loaded model: 7.460770606994629\n" + "Prediction for input 4.0 with loaded model: 7.760094165802002\n" ] } ], @@ -568,7 +519,98 @@ "id": "db4db02e", "metadata": {}, "source": [ - "# Clear the model store" + "# Model versioning " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "84136352", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['numpy_array', 'pytorch', 'random_forest', 'tensorflow']" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# List all available models in the store\n", + "models = model_store.list_models()\n", + "models" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "cadbec7e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[ModelVersion(name='pytorch', description='', version='1.0', created_at=1737561392.98, shard_keys=['shard:pytorch:1.0:0'])]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# List model versions for a model\n", + "versions = model_store.get_all_versions(models[1])\n", + "versions" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "0962bc96", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Delete a model version\n", + "model_version = versions[0]\n", + "model_store.delete_version(model_version.name, model_version.version)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "28af2739", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "10617" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Clear all versions for all models\n", + "model_store.clear()" ] } ], diff --git a/model_store/registry.py b/model_store/registry.py deleted file mode 100644 index c5edfe6..0000000 --- a/model_store/registry.py +++ /dev/null @@ -1,294 +0,0 @@ -""" -model_registry.py - -Implements a ModelRegistry class that tracks model versioning and metadata -using RedisVL. Also includes a Pydantic model for model version metadata. -""" - -import json -from typing import List, Optional - -from pydantic import BaseModel, Field -from redis import Redis -from redisvl.index import SearchIndex -from redisvl.query import FilterQuery -from redisvl.query.filter import FilterExpression, Tag - -from model_store.utils import current_timestamp, new_model_version - - -class ModelVersion(BaseModel): - """ - Metadata for a specific model version stored in Redis. - """ - - name: str = Field(..., description="Unique identifier for the model") - description: str = Field( - default="", description="Optional description of this model version" - ) - version: str = Field( - default_factory=new_model_version, - description="Version identifier (e.g. semantic version or UUID)", - ) - created_at: float = Field( - default_factory=current_timestamp, - description="Unix timestamp when this version was created", - ) - shard_keys: List[str] = Field( - default=[], description="Redis keys containing the serialized model data" - ) - - -class ModelRegistry: - """ - Manages versioning and metadata for machine learning models stored in Redis. - - Uses RedisVL for efficient querying and indexing of model metadata. Each model - version is stored as a JSON document with searchable fields for name, version, - description and creation time. - - The registry maintains an index of model versions and their metadata, allowing - for efficient querying and retrieval of model versions by name, version, or - creation time. - """ - - def __init__(self, redis_client: Redis) -> None: - """ - Initialize the model registry. - - Args: - redis_client (Redis): Initialized Redis client for - storage and querying. - """ - self.registry_idx = SearchIndex.from_dict( - { - "index": { - "name": "model_registry", - "prefix": "model", - "storage_type": "json", - "key_separator": ":", - }, - "fields": [ - {"name": "name", "type": "tag"}, - {"name": "description", "type": "text"}, - {"name": "version", "type": "tag"}, - {"name": "created_at", "type": "numeric"}, - ], - } - ) - self.registry_idx.set_client(redis_client) - # Create index only if it doesn't exist - self.registry_idx.create(overwrite=False, drop=False) - - @staticmethod - def model_version_key(name: str, version: str) -> str: - """ - Generate Redis key for storing a specific model version. - - Args: - name (str): Model name - version (str): Model version identifier - - Returns: - str: Formatted Redis key in the format "model:{name}:{version}" - """ - return f"model:{name}:{version}" - - def add_version(self, name: str, **kwargs) -> ModelVersion: - """ - Add a new model version to the registry. - - Args: - name (str): Model name. - **kwargs: Additional model metadata fields (version, - description, shard_keys). - - Returns: - ModelVersion: Model version metadata object. - - Raises: - ValidationError: If required fields are missing or invalid. - """ - model_version = ModelVersion(name=name, **kwargs) - key = self.model_version_key(model_version.name, model_version.version) - if not self.registry_idx.client.exists(key): - self.registry_idx.load(data=[model_version.model_dump()], keys=[key]) - return model_version - raise ValueError(f"Model {name} version {model_version.version} already exists") - - def set_model_shards(self, name: str, version: str, shard_keys: List[str]): - key = self.model_version_key(name, version) - if self.registry_idx.client.exists(key): - self.registry_idx.client.json().set(key, "$.shard_keys", shard_keys) - else: - raise ValueError( - f"Failed to set shard keys. Model {name} and version {version} does not exist" - ) - - def _query_version( - self, model_filter: FilterExpression, latest: bool = False - ) -> ModelVersion: - """ - Query model versions using a filter expression. - - Args: - model_filter (FilterExpression): RediSearch filter expression to match models. - latest (bool): If True, sort by creation time to get most recent version. - Defaults to False. - - Returns: - ModelVersion: Matching ModelVersion instance. - - Raises: - ValueError: If no matching model version is found. - """ - query = FilterQuery( - filter_expression=model_filter, - num_results=1, - return_fields=[ - "name", - "description", - "version", - "created_at", - "$.shard_keys", - ], - ) - - if latest: - query.sort_by("created_at", asc=False) - - results = self.registry_idx.query(query) - - if not results: - raise ValueError(f"No model version found matching filter: {model_filter}") - - return ModelVersion( - name=results[0]["name"], - description=results[0]["description"], - version=results[0]["version"], - created_at=results[0]["created_at"], - shard_keys=json.loads(results[0]["$.shard_keys"]), - ) - - def get_version(self, name: str, version: str) -> ModelVersion: - """ - Retrieve metadata for a specific model version. - - Args: - name (str): Model identifier. - version (str): Version identifier. - - Returns: - ModelVersion: ModelVersion instance containing metadata. - - Raises: - ValueError: If the requested model version is not found - """ - model_filter = (Tag("name") == name) & (Tag("version") == version) - return self._query_version(model_filter) - - def get_latest_version(self, name: str) -> ModelVersion: - """ - Get the most recently created version of a model. - - Args: - name (str): Model identifier. - - Returns: - ModelVersion: ModelVersion instance for the latest version. - - Raises: - ValueError: If no versions exist for the model. - """ - model_filter = Tag("name") == name - return self._query_version(model_filter, latest=True) - - def get_all_versions(self, name: str) -> List[ModelVersion]: - """ - Get all versions of a model sorted by creation time (newest first). - - Args: - name (str): Model identifier. - - Returns: - List[ModelVersion]: List of ModelVersion instances for - all versions of the model, sorted by creation time with - newest first. - - Raises: - ValueError: If no versions exist for the model. - """ - query = FilterQuery( - filter_expression=Tag("name") == name, - return_fields=[ - "name", - "description", - "version", - "created_at", - "$.shard_keys", - ], - ) - query.sort_by("created_at", asc=False) - - model_versions: List[ModelVersion] = [] - - for results in self.registry_idx.paginate(query, page_size=50): - if results: - model_versions.extend( - [ - ModelVersion( - name=result["name"], - description=result["description"], - version=result["version"], - created_at=result["created_at"], - shard_keys=json.loads(result["$.shard_keys"]), - ) - for result in results - ] - ) - - if not model_versions: - raise ValueError(f"No versions found for model: {name}") - - return model_versions - - def delete_version(self, name: str, version: str) -> int: - """ - Delete a specific model version from the registry. - - Args: - name (str): Model identifier. - version (str): Version identifier to delete. - - Returns: - int: Number of model versions deleted (0 or 1). - - Raises: - ValueError: If the specified version does not exist. - """ - version_key = self.model_version_key(name, version) - deleted_count = self.registry_idx.drop_keys([version_key]) - return deleted_count - - def clear(self, name: Optional[str] = None) -> int: - """ - Clear model versions from the registry. - - Args: - name (Optional[str]): If provided, only clear versions of this model. - If None, clear all model versions. - - Returns: - int: Number of model versions that were deleted. - """ - if name: - model_versions = self.get_all_versions(name) - model_version_keys = [ - self.model_version_key(model_version.name, model_version.version) - for model_version in model_versions - ] - self.registry_idx.drop_keys(model_version_keys) - deleted_count = len(model_version_keys) - else: - deleted_count = self.registry_idx.clear() - return deleted_count diff --git a/model_store/shard_manager.py b/model_store/shard_manager.py deleted file mode 100644 index e383e3a..0000000 --- a/model_store/shard_manager.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Any, Iterator, List - -from model_store.serialize import PickleSerializer, SerializationError, Serializer - - -class ModelShardManager: - """ - Manages sharding and serialization of model data. - - This class handles breaking large model objects into manageable shards and serializing - them for storage. It provides an abstraction layer between raw model objects and their - serialized/sharded representation ready for storage. - - Key features: - - Configurable shard size to optimize for different storage backends - - Pluggable serialization via the Serializer protocol - - Maintains data integrity across sharding/reassembly - - Example: - >>> manager = ModelShardManager() - >>> shards = manager.to_shards(large_model) # Split model into shards - >>> reconstructed = manager.from_shards(shards) # Reassemble model - """ - - def __init__(self, shard_size: int, serializer: Serializer = PickleSerializer()): - """ - Initialize the shard manager. - - Args: - shard_size (int): Maximum size in bytes for each shard. - serializer (Serializer): Serializer implementation to use. Defaults to pickle. - """ - if shard_size <= 0: - raise ValueError("Shard size must be positive") - - self.shard_size = shard_size - self.serializer = serializer - - def _shardify(self, data: bytes) -> Iterator[bytes]: - """ - Split serialized data into fixed-size shards. - - Args: - data (bytes): The full serialized model data. - - Yields: - bytes: Successive shards of the data, each up to shard_size in length. - """ - total_size = len(data) - for start in range(0, total_size, self.shard_size): - yield data[start : start + self.shard_size] - - @staticmethod - def shard_key(model_name: str, model_version: str, idx: int) -> str: - """ - Generate a storage key for a model shard. - - Args: - model_name (str): Name of the model. - model_version (str): Version identifier of the model. - idx (int): Shard index. - - Returns: - str: Formatted storage key for the model shard. - """ - return f"shard:{model_name}:{model_version}:{idx}" - - def to_shards(self, model: Any) -> List[bytes]: - """ - Convert model into smaller chunks (shards) ready for storage. - - Args: - model (Any): The model object to shard. - - Returns: - List[bytes]: List of binary shards derived from the model. - TODO -- returns a generator here - - Raises: - SerializationError: If model serialization fails. - """ - try: - serialized_data = self.serializer.dumps(model) - return self._shardify(serialized_data) - except Exception as e: - raise SerializationError(f"Failed to serialize model: {str(e)}") from e - - def from_shards(self, shards: List[bytes]) -> Any: - """ - Reconstruct a model from its shards. - - Args: - shards (List[bytes]): List of binary shards to reassemble. - - Returns: - Any: The reconstructed model object. - - Raises: - SerializationError: If model deserialization fails. - """ - try: - serialized_data = b"".join(shards) - return self.serializer.loads(serialized_data) - except Exception as e: - raise SerializationError(f"Failed to deserialize model: {str(e)}") from e diff --git a/model_store/store.py b/model_store/store.py deleted file mode 100644 index acaa01c..0000000 --- a/model_store/store.py +++ /dev/null @@ -1,227 +0,0 @@ -from typing import Any, List, Optional - -from redis import Redis - -from model_store.registry import ModelRegistry, ModelVersion -from model_store.shard_manager import ModelShardManager, SerializationError -from model_store.utils import PIPELINE_BATCH_SIZE, current_timestamp, setup_logger - -logger = setup_logger(__name__) - - -class ModelStoreError(Exception): - """Raised when model store I/O operations fail.""" - pass - - -class ModelStore: - """ - High-level interface for storing and retrieving AI/ML models from Redis. - - The ModelStore provides a simple API for saving and loading models, - handling versioning, metadata, and efficient storage internally. It uses the - ModelRegistry for version tracking and metadata management, and the ModelShardManager - for serialization and sharding. - - Example: - >>> store = ModelStore(redis_client) - >>> store.save_model(model, name="my_model", description="My trained model") - >>> loaded_model = store.load_model("my_model") - """ - - def __init__(self, redis_client: Redis, shard_size: int = 1024 * 1024): - """ - Initialize the ModelStore. - - Args: - redis_client (Redis): An initialized Redis client instance. - shard_size (int, optional): Maximum size in bytes for each model shard. - Defaults to 1MB. - """ - if not isinstance(redis_client, Redis): - raise TypeError("Must provide a valid Redis client instance.") - - self.redis_client = redis_client - self.model_registry = ModelRegistry(redis_client) - self.shard_manager = ModelShardManager(shard_size=shard_size) - - def _to_redis(self, model: Any, name: str, version: str) -> List[str]: - """ - Serialize, shard and store a model in Redis. - - Args: - model (Any): The model object to store. - name (str): Name for the model. - version (str): Version identifier for the model. - - Returns: - List[str]: List of Redis keys where the model shards were stored. - - Raises: - ModelStoreError: If serialization or storage fails. - """ - start_time = current_timestamp() - logger.info("Starting model serialization and storage") - - try: - shard_keys: List[str] = [] - - # Store shards in Redis - with self.redis_client.pipeline(transaction=False) as pipe: - # Serialize and shard the model - logger.debug("Serializing and sharding model") - for i, shard in enumerate(self.shard_manager.to_shards(model)): - skey = self.shard_manager.shard_key(name, version, i) - shard_keys.append(skey) - # Store under shard keys in Redis - pipe.set(skey, shard) - if i % PIPELINE_BATCH_SIZE == 0: - logger.debug("Executing pipeline batch") - pipe.execute() - pipe.execute() - - except SerializationError as e: - raise ModelStoreError(f"Failed to serialize model: {str(e)}") from e - except Exception as e: - # Clean up any shards that were stored before the error - logger.error("Error during storage, cleaning up shards") - self.redis_client.delete(*shard_keys) - raise ModelStoreError(f"Failed to store model shards: {str(e)}") from e - - duration = current_timestamp() - start_time - logger.info(f"Stored model in {len(shard_keys)} shards ({duration:.4f}s)") - return shard_keys - - def save_model(self, model: Any, name: str, **kwargs) -> str: - """ - Store a model in Redis with versioning and metadata. - - Saves the model by: - 1. Creating a model version record in the registry. - 2. Serializing and storing the model in chunks. - 3. Updating the registry with chunk locations. - - Args: - model (Any): The model object to store. - name (str): Name for the model. - **kwargs: Additional model metadata fields (version, - description, shard_keys). - - Returns: - str: The created model version. - - Raises: - ModelStoreError: If model storage or registration fails. - """ - total_start = current_timestamp() - model_version: Optional[ModelVersion] = None - logger.info(f"Starting save operation for model '{name}'") - - try: - # Create model version record - st = current_timestamp() - logger.debug("Creating model version record") - model_version = self.model_registry.add_version(name=name, **kwargs) - logger.info(f"Added version record ({current_timestamp()-st:.4f}s)") - - # Store model chunks and get their keys - logger.debug("Starting model storage") - shard_keys = self._to_redis(model, name, model_version.version) - - # Update model version with shard locations - st = current_timestamp() - logger.debug("Updating model version with shard locations") - self.model_registry.set_model_shards( - name, model_version.version, shard_keys - ) - logger.info(f"Set shards ({current_timestamp()-st:.4f}s)") - - total_duration = current_timestamp() - total_start - logger.info(f"Total save operation completed in {total_duration:.4f}s") - return model_version.version - - except Exception as e: - # Clean up the version record if storage fails - if model_version: - logger.error("Error during save, cleaning up version record") - self.model_registry.delete_version(name, model_version.version) - if not isinstance(e, ModelStoreError): - raise ModelStoreError(f"Failed to save model: {str(e)}") from e - raise - - def _from_redis(self, shard_keys: List[str]) -> Any: - """ - Load and reconstruct a model from its shards in Redis. - - Args: - shard_keys (List[str]): List of Redis keys containing the model shards. - - Returns: - Any: The reconstructed model object. - - Raises: - ModelStoreError: If shard retrieval or deserialization fails - """ - start_time = current_timestamp() - shards: List[bytes] = [] - logger.info("Starting model reconstruction from shards") - - try: - # Retrieve shards from Redis - logger.debug("Retrieving shards from Redis") - with self.redis_client.pipeline(transaction=False) as pipe: - for i, skey in enumerate(shard_keys): - pipe.get(skey) - if i % PIPELINE_BATCH_SIZE == 0: - shards.extend(pipe.execute()) - shards.extend(pipe.execute()) - - # Deserialize the model - logger.debug("Deserializing model from shards") - model = self.shard_manager.from_shards(shards) - - duration = current_timestamp() - start_time - logger.info(f"Loaded model from {len(shard_keys)} shards ({duration:.4f}s)") - return model - - except Exception as e: - raise ModelStoreError(f"Failed to load model: {str(e)}") from e - - def load_model(self, name: str, version: Optional[str] = None) -> Any: - """ - Load a model from Redis by name and optional version. - - Args: - name (str): Unique identifier for the model. - version (Optional[str]): Specific version to load. If None, - loads the latest version. - - Returns: - Any: The reconstructed model object. - - Raises: - ModelStoreError: If model loading fails - """ - total_start = current_timestamp() - logger.info(f"Starting load operation for model '{name}'") - - try: - st = current_timestamp() - if not version: - logger.debug("Retrieving latest version") - model_version = self.model_registry.get_latest_version(name) - else: - logger.debug(f"Retrieving specific version: {version}") - model_version = self.model_registry.get_version(name, version) - - logger.info(f"Retrieved version metadata ({current_timestamp()-st:.4f}s)") - model = self._from_redis(model_version.shard_keys) - - total_duration = current_timestamp() - total_start - logger.info(f"Total load operation completed ({total_duration:.4f}s)") - return model - - except Exception as e: - if not isinstance(e, ModelStoreError): - raise ModelStoreError(f"Failed to load model {name}: {str(e)}") from e - raise diff --git a/pyproject.toml b/pyproject.toml index 3607a74..abd7e22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "License :: OSI Approved :: MIT License", ] -packages = [{ include = "model_store", from = "." }] +packages = [{ include = "redis_model_store", from = "." }] [tool.poetry.dependencies] python = ">=3.9,<3.13" @@ -71,7 +71,6 @@ exclude = ''' [tool.pytest.ini_options] log_cli = true -asyncio_mode = "auto" [tool.mypy] warn_unused_configs = true diff --git a/model_store/__init__.py b/redis_model_store/__init__.py similarity index 100% rename from model_store/__init__.py rename to redis_model_store/__init__.py diff --git a/model_store/serialize.py b/redis_model_store/serialize.py similarity index 100% rename from model_store/serialize.py rename to redis_model_store/serialize.py diff --git a/redis_model_store/shard_manager.py b/redis_model_store/shard_manager.py new file mode 100644 index 0000000..fde2f1d --- /dev/null +++ b/redis_model_store/shard_manager.py @@ -0,0 +1,98 @@ +from typing import Any, Iterator, List + +from redis_model_store.serialize import PickleSerializer, SerializationError, Serializer + + +class ModelShardManager: + """ + Manages serialization and sharding of model data. + + Handles breaking large objects into manageable chunks and provides + serialization/deserialization. This provides an abstraction layer between + raw objects and their serialized/sharded representation. + + The manager supports: + - Configurable shard sizes + - Pluggable serialization formats + - Efficient streaming of shards + """ + + def __init__(self, shard_size: int, serializer: Serializer = PickleSerializer()): + """ + Initialize the shard manager. + + Args: + shard_size: Maximum size in bytes for each shard + serializer: Serializer implementation to use (default: pickle) + + Raises: + ValueError: If shard_size is not positive + """ + if shard_size <= 0: + raise ValueError("Shard size must be positive") + + self.shard_size = shard_size + self.serializer = serializer + + @staticmethod + def _shardify(data: bytes, shard_size: int) -> Iterator[bytes]: + """Split serialized data into fixed-size shards.""" + total_size = len(data) + for start in range(0, total_size, shard_size): + yield data[start : start + shard_size] + + @staticmethod + def shard_key(model_name: str, model_version: str, idx: int) -> str: + """ + Generate a storage key for a model shard. + + Args: + model_name: Name of the model + model_version: Version identifier + idx: Shard index number + + Returns: + Formatted storage key for the shard + """ + return f"shard:{model_name}:{model_version}:{idx}" + + def to_shards(self, obj: Any) -> Iterator[bytes]: + """ + Convert object into shards ready for storage. + + The object is first serialized then split into fixed-size chunks. + Shards are yielded one at a time to minimize memory usage. + + Args: + obj: The object to shard + + Returns: + Iterator yielding binary shards + + Raises: + SerializationError: If the object cannot be serialized + """ + try: + serialized = self.serializer.dumps(obj) + return self._shardify(serialized, self.shard_size) + except Exception as e: + raise SerializationError(f"Failed to serialize object: {str(e)}") from e + + def from_shards(self, shards: List[bytes]) -> Any: + """ + Reconstruct object from shards. + + Args: + shards: List of binary shards in order + + Returns: + The reconstructed object + + Raises: + SerializationError: If the shards cannot be deserialized + """ + try: + serialized = b"".join(shards) + return self.serializer.loads(serialized) + except Exception as e: + raise SerializationError(f"Failed to deserialize object: {str(e)}") from e diff --git a/redis_model_store/store.py b/redis_model_store/store.py new file mode 100644 index 0000000..2c99fb1 --- /dev/null +++ b/redis_model_store/store.py @@ -0,0 +1,470 @@ +import json +from typing import Any, List, Optional, Set + +from pydantic import BaseModel, Field +from redis import Redis +from redisvl.index import SearchIndex +from redisvl.query import FilterQuery +from redisvl.query.filter import FilterExpression, Tag + +from redis_model_store.shard_manager import ModelShardManager +from redis_model_store.utils import ( + PIPELINE_BATCH_SIZE, + current_timestamp, + new_model_version, + setup_logger, +) + +logger = setup_logger(__name__) + + +class ModelStoreError(Exception): + """Raised when model store operations fail. + + This is the base exception for all model store operations including: + - Model saving/loading + - Version management + - Store initialization + """ + + pass + + +class ModelVersion(BaseModel): + """ + Metadata for a specific model version. + + Contains all metadata associated with a stored model version including + its name, version identifier, description, creation time, and storage + locations. + """ + + name: str = Field(..., description="Unique identifier for the model") + description: str = Field( + default="", description="Optional description of this model version" + ) + version: str = Field( + default_factory=new_model_version, + description="Version identifier (e.g. semantic version or UUID)", + ) + created_at: float = Field( + default_factory=current_timestamp, + description="Unix timestamp when this version was created", + ) + shard_keys: List[str] = Field( + default=[], description="Redis keys containing the serialized model data" + ) + + @classmethod + def from_dict(cls, result: dict) -> "ModelVersion": + """Create a ModelVersion instance from a query result dict. + + Args: + result: Dictionary containing version metadata fields + + Returns: + New ModelVersion instance + + Raises: + ModelStoreError: If required fields are missing or malformed + """ + try: + if "$.shard_keys" in result: + result["shard_keys"] = result.pop("$.shard_keys") + if isinstance(result["shard_keys"], str): + result["shard_keys"] = json.loads(result["shard_keys"]) + return cls( + name=result["name"], + description=result["description"], + version=result["version"], + created_at=result["created_at"], + shard_keys=result["shard_keys"], + ) + except (KeyError, json.JSONDecodeError) as e: + raise ModelStoreError(f"Invalid model version data: {str(e)}") from e + + +class ModelStore: + """ + High-level interface for storing and retrieving ML models with versioning. + + The ModelStore provides a simple API for saving and loading models while handling: + - Automatic model versioning + - Metadata tracking (creation time, descriptions) + - Efficient storage via sharding + - Version querying and retrieval + + Models are stored in Redis using a combination of: + - JSON storage for version metadata + - Bytes storage for model shards + - Search indices for efficient querying + + All operations are atomic and handle cleanup on failure. + + Example: + >>> store = ModelStore(redis_client) + >>> version = store.save_model( + ... model, + ... name="bert-qa", + ... description="BERT model fine-tuned for QA" + ... ) + >>> model = store.load_model("bert-qa") # loads latest version + >>> model = store.load_model("bert-qa", version=version) # loads specific version + """ + + def __init__(self, redis_client: Redis, shard_size: int = 1024 * 1024): + """ + Initialize the model store. + + Args: + redis_client: Initialized Redis client instance + shard_size: Maximum size in bytes for each model shard (default: 1MB) + + Raises: + TypeError: If redis_client is not a valid Redis instance + ModelStoreError: If store initialization fails + """ + if not isinstance(redis_client, Redis): + raise TypeError("Must provide a valid Redis client instance") + + try: + self.redis_client = redis_client + self.shard_manager = ModelShardManager(shard_size=shard_size) + + # Initialize model version index + self._query_return_fields = [ + "name", + "description", + "version", + "created_at", + "$.shard_keys", + ] + self.store_idx = SearchIndex.from_dict( + { + "index": { + "name": "model_store", + "prefix": "model_version", + "storage_type": "json", + "key_separator": ":", + }, + "fields": [ + {"name": "name", "type": "tag"}, + {"name": "description", "type": "text"}, + {"name": "version", "type": "tag"}, + {"name": "created_at", "type": "numeric"}, + ], + } + ) + self.store_idx.set_client(redis_client) + self.store_idx.create(overwrite=False, drop=False) + + except Exception as e: + raise ModelStoreError(f"Failed to initialize model store: {str(e)}") from e + + @staticmethod + def _version_key(name: str, version: str) -> str: + """Generate Redis key for a model version.""" + return f"model_version:{name}:{version}" + + def _store_shards(self, model: Any, name: str, version: str) -> List[str]: + """Store model shards in Redis.""" + start_time = current_timestamp() + logger.info("Starting model serialization and storage") + shard_keys: List[str] = [] + + try: + with self.redis_client.pipeline(transaction=False) as pipe: + # Generate shards and pipeline into Redis + for i, shard in enumerate(self.shard_manager.to_shards(model)): + shard_key = self.shard_manager.shard_key(name, version, i) + shard_keys.append(shard_key) + pipe.set(shard_key, shard) + # Flush the pipeline batch + if i % PIPELINE_BATCH_SIZE == 0: + pipe.execute() + pipe.execute() + + duration = current_timestamp() - start_time + logger.info(f"Stored model in {len(shard_keys)} shards ({duration:.4f}s)") + return shard_keys + + except Exception as e: + # Clean up any stored shards + if shard_keys: + self.redis_client.delete(*shard_keys) + raise ModelStoreError(f"Failed to store model shards: {str(e)}") from e + + def _load_shards(self, shard_keys: List[str]) -> Any: + """Load and reconstruct model from shards.""" + start_time = current_timestamp() + logger.info("Starting model reconstruction from shards") + + try: + shards: List[bytes] = [] + # Pipeline load shards from Redis + with self.redis_client.pipeline(transaction=False) as pipe: + for i, shard_key in enumerate(shard_keys): + pipe.get(shard_key) + # Flush pipeline batch + if i % PIPELINE_BATCH_SIZE == 0: + shards.extend(pipe.execute()) + shards.extend(pipe.execute()) + # Deserialize model from shards + model = self.shard_manager.from_shards(shards) + duration = current_timestamp() - start_time + logger.info(f"Loaded model from {len(shard_keys)} shards ({duration:.4f}s)") + return model + + except Exception as e: + raise ModelStoreError(f"Failed to load model shards: {str(e)}") from e + + def save_model(self, model: Any, name: str, **kwargs) -> str: + """ + Store a model with versioning and metadata. + + Handles serialization, sharding, and version tracking for the model. + If the operation fails, any partially stored data is cleaned up. + + Args: + model: The model object to store + name: Unique identifier for the model + **kwargs: Additional metadata fields including: + - version: Specific version identifier (optional) + - description: Human readable description (optional) + + Returns: + str: The version identifier for the stored model + + Raises: + ModelStoreError: If the model cannot be saved or version exists + """ + total_start = current_timestamp() + model_version: Optional[ModelVersion] = None + logger.info(f"Saving '{name}' model") + + try: + # Create version record + model_version = ModelVersion(name=name, **kwargs) + model_version_key = self._version_key(name, model_version.version) + + if self.redis_client.exists(model_version_key): + raise ModelStoreError( + f"Version exists: Model {name} version {model_version.version}" + ) + + # Store model chunks and update version + shard_keys = self._store_shards(model, name, model_version.version) + model_version.shard_keys = shard_keys + self.store_idx.load( + data=[model_version.model_dump()], keys=[model_version_key] + ) + + total_duration = current_timestamp() - total_start + logger.info(f"Save operation completed in {total_duration:.4f}s") + return model_version.version + + except Exception as e: + # Clean up version record if it was created + if model_version: + self._delete_version(name, model_version.version) + if not isinstance(e, ModelStoreError): + raise ModelStoreError(f"Failed to save model: {str(e)}") from e + raise + + def load_model(self, name: str, version: Optional[str] = None) -> Any: + """ + Load a model by name and optional version. + + Args: + name: Unique identifier for the model + version: Specific version to load. If None, loads the latest version + + Returns: + The reconstructed model object + + Raises: + ModelStoreError: If the model/version is not found or cannot be loaded + """ + total_start = current_timestamp() + logger.info(f"Loading '{name}' model") + + try: + # Get model version metadata + if version: + model_version = self.get_version(name, version) + else: + model_version = self.get_latest_version(name) + + # Load model data + model = self._load_shards(model_version.shard_keys) + total_duration = current_timestamp() - total_start + logger.info(f"Load operation completed in {total_duration:.4f}s") + return model + + except ModelStoreError: + raise + except Exception as e: + raise ModelStoreError(f"Failed to load model: {str(e)}") from e + + def get_version(self, name: str, version: str) -> ModelVersion: + """ + Get metadata for a specific model version. + + Args: + name: Model identifier + version: Version identifier + + Returns: + ModelVersion containing the version metadata + + Raises: + ModelStoreError: If the version is not found + """ + model_version_key = self._version_key(name, version) + model_version_dict = self.redis_client.json().get(model_version_key) + if not model_version_dict: + raise ModelStoreError(f"Version not found: Model {name} version {version}") + return ModelVersion.from_dict(model_version_dict) + + def get_latest_version(self, name: str) -> ModelVersion: + """ + Get the most recent version of a model. + + Args: + name: Model identifier + + Returns: + ModelVersion for the most recently created version + + Raises: + ModelStoreError: If no versions exist for the model + """ + query = FilterQuery( + filter_expression=Tag("name") == name, + num_results=1, + return_fields=self._query_return_fields, + ).sort_by("created_at", asc=False) + + results = self.store_idx.query(query) + if not results: + raise ModelStoreError(f"No versions found for model: {name}") + + return ModelVersion.from_dict(results[0]) + + def get_all_versions(self, name: str) -> List[ModelVersion]: + """ + Get all versions of a model sorted by creation time. + + Args: + name: Model identifier + + Returns: + List of ModelVersion objects, sorted newest to oldest + + Raises: + ModelStoreError: If no versions exist or metadata is invalid + """ + query = FilterQuery( + filter_expression=Tag("name") == name, + return_fields=self._query_return_fields, + ).sort_by("created_at", asc=False) + + versions: List[ModelVersion] = [] + for results in self.store_idx.paginate(query, page_size=50): + if results: + versions.extend(ModelVersion.from_dict(result) for result in results) + + if not versions: + raise ModelStoreError(f"No versions found for model: {name}") + + return versions + + def _delete_version(self, name: str, version: str) -> int: + """Delete a model version and its shards.""" + try: + model_version = self.get_version(name, version) + except ModelStoreError: + # Ignore case where model version doesn't exist + pass + + keys = model_version.shard_keys + [self._version_key(name, version)] + self.store_idx.drop_keys(keys) + return len(keys) + + def delete_version(self, name: str, version: str) -> int: + """ + Delete a specific model version. + + Removes both the version metadata and all associated model shards. + + Args: + name: Model identifier + version: Version identifier + + Returns: + Number of Redis keys deleted + + """ + return self._delete_version(name, version) + + def clear(self, name: Optional[str] = None) -> int: + """ + Clear model versions from the store. + + Args: + name: If provided, only clear versions of this model. + If None, clear all models and versions. + + Returns: + Number of Redis keys deleted + + Raises: + ModelStoreError: If clearing operation fails + """ + try: + # Get versions to delete based on name param + if name: + versions = self.get_all_versions(name) + else: + # Get all versions across all models + models = self.list_models() + versions = [] + for model_name in models: + versions.extend(self.get_all_versions(model_name)) + + # Delete each version and count total keys removed + total_deleted = 0 + for version in versions: + total_deleted += self._delete_version(version.name, version.version) + return total_deleted + + except ModelStoreError: + raise + except Exception as e: + raise ModelStoreError(f"Failed to clear store: {str(e)}") from e + + def list_models(self) -> List[str]: + """ + Get a list of all model names in the store. + + Returns: + List of unique model names, sorted alphabetically + + Raises: + ModelStoreError: If query fails + """ + try: + query = FilterQuery( + return_fields=["name"], + ).sort_by("name", asc=True) + + # Use set to get unique model names + model_names: Set[str] = set() + for results in self.store_idx.paginate(query, page_size=50): + if results: + model_names.update(result["name"] for result in results) + + return sorted(model_names) + + except Exception as e: + raise ModelStoreError(f"Failed to list models: {str(e)}") from e diff --git a/model_store/utils.py b/redis_model_store/utils.py similarity index 98% rename from model_store/utils.py rename to redis_model_store/utils.py index 8ebdad9..f5cb911 100644 --- a/model_store/utils.py +++ b/redis_model_store/utils.py @@ -3,7 +3,7 @@ from uuid import uuid4 #: How many commands to queue in a Redis pipeline before executing. -PIPELINE_BATCH_SIZE = 256 +PIPELINE_BATCH_SIZE = 64 def setup_logger(name): diff --git a/scripts.py b/scripts.py index 1c54de9..74e286e 100644 --- a/scripts.py +++ b/scripts.py @@ -2,30 +2,30 @@ def format(): - subprocess.run(["isort", "./model_store", "--profile", "black"], check=True) - subprocess.run(["black", "./model_store"], check=True) + subprocess.run(["isort", "./redis_model_store", "./tests", "--profile", "black"], check=True) + subprocess.run(["black", "./redis_model_store", "./tests"], check=True) def check_format(): - subprocess.run(["black", "--check", "./model_store"], check=True) + subprocess.run(["black", "--check", "./redis_model_store"], check=True) def sort_imports(): - subprocess.run(["isort", "./model_store", "./tests/", "--profile", "black"], check=True) + subprocess.run(["isort", "./redis_model_store", "./tests/", "--profile", "black"], check=True) def check_sort_imports(): subprocess.run( - ["isort", "./model_store", "--check-only", "--profile", "black"], check=True + ["isort", "./redis_model_store", "--check-only", "--profile", "black"], check=True ) def check_lint(): - subprocess.run(["pylint", "--rcfile=.pylintrc", "./model_store"], check=True) + subprocess.run(["pylint", "--rcfile=.pylintrc", "./redis_model_store"], check=True) def check_mypy(): - subprocess.run(["python", "-m", "mypy", "./model_store"], check=True) + subprocess.run(["python", "-m", "mypy", "./redis_model_store"], check=True) def test(): diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml new file mode 100644 index 0000000..e8648df --- /dev/null +++ b/tests/docker-compose.yaml @@ -0,0 +1,14 @@ +version: "3.9" +services: + redis: + image: "redis/redis-stack:${REDIS_VERSION}" + ports: + - "6379" + environment: + - "REDIS_ARGS=--save '' --appendonly no" + deploy: + replicas: 1 + restart_policy: + condition: on-failure + labels: + - "com.docker.compose.publishers=redis,6379,6379" \ No newline at end of file diff --git a/tests/test_serialize.py b/tests/test_serialize.py new file mode 100644 index 0000000..cfe5519 --- /dev/null +++ b/tests/test_serialize.py @@ -0,0 +1,53 @@ +import pytest + +from redis_model_store.serialize import PickleSerializer + + +@pytest.fixture +def serializer(): + return PickleSerializer() + + +class SampleClass: + def __init__(self, x, y): + self.x = x + self.y = y + + def __eq__(self, other): + return ( + isinstance(other, SampleClass) and self.x == other.x and self.y == other.y + ) + + +def test_pickle_serializer_simple_types(serializer): + # Test with different types + test_cases = [ + 42, + "hello world", + [1, 2, 3], + {"a": 1, "b": 2}, + (1, "two", 3.0), + True, + None, + ] + + for obj in test_cases: + serialized = serializer.dumps(obj) + deserialized = serializer.loads(serialized) + assert deserialized == obj + + +def test_pickle_serializer_complex_object(serializer): + obj = SampleClass(42, "test") + + serialized = serializer.dumps(obj) + deserialized = serializer.loads(serialized) + + assert deserialized == obj + assert deserialized.x == 42 + assert deserialized.y == "test" + + +def test_pickle_serializer_invalid_data(serializer): + with pytest.raises(Exception): + serializer.loads(b"invalid data") diff --git a/tests/test_shard_manager.py b/tests/test_shard_manager.py new file mode 100644 index 0000000..0ec138e --- /dev/null +++ b/tests/test_shard_manager.py @@ -0,0 +1,106 @@ +import pytest + +from redis_model_store.serialize import PickleSerializer +from redis_model_store.shard_manager import ModelShardManager, SerializationError + + +@pytest.fixture +def small_shard_manager(): + return ModelShardManager(shard_size=10) # Very small shard size for testing + + +@pytest.fixture +def shard_manager(): + return ModelShardManager(shard_size=1024) # More realistic shard size + + +class LargeObject: + def __init__(self, data): + self.data = data + + def __eq__(self, other): + return isinstance(other, LargeObject) and self.data == other.data + + +def test_shard_manager_init(): + # Test valid initialization + manager = ModelShardManager(shard_size=1024) + assert manager.shard_size == 1024 + assert isinstance(manager.serializer, PickleSerializer) + + # Test invalid shard size + with pytest.raises(ValueError): + ModelShardManager(shard_size=0) + + with pytest.raises(ValueError): + ModelShardManager(shard_size=-1) + + +def test_shard_key_format(): + manager = ModelShardManager(shard_size=1024) + key = manager.shard_key("model1", "v1", 0) + assert key == "shard:model1:v1:0" + + +def test_small_object_single_shard(shard_manager): + # Object small enough to fit in one shard + obj = {"key": "value"} + + shards = list(shard_manager.to_shards(obj)) + assert len(shards) == 1 + + reconstructed = shard_manager.from_shards(shards) + assert reconstructed == obj + + +def test_large_object_multiple_shards(small_shard_manager): + # Create object that will require multiple shards + obj = LargeObject("x" * 25) # Will be split into multiple shards + + shards = list(small_shard_manager.to_shards(obj)) + assert len(shards) > 1 + + reconstructed = small_shard_manager.from_shards(shards) + assert isinstance(reconstructed, LargeObject) + assert reconstructed == obj + assert reconstructed.data == obj.data + + +def test_shard_size_boundaries(small_shard_manager): + # Test objects of different sizes around the shard boundary + test_sizes = [9, 10, 11, 19, 20, 21, 100] + + for size in test_sizes: + obj = LargeObject("x" * size) + shards = list(small_shard_manager.to_shards(obj)) + reconstructed = small_shard_manager.from_shards(shards) + assert reconstructed == obj + + +def test_invalid_shards(): + manager = ModelShardManager(shard_size=10) + + # Test with corrupted data + with pytest.raises(SerializationError): + manager.from_shards([b"invalid data"]) + + # Test with empty shards + with pytest.raises(SerializationError): + manager.from_shards([]) + + +def test_complex_object_serialization(shard_manager): + # Test with a more complex nested structure + obj = { + "list": [1, 2, 3], + "dict": {"nested": "value"}, + "tuple": (4, 5, 6), + "object": LargeObject("test"), + } + + shards = list(shard_manager.to_shards(obj)) + reconstructed = shard_manager.from_shards(shards) + + assert reconstructed == obj + assert isinstance(reconstructed["object"], LargeObject) + assert reconstructed["object"].data == "test" diff --git a/tests/test_store.py b/tests/test_store.py new file mode 100644 index 0000000..fa26e69 --- /dev/null +++ b/tests/test_store.py @@ -0,0 +1,221 @@ +import json +import os +from typing import Generator + +import pytest +from redis import Redis + +from redis_model_store.store import ModelStore, ModelStoreError, ModelVersion + + +@pytest.fixture +def redis_client(): + """Provide a Redis client for testing.""" + client = Redis.from_url(os.environ["REDIS_URL"]) + yield client + # Clean up after tests + client.flushdb() + + +@pytest.fixture +def store(redis_client: Redis): + """Provide a ModelStore instance for testing.""" + return ModelStore(redis_client, shard_size=1024) + + +@pytest.fixture +def sample_model(): + """Provide a simple model for testing.""" + return {"weights": [1.0, 2.0, 3.0], "config": {"layers": 3, "activation": "relu"}} + + +@pytest.fixture +def large_model(): + """Provide a model large enough to create multiple shards.""" + return { + "weights": [1.0456] * 1000, # Large array + "config": {"layers": 10, "activation": "relu"}, + } + + +@pytest.fixture +def populated_store(store: ModelStore, sample_model: dict): + """Provide a store pre-populated with test models.""" + versions = { + "model-a": ["v1.0"], + "model-b": ["v1.0", "v2.0"], + "model-c": ["v1.0", "v1.1", "v2.0"], + } + + for name, model_versions in versions.items(): + for version in model_versions: + store.save_model( + sample_model, + name=name, + version=version, + description=f"Test model {name} {version}", + ) + + yield store + + +class TestModelVersion: + """Test ModelVersion creation and serialization.""" + + def test_create_minimal(self): + """Should create ModelVersion with only required fields.""" + version = ModelVersion(name="test-model") + + assert version.name == "test-model" + assert version.description == "" + assert version.version # auto-generated + assert version.created_at > 0 + assert version.shard_keys == [] + + def test_create_complete(self): + """Should create ModelVersion with all fields specified.""" + version = ModelVersion( + name="test-model", + description="Test model", + version="v1.0", + created_at=1234567890.0, + shard_keys=["shard:1", "shard:2"], + ) + + assert version.name == "test-model" + assert version.description == "Test model" + assert version.version == "v1.0" + assert version.created_at == 1234567890.0 + assert version.shard_keys == ["shard:1", "shard:2"] + + def test_from_dict_valid(self): + """Should create ModelVersion from valid query result.""" + data = { + "name": "test-model", + "description": "Test model", + "version": "v1.0", + "created_at": 1234567890.0, + "$.shard_keys": json.dumps(["shard:1", "shard:2"]), + } + + version = ModelVersion.from_dict(data) + assert version.name == "test-model" + assert version.shard_keys == ["shard:1", "shard:2"] + + @pytest.mark.parametrize( + "invalid_data", + [ + {}, # Empty dict + {"name": "test"}, # Missing required fields + { # Invalid shard keys format + "name": "test", + "version": "v1", + "created_at": 123, + "$.shard_keys": "invalid", + }, + ], + ) + def test_from_dict_invalid(self, invalid_data): + """Should raise error for invalid data formats.""" + with pytest.raises(ModelStoreError, match="Invalid model version data"): + ModelVersion.from_dict(invalid_data) + + +class TestModelStore: + """Test ModelStore operations.""" + + def test_init_invalid_client(self): + """Should reject invalid Redis client.""" + with pytest.raises(TypeError, match="Must provide a valid Redis client"): + ModelStore(None) + + def test_save_and_load_basic(self, store: ModelStore, sample_model: dict): + """Should save and load model with basic metadata.""" + version = store.save_model( + sample_model, name="test-model", description="Test model" + ) + assert version # version string returned + + loaded = store.load_model("test-model") + assert loaded == sample_model + + def test_save_and_load_large(self, store: ModelStore, large_model: dict): + """Should handle models requiring multiple shards.""" + version = store.save_model(large_model, name="large-model") + loaded = store.load_model("large-model") + assert loaded == large_model + + def test_save_duplicate_version(self, store: ModelStore, sample_model: dict): + """Should prevent duplicate version creation.""" + store.save_model(sample_model, name="test", version="v1.0") + + with pytest.raises(ModelStoreError, match="Version exists"): + store.save_model(sample_model, name="test", version="v1.0") + + @pytest.mark.parametrize( + "name,version", + [ + ("nonexistent", None), # No such model + ("test-model", "v999"), # No such version + ], + ) + def test_load_nonexistent(self, store: ModelStore, name: str, version: str): + """Should handle loading nonexistent models/versions.""" + with pytest.raises(ModelStoreError): + store.load_model(name, version) + + def test_version_management(self, store: ModelStore, sample_model: dict): + """Should manage multiple versions correctly.""" + # Create versions + v1 = store.save_model(sample_model, name="test", version="v1.0") + v2 = store.save_model(sample_model, name="test", version="v2.0") + + # Get specific version + version = store.get_version("test", "v1.0") + assert version.version == "v1.0" + + # Get latest + latest = store.get_latest_version("test") + assert latest.version == "v2.0" + + # Get all versions + versions = store.get_all_versions("test") + assert len(versions) == 2 + assert [v.version for v in versions] == ["v2.0", "v1.0"] + + def test_list_models(self, populated_store: ModelStore): + """Should list available models correctly.""" + models = populated_store.list_models() + assert models == ["model-a", "model-b", "model-c"] + + # After deletion + populated_store.clear("model-b") + models = populated_store.list_models() + assert models == ["model-a", "model-c"] + + def test_delete_version(self, populated_store: ModelStore): + """Should delete specific version and maintain others.""" + # Delete middle version + deleted = populated_store.delete_version("model-c", "v1.1") + assert deleted > 0 + + versions = populated_store.get_all_versions("model-c") + assert [v.version for v in versions] == ["v2.0", "v1.0"] + + def test_clear_specific(self, populated_store: ModelStore): + """Should clear specific model and maintain others.""" + populated_store.clear("model-b") + + # model-b should be gone + with pytest.raises(ModelStoreError): + populated_store.get_version("model-b", "v1.0") + + # others should remain + assert populated_store.list_models() == ["model-a", "model-c"] + + def test_clear_all(self, populated_store: ModelStore): + """Should clear entire store.""" + deleted = populated_store.clear() + assert deleted > 0 + + assert populated_store.list_models() == [] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d7d5eb6 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,47 @@ +import logging +import uuid +from datetime import datetime, timezone + +from redis_model_store.utils import current_timestamp, new_model_version, setup_logger + + +def test_setup_logger(): + logger = setup_logger("test_logger") + + assert isinstance(logger, logging.Logger) + assert logger.name == "test_logger" + assert logger.level == logging.INFO + assert len(logger.handlers) == 1 + assert isinstance(logger.handlers[0], logging.StreamHandler) + + # Test that calling setup_logger again doesn't add another handler + logger = setup_logger("test_logger") + assert len(logger.handlers) == 1 + + +def test_current_timestamp(): + timestamp = current_timestamp() + now = datetime.now(timezone.utc).timestamp() + + assert isinstance(timestamp, float) + # Check if timestamp is recent (within 1 second) + assert abs(timestamp - now) < 1 + + +def test_new_model_version(): + version = new_model_version() + + assert isinstance(version, str) + # Verify it's a valid UUID + assert uuid.UUID(version) + + # Test uniqueness + another_version = new_model_version() + assert version != another_version + + +def test_pipeline_batch_size_constant(): + from redis_model_store.utils import PIPELINE_BATCH_SIZE + + assert isinstance(PIPELINE_BATCH_SIZE, int) + assert PIPELINE_BATCH_SIZE > 0