Skip to content

Commit

Permalink
Zarr Python v3 (#484)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jun 28, 2024
1 parent 3fe71a4 commit e7ff365
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 9 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ omit =
cubed/runtime/executors/lithops.py
cubed/runtime/executors/modal*.py
cubed/storage/backends/tensorstore.py
cubed/storage/backends/zarr_python_v3.py
cubed/vendor/*
58 changes: 58 additions & 0 deletions .github/workflows/zarr-v3-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
name: Zarr v3 Tests

on:
push:
branches:
- "main"
pull_request:
workflow_dispatch:

concurrency:
group: Tests-${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.10"]

steps:
- name: Checkout source
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Setup Graphviz
uses: ts-graphviz/setup-graphviz@v2
with:
macos-skip-brew-update: 'true'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
- name: Install
run: |
python -m pip install -e .[test]
python -m pip install -U git+https://github.com/zarr-developers/zarr-python.git@v3
- name: Run tests
env:
COVERAGE_CORE: sysmon
run: |
if [ "$RUNNER_OS" == "Windows" ]; then
pytest -v
else
pytest -v --cov=cubed --cov-report=term-missing --cov-fail-under=90
fi
shell: bash
6 changes: 2 additions & 4 deletions cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,13 +379,11 @@ def visualize(
elif node_type == "array":
target = d["target"]
chunkmem = memory_repr(chunk_memory(target))
nbytes = None

# materialized arrays are light orange, virtual arrays are white
if isinstance(target, (LazyZarrArray, zarr.Array)):
d["style"] = "filled"
d["fillcolor"] = "#ffd8b1"
nbytes = memory_repr(target.nbytes)
if n in array_display_names:
var_name = array_display_names[n]
label = f"{n}\n{var_name}"
Expand All @@ -394,8 +392,8 @@ def visualize(
tooltip += f"chunks: {target.chunks}\n"
tooltip += f"dtype: {target.dtype}\n"
tooltip += f"chunk memory: {chunkmem}\n"
if nbytes is not None:
tooltip += f"nbytes: {nbytes}\n"
if hasattr(target, "nbytes"):
tooltip += f"nbytes: {memory_repr(target.nbytes)}\n"

del d["target"]

Expand Down
14 changes: 13 additions & 1 deletion cubed/storage/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,22 @@ def open_backend_array(
# e.g. set globally with CUBED_STORAGE_NAME=tensorstore
storage_name = config.get("storage_name", None)

if storage_name is None or storage_name == "zarr-python":
if storage_name is None:
import zarr

if zarr.__version__[0] == "3":
storage_name = "zarr-python-v3"
else:
storage_name = "zarr-python"

if storage_name == "zarr-python":
from cubed.storage.backends.zarr_python import open_zarr_array

open_func = open_zarr_array
elif storage_name == "zarr-python-v3":
from cubed.storage.backends.zarr_python_v3 import open_zarr_v3_array

open_func = open_zarr_v3_array
elif storage_name == "tensorstore":
from cubed.storage.backends.tensorstore import open_tensorstore_array

Expand Down
80 changes: 80 additions & 0 deletions cubed/storage/backends/zarr_python_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Optional

import zarr

from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
from cubed.utils import join_path


class ZarrV3ArrayGroup(dict):
def __init__(
self,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
):
dict.__init__(self)
self.shape = shape
self.dtype = dtype
self.chunks = chunks

def __getitem__(self, key):
if isinstance(key, str):
return super().__getitem__(key)
return {field: zarray[key] for field, zarray in self.items()}

def set_basic_selection(self, selection, value, fields=None):
self[fields][selection] = value


def open_zarr_v3_array(
store: T_Store,
mode: str,
*,
shape: Optional[T_Shape] = None,
dtype: Optional[T_DType] = None,
chunks: Optional[T_RegularChunks] = None,
path: Optional[str] = None,
**kwargs,
):
if isinstance(chunks, int):
chunks = (chunks,)

if mode in ("r", "r+"):
# TODO: remove when https://github.com/zarr-developers/zarr-python/issues/1978 is fixed
if mode == "r+":
mode = "w"
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
return zarr.open(store=store, mode=mode, path=path)
else:
ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
for field in dtype.fields:
field_dtype, _ = dtype.fields[field]
field_path = field if path is None else join_path(path, field)
ret[field] = zarr.open(store=store, mode=mode, path=field_path)
return ret
else:
overwrite = True if mode == "a" else False
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
return zarr.create(
shape=shape,
dtype=dtype,
chunk_shape=chunks,
store=store,
overwrite=overwrite,
path=path,
)
else:
ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
for field in dtype.fields:
field_dtype, _ = dtype.fields[field]
field_path = field if path is None else join_path(path, field)
ret[field] = zarr.create(
shape=shape,
dtype=field_dtype,
chunk_shape=chunks,
store=store,
overwrite=overwrite,
path=field_path,
)
return ret
2 changes: 1 addition & 1 deletion cubed/tests/storage/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_lazy_zarr_array(tmp_path):
arr = lazy_zarr_array(zarr_path, shape=(3, 3), dtype=int, chunks=(2, 2))

assert not zarr_path.exists()
with pytest.raises(ValueError):
with pytest.raises((TypeError, ValueError)):
arr.open()

arr.create()
Expand Down
6 changes: 3 additions & 3 deletions cubed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_store(tmp_path, spec):
target = zarr.empty(a.shape, store=store)

cubed.store(a, target)
assert_array_equal(target, np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
assert_array_equal(target[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))


def test_store_multiple(tmp_path, spec):
Expand All @@ -146,8 +146,8 @@ def test_store_multiple(tmp_path, spec):
target2 = zarr.empty(b.shape, store=store2)

cubed.store([a, b], [target1, target2])
assert_array_equal(target1, np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
assert_array_equal(target2, np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]))
assert_array_equal(target1[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
assert_array_equal(target2[:], np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]]))


def test_store_fails(tmp_path, spec):
Expand Down

0 comments on commit e7ff365

Please sign in to comment.