Skip to content

Commit

Permalink
test(stores): provide public store test class
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamman committed Apr 21, 2024
1 parent 49f1505 commit 31f66ff
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/zarr/v3/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], N
# and scoped to specific zarr versions
if key not in ("zarr.json", ".zgroup", ".zattrs"):
try:
# TODO: performance optimization -- load children concurrently
# TODO: performance optimization -- load children concurrently
child = await self.getitem(key)
yield key, child
except KeyError:
Expand Down
5 changes: 1 addition & 4 deletions src/zarr/v3/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ async def list(self) -> AsyncGenerator[str, None]:
if p.is_file():
yield str(p)


async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
"""Retrieve all keys in the store with a given prefix.
Expand All @@ -173,7 +172,6 @@ async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]:
if p.is_file():
yield str(p)


async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
Retrieve all keys and prefixes with a given prefix and which do not contain the character
Expand All @@ -189,12 +187,11 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
base = self.root / prefix
to_strip = str(base) + "/"

try:
key_iter = base.iterdir()
except (FileNotFoundError, NotADirectoryError):
key_iter = []

for key in key_iter:
yield str(key).replace(to_strip, "")

2 changes: 1 addition & 1 deletion src/zarr/v3/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async def get(
async def get_partial_values(
self, key_ranges: List[Tuple[str, Tuple[int, int]]]
) -> List[bytes]:
raise NotImplementedError
return [await self.get(key, range) for key, range in key_ranges]

async def exists(self, key: str) -> bool:
return key in self._store_dict
Expand Down
46 changes: 46 additions & 0 deletions src/zarr/v3/testing/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest

from zarr.v3.abc.store import Store


class StoreTests:
store_cls: type[Store]

@pytest.fixture(scope="function")
def store(self) -> Store:
return self.store_cls()

def test_store_type(self, store: Store) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_repr(self, store: Store) -> None:
assert repr(store)

@pytest.mark.asyncio
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None:
await store.set(key, data)
assert await store.get(key) == data

# @pytest.mark.parametrize("key, data, error, error_msg", [
# ("", b"\x01\x02\x03\x04", ValueError, "invalid key")
# ])
# async def test_set_raises(self, store, key: Any, data: Any) -> None:
# with pytest.raises(TypeError):
# await store.set(key, data)

@pytest.mark.asyncio
@pytest.mark.parametrize("key", ["foo/c/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_get_partial_values(self, store: Store, key: str, data: bytes) -> None:
# put all of the data
await store.set(key, data)
# read back just part of it
vals = await store.get_partial_values([(key, (0, 2))])
assert vals == [data[0:2]]

# read back multiple parts of it at once
vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))])
assert vals == [data[0:2], data[2:4]]
14 changes: 14 additions & 0 deletions tests/v3/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,25 @@

# import numpy as np
from __future__ import annotations
from zarr.v3.testing.store import StoreTests
from zarr.v3.store.local import LocalStore
from zarr.v3.store.memory import MemoryStore
from pathlib import Path
import pytest


class TestLocalStore(StoreTests):
store_cls = LocalStore

@pytest.fixture(scope="function")
def store(self, tmpdir) -> LocalStore:
return self.store_cls(str(tmpdir))


class TestMemoryStore(StoreTests):
store_cls = MemoryStore


@pytest.mark.parametrize("auto_mkdir", (True, False))
def test_local_store_init(tmpdir, auto_mkdir: bool) -> None:
tmpdir_str = str(tmpdir)
Expand Down

0 comments on commit 31f66ff

Please sign in to comment.