diff --git a/docs/release.rst b/docs/release.rst index 3ed47ff9f5..b78e709c0e 100644 --- a/docs/release.rst +++ b/docs/release.rst @@ -18,6 +18,12 @@ Release notes Unreleased (v3) --------------- +Enhancements +~~~~~~~~~~~~ + +* Implement listing of the sub-arrays and sub-groups for a V3 ``Group``. + By :user:`Davis Bennett ` :issue:`1726`. + Maintenance ~~~~~~~~~~~ diff --git a/src/zarr/group.py b/src/zarr/group.py index cd2c00dc11..c40b5f9a34 100644 --- a/src/zarr/group.py +++ b/src/zarr/group.py @@ -1,10 +1,18 @@ from __future__ import annotations +from typing import TYPE_CHECKING from dataclasses import asdict, dataclass, field, replace import asyncio import json import logging -from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, List + +if TYPE_CHECKING: + from typing import ( + Any, + AsyncGenerator, + Literal, + AsyncIterator, + ) from zarr.abc.metadata import Metadata from zarr.array import AsyncArray, Array @@ -25,7 +33,7 @@ def parse_zarr_format(data: Any) -> Literal[2, 3]: # todo: convert None to empty dict -def parse_attributes(data: Any) -> Dict[str, Any]: +def parse_attributes(data: Any) -> dict[str, Any]: if data is None: return {} elif isinstance(data, dict) and all(map(lambda v: isinstance(v, str), data.keys())): @@ -36,12 +44,12 @@ def parse_attributes(data: Any) -> Dict[str, Any]: @dataclass(frozen=True) class GroupMetadata(Metadata): - attributes: Dict[str, Any] = field(default_factory=dict) + attributes: dict[str, Any] = field(default_factory=dict) zarr_format: Literal[2, 3] = 3 node_type: Literal["group"] = field(default="group", init=False) # todo: rename this, since it doesn't return bytes - def to_bytes(self) -> Dict[str, bytes]: + def to_bytes(self) -> dict[str, bytes]: if self.zarr_format == 3: return {ZARR_JSON: json.dumps(self.to_dict()).encode()} else: @@ -50,7 +58,7 @@ def to_bytes(self) -> Dict[str, bytes]: ZATTRS_JSON: json.dumps(self.attributes).encode(), } - def __init__(self, attributes: Optional[Dict[str, Any]] = None, zarr_format: Literal[2, 3] = 3): + def __init__(self, attributes: dict[str, Any] | None = None, zarr_format: Literal[2, 3] = 3): attributes_parsed = parse_attributes(attributes) zarr_format_parsed = parse_zarr_format(zarr_format) @@ -58,11 +66,11 @@ def __init__(self, attributes: Optional[Dict[str, Any]] = None, zarr_format: Lit object.__setattr__(self, "zarr_format", zarr_format_parsed) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> GroupMetadata: + def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: assert data.pop("node_type", None) in ("group", None) return cls(**data) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return asdict(self) @@ -70,14 +78,14 @@ def to_dict(self) -> Dict[str, Any]: class AsyncGroup: metadata: GroupMetadata store_path: StorePath - runtime_configuration: RuntimeConfiguration + runtime_configuration: RuntimeConfiguration = RuntimeConfiguration() @classmethod async def create( cls, store: StoreLike, *, - attributes: Optional[Dict[str, Any]] = None, + attributes: dict[str, Any] = {}, exists_ok: bool = False, zarr_format: Literal[2, 3] = 3, runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(), @@ -89,7 +97,7 @@ async def create( elif zarr_format == 2: assert not await (store_path / ZGROUP_JSON).exists() group = cls( - metadata=GroupMetadata(attributes=attributes or {}, zarr_format=zarr_format), + metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format), store_path=store_path, runtime_configuration=runtime_configuration, ) @@ -137,7 +145,7 @@ async def open( def from_dict( cls, store_path: StorePath, - data: Dict[str, Any], + data: dict[str, Any], runtime_configuration: RuntimeConfiguration, ) -> AsyncGroup: group = cls( @@ -150,14 +158,24 @@ def from_dict( async def getitem( self, key: str, - ) -> Union[AsyncArray, AsyncGroup]: + ) -> AsyncArray | AsyncGroup: store_path = self.store_path / key + # Note: + # in zarr-python v2, we first check if `key` references an Array, else if `key` references + # a group,using standalone `contains_array` and `contains_group` functions. These functions + # are reusable, but for v3 they would perform redundant I/O operations. + # Not clear how much of that strategy we want to keep here. + + # if `key` names an object in storage, it cannot be an array or group + if await store_path.exists(): + raise KeyError(key) + if self.metadata.zarr_format == 3: zarr_json_bytes = await (store_path / ZARR_JSON).get() if zarr_json_bytes is None: # implicit group? - logger.warning("group at {} is an implicit group", store_path) + logger.warning("group at %s is an implicit group", store_path) zarr_json = { "zarr_format": self.metadata.zarr_format, "node_type": "group", @@ -196,7 +214,7 @@ async def getitem( else: if zgroup_bytes is None: # implicit group? - logger.warning("group at {} is an implicit group", store_path) + logger.warning("group at %s is an implicit group", store_path) zgroup = ( json.loads(zgroup_bytes) if zgroup_bytes is not None @@ -248,7 +266,7 @@ async def create_array(self, path: str, **kwargs) -> AsyncArray: **kwargs, ) - async def update_attributes(self, new_attributes: Dict[str, Any]): + async def update_attributes(self, new_attributes: dict[str, Any]): # metadata.attributes is "frozen" so we simply clear and update the dict self.metadata.attributes.clear() self.metadata.attributes.update(new_attributes) @@ -269,26 +287,68 @@ async def update_attributes(self, new_attributes: Dict[str, Any]): def __repr__(self): return f"" - async def nchildren(self) -> int: - raise NotImplementedError - - async def children(self) -> AsyncIterator[Union[AsyncArray, AsyncGroup]]: - raise NotImplementedError - - async def contains(self, child: str) -> bool: + async def nmembers(self) -> int: raise NotImplementedError - async def group_keys(self) -> AsyncIterator[str]: - raise NotImplementedError + async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]: + """ + Returns an AsyncGenerator over the arrays and groups contained in this group. + This method requires that `store_path.store` supports directory listing. + + The results are not guaranteed to be ordered. + """ + if not self.store_path.store.supports_listing: + msg = ( + f"The store associated with this group ({type(self.store_path.store)}) " + "does not support listing, " + "specifically via the `list_dir` method. " + "This function requires a store that supports listing." + ) - async def groups(self) -> AsyncIterator[AsyncGroup]: - raise NotImplementedError + raise ValueError(msg) + subkeys = await self.store_path.store.list_dir(self.store_path.path) + # would be nice to make these special keys accessible programmatically, + # and scoped to specific zarr versions + subkeys_filtered = filter(lambda v: v not in ("zarr.json", ".zgroup", ".zattrs"), subkeys) + # is there a better way to schedule this? + for subkey in subkeys_filtered: + try: + yield (subkey, await self.getitem(subkey)) + except KeyError: + # keyerror is raised when `subkey` names an object (in the object storage sense), + # as opposed to a prefix, in the store under the prefix associated with this group + # in which case `subkey` cannot be the name of a sub-array or sub-group. + logger.warning( + "Object at %s is not recognized as a component of a Zarr hierarchy.", subkey + ) + pass - async def array_keys(self) -> AsyncIterator[str]: + async def contains(self, member: str) -> bool: raise NotImplementedError + # todo: decide if this method should be separate from `groups` + async def group_keys(self) -> AsyncGenerator[str, None]: + async for key, value in self.members(): + if isinstance(value, AsyncGroup): + yield key + + # todo: decide if this method should be separate from `group_keys` + async def groups(self) -> AsyncGenerator[AsyncGroup, None]: + async for key, value in self.members(): + if isinstance(value, AsyncGroup): + yield value + + # todo: decide if this method should be separate from `arrays` + async def array_keys(self) -> AsyncGenerator[str, None]: + async for key, value in self.members(): + if isinstance(value, AsyncArray): + yield key + + # todo: decide if this method should be separate from `array_keys` async def arrays(self) -> AsyncIterator[AsyncArray]: - raise NotImplementedError + async for key, value in self.members(): + if isinstance(value, AsyncArray): + yield value async def tree(self, expand=False, level=None) -> Any: raise NotImplementedError @@ -331,7 +391,7 @@ def create( cls, store: StoreLike, *, - attributes: Optional[Dict[str, Any]] = None, + attributes: dict[str, Any] = {}, exists_ok: bool = False, runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(), ) -> Group: @@ -358,7 +418,7 @@ def open( ) return cls(obj) - def __getitem__(self, path: str) -> Union[Array, Group]: + def __getitem__(self, path: str) -> Array | Group: obj = self._sync(self._async_group.getitem(path)) if isinstance(obj, AsyncArray): return Array(obj) @@ -378,7 +438,7 @@ def __setitem__(self, key, value): """__setitem__ is not supported in v3""" raise NotImplementedError - async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group: + async def update_attributes_async(self, new_attributes: dict[str, Any]) -> Group: new_metadata = replace(self.metadata, attributes=new_attributes) # Write new metadata @@ -389,6 +449,10 @@ async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group async_group = replace(self._async_group, metadata=new_metadata) return replace(self, _async_group=async_group) + @property + def store_path(self) -> StorePath: + return self._async_group.store_path + @property def metadata(self) -> GroupMetadata: return self._async_group.metadata @@ -401,50 +465,54 @@ def attrs(self) -> Attributes: def info(self): return self._async_group.info - @property - def store_path(self) -> StorePath: - return self._async_group.store_path - - def update_attributes(self, new_attributes: Dict[str, Any]): + def update_attributes(self, new_attributes: dict[str, Any]): self._sync(self._async_group.update_attributes(new_attributes)) return self @property - def nchildren(self) -> int: - return self._sync(self._async_group.nchildren()) + def nmembers(self) -> int: + return self._sync(self._async_group.nmembers()) @property - def children(self) -> List[Union[Array, Group]]: - raise NotImplementedError - # Uncomment with AsyncGroup implements this method - # _children: List[Union[AsyncArray, AsyncGroup]] = self._sync_iter( - # self._async_group.children() - # ) - # return [Array(obj) if isinstance(obj, AsyncArray) else Group(obj) for obj in _children] + def members(self) -> tuple[tuple[str, Array | Group], ...]: + """ + Return the sub-arrays and sub-groups of this group as a `tuple` of (name, array | group) + pairs + """ + _members: list[tuple[str, AsyncArray | AsyncGroup]] = self._sync_iter( + self._async_group.members() + ) + ret: list[tuple[str, Array | Group]] = [] + for key, value in _members: + if isinstance(value, AsyncArray): + ret.append((key, Array(value))) + else: + ret.append((key, Group(value))) + return tuple(ret) - def __contains__(self, child) -> bool: - return self._sync(self._async_group.contains(child)) + def __contains__(self, member) -> bool: + return self._sync(self._async_group.contains(member)) - def group_keys(self) -> List[str]: - raise NotImplementedError + def group_keys(self) -> list[str]: # uncomment with AsyncGroup implements this method # return self._sync_iter(self._async_group.group_keys()) + raise NotImplementedError - def groups(self) -> List[Group]: + def groups(self) -> list[Group]: # TODO: in v2 this was a generator that return key: Group - raise NotImplementedError # uncomment with AsyncGroup implements this method # return [Group(obj) for obj in self._sync_iter(self._async_group.groups())] + raise NotImplementedError - def array_keys(self) -> List[str]: + def array_keys(self) -> list[str]: # uncomment with AsyncGroup implements this method - # return self._sync_iter(self._async_group.array_keys()) + # return self._sync_iter(self._async_group.array_keys) raise NotImplementedError - def arrays(self) -> List[Array]: - raise NotImplementedError + def arrays(self) -> list[Array]: # uncomment with AsyncGroup implements this method - # return [Array(obj) for obj in self._sync_iter(self._async_group.arrays())] + # return [Array(obj) for obj in self._sync_iter(self._async_group.arrays)] + raise NotImplementedError def tree(self, expand=False, level=None) -> Any: return self._sync(self._async_group.tree(expand=expand, level=level)) diff --git a/tests/v2/conftest.py b/tests/v2/conftest.py index a7a445c640..c84cdfa439 100644 --- a/tests/v2/conftest.py +++ b/tests/v2/conftest.py @@ -1,5 +1,4 @@ import pathlib - import pytest diff --git a/tests/v3/conftest.py b/tests/v3/conftest.py new file mode 100644 index 0000000000..3dc55c0298 --- /dev/null +++ b/tests/v3/conftest.py @@ -0,0 +1,32 @@ +import pathlib +import pytest + +from zarr.store import LocalStore, StorePath, MemoryStore, RemoteStore + + +@pytest.fixture(params=[str, pathlib.Path]) +def path_type(request): + return request.param + + +# todo: harmonize this with local_store fixture +@pytest.fixture +def store_path(tmpdir): + store = LocalStore(str(tmpdir)) + p = StorePath(store) + return p + + +@pytest.fixture(scope="function") +def local_store(tmpdir): + return LocalStore(str(tmpdir)) + + +@pytest.fixture(scope="function") +def remote_store(): + return RemoteStore() + + +@pytest.fixture(scope="function") +def memory_store(): + return MemoryStore() diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index 1150469db1..941256bdd2 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -1,3 +1,10 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +from zarr.sync import sync + +if TYPE_CHECKING: + from zarr.store import MemoryStore, LocalStore import pytest import numpy as np @@ -6,21 +13,60 @@ from zarr.config import RuntimeConfiguration -@pytest.fixture -def store_path(tmpdir): - store = LocalStore(str(tmpdir)) - p = StorePath(store) - return p +# todo: put RemoteStore in here +@pytest.mark.parametrize("store_type", ("local_store", "memory_store")) +def test_group_members(store_type, request): + """ + Test that `Group.members` returns correct values, i.e. the arrays and groups + (explicit and implicit) contained in that group. + """ + + store: LocalStore | MemoryStore = request.getfixturevalue(store_type) + path = "group" + agroup = AsyncGroup( + metadata=GroupMetadata(), + store_path=StorePath(store=store, path=path), + ) + group = Group(agroup) + members_expected = {} + members_expected["subgroup"] = group.create_group("subgroup") + # make a sub-sub-subgroup, to ensure that the children calculation doesn't go + # too deep in the hierarchy + _ = members_expected["subgroup"].create_group("subsubgroup") -def test_group(store_path) -> None: + members_expected["subarray"] = group.create_array( + "subarray", shape=(100,), dtype="uint8", chunk_shape=(10,), exists_ok=True + ) + + # add an extra object to the domain of the group. + # the list of children should ignore this object. + sync(store.set(f"{path}/extra_object", b"000000")) + # add an extra object under a directory-like prefix in the domain of the group. + # this creates an implicit group called implicit_subgroup + sync(store.set(f"{path}/implicit_subgroup/extra_object", b"000000")) + # make the implicit subgroup + members_expected["implicit_subgroup"] = Group( + AsyncGroup( + metadata=GroupMetadata(), + store_path=StorePath(store=store, path=f"{path}/implicit_subgroup"), + ) + ) + members_observed = group.members + # members are not guaranteed to be ordered, so sort before comparing + assert sorted(dict(members_observed)) == sorted(members_expected) + + +@pytest.mark.parametrize("store_type", (("local_store",))) +def test_group(store_type, request) -> None: + store = request.getfixturevalue(store_type) + store_path = StorePath(store) agroup = AsyncGroup( metadata=GroupMetadata(), store_path=store_path, runtime_configuration=RuntimeConfiguration(), ) group = Group(agroup) - assert agroup.metadata is group.metadata # create two groups