Skip to content

Commit

Permalink
implement group.members (#1726)
Browse files Browse the repository at this point in the history
* feat: functional .children method for groups

* changes necessary for correctly generating list of children

* add stand-alone test for group.children

* give type hints a glow-up

* test: use separate assert statements to avoid platform-dependent ordering issues

* test: put fixtures in conftest, add MemoryStore fixture

* docs: release notes

* test: remove prematurely-added mock s3 fixture

* fix: Rename children to members; AsyncGroup.members yields tuples of (name, AsyncArray / AsyncGroup) pairs; Group.members repackages these into a dict.

* fix: make Group.members return a tuple of str, Array | Group pairs

* fix: revert changes to synchronization code; this is churn that we need to deal with

* make mypy happy

* feat: implement member-specific iteration methods in asyncgroup

* chore: clean up some post-merge issues

* chore: remove extra directory added by test code

---------

Co-authored-by: Joseph Hamman <joe@earthmover.io>
  • Loading branch information
d-v-b and jhamman authored Apr 24, 2024
1 parent 0f755cc commit 57d6ace
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 64 deletions.
6 changes: 6 additions & 0 deletions docs/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ Release notes
Unreleased (v3)
---------------

Enhancements
~~~~~~~~~~~~

* Implement listing of the sub-arrays and sub-groups for a V3 ``Group``.
By :user:`Davis Bennett <d-v-b>` :issue:`1726`.

Maintenance
~~~~~~~~~~~

Expand Down
180 changes: 124 additions & 56 deletions src/zarr/group.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from dataclasses import asdict, dataclass, field, replace

import asyncio
import json
import logging
from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, List

if TYPE_CHECKING:
from typing import (
Any,
AsyncGenerator,
Literal,
AsyncIterator,
)
from zarr.abc.metadata import Metadata

from zarr.array import AsyncArray, Array
Expand All @@ -25,7 +33,7 @@ def parse_zarr_format(data: Any) -> Literal[2, 3]:


# todo: convert None to empty dict
def parse_attributes(data: Any) -> Dict[str, Any]:
def parse_attributes(data: Any) -> dict[str, Any]:
if data is None:
return {}
elif isinstance(data, dict) and all(map(lambda v: isinstance(v, str), data.keys())):
Expand All @@ -36,12 +44,12 @@ def parse_attributes(data: Any) -> Dict[str, Any]:

@dataclass(frozen=True)
class GroupMetadata(Metadata):
attributes: Dict[str, Any] = field(default_factory=dict)
attributes: dict[str, Any] = field(default_factory=dict)
zarr_format: Literal[2, 3] = 3
node_type: Literal["group"] = field(default="group", init=False)

# todo: rename this, since it doesn't return bytes
def to_bytes(self) -> Dict[str, bytes]:
def to_bytes(self) -> dict[str, bytes]:
if self.zarr_format == 3:
return {ZARR_JSON: json.dumps(self.to_dict()).encode()}
else:
Expand All @@ -50,34 +58,34 @@ def to_bytes(self) -> Dict[str, bytes]:
ZATTRS_JSON: json.dumps(self.attributes).encode(),
}

def __init__(self, attributes: Optional[Dict[str, Any]] = None, zarr_format: Literal[2, 3] = 3):
def __init__(self, attributes: dict[str, Any] | None = None, zarr_format: Literal[2, 3] = 3):
attributes_parsed = parse_attributes(attributes)
zarr_format_parsed = parse_zarr_format(zarr_format)

object.__setattr__(self, "attributes", attributes_parsed)
object.__setattr__(self, "zarr_format", zarr_format_parsed)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> GroupMetadata:
def from_dict(cls, data: dict[str, Any]) -> GroupMetadata:
assert data.pop("node_type", None) in ("group", None)
return cls(**data)

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return asdict(self)


@dataclass(frozen=True)
class AsyncGroup:
metadata: GroupMetadata
store_path: StorePath
runtime_configuration: RuntimeConfiguration
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration()

@classmethod
async def create(
cls,
store: StoreLike,
*,
attributes: Optional[Dict[str, Any]] = None,
attributes: dict[str, Any] = {},
exists_ok: bool = False,
zarr_format: Literal[2, 3] = 3,
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
Expand All @@ -89,7 +97,7 @@ async def create(
elif zarr_format == 2:
assert not await (store_path / ZGROUP_JSON).exists()
group = cls(
metadata=GroupMetadata(attributes=attributes or {}, zarr_format=zarr_format),
metadata=GroupMetadata(attributes=attributes, zarr_format=zarr_format),
store_path=store_path,
runtime_configuration=runtime_configuration,
)
Expand Down Expand Up @@ -137,7 +145,7 @@ async def open(
def from_dict(
cls,
store_path: StorePath,
data: Dict[str, Any],
data: dict[str, Any],
runtime_configuration: RuntimeConfiguration,
) -> AsyncGroup:
group = cls(
Expand All @@ -150,14 +158,24 @@ def from_dict(
async def getitem(
self,
key: str,
) -> Union[AsyncArray, AsyncGroup]:
) -> AsyncArray | AsyncGroup:
store_path = self.store_path / key

# Note:
# in zarr-python v2, we first check if `key` references an Array, else if `key` references
# a group,using standalone `contains_array` and `contains_group` functions. These functions
# are reusable, but for v3 they would perform redundant I/O operations.
# Not clear how much of that strategy we want to keep here.

# if `key` names an object in storage, it cannot be an array or group
if await store_path.exists():
raise KeyError(key)

if self.metadata.zarr_format == 3:
zarr_json_bytes = await (store_path / ZARR_JSON).get()
if zarr_json_bytes is None:
# implicit group?
logger.warning("group at {} is an implicit group", store_path)
logger.warning("group at %s is an implicit group", store_path)
zarr_json = {
"zarr_format": self.metadata.zarr_format,
"node_type": "group",
Expand Down Expand Up @@ -196,7 +214,7 @@ async def getitem(
else:
if zgroup_bytes is None:
# implicit group?
logger.warning("group at {} is an implicit group", store_path)
logger.warning("group at %s is an implicit group", store_path)
zgroup = (
json.loads(zgroup_bytes)
if zgroup_bytes is not None
Expand Down Expand Up @@ -248,7 +266,7 @@ async def create_array(self, path: str, **kwargs) -> AsyncArray:
**kwargs,
)

async def update_attributes(self, new_attributes: Dict[str, Any]):
async def update_attributes(self, new_attributes: dict[str, Any]):
# metadata.attributes is "frozen" so we simply clear and update the dict
self.metadata.attributes.clear()
self.metadata.attributes.update(new_attributes)
Expand All @@ -269,26 +287,68 @@ async def update_attributes(self, new_attributes: Dict[str, Any]):
def __repr__(self):
return f"<AsyncGroup {self.store_path}>"

async def nchildren(self) -> int:
raise NotImplementedError

async def children(self) -> AsyncIterator[Union[AsyncArray, AsyncGroup]]:
raise NotImplementedError

async def contains(self, child: str) -> bool:
async def nmembers(self) -> int:
raise NotImplementedError

async def group_keys(self) -> AsyncIterator[str]:
raise NotImplementedError
async def members(self) -> AsyncGenerator[tuple[str, AsyncArray | AsyncGroup], None]:
"""
Returns an AsyncGenerator over the arrays and groups contained in this group.
This method requires that `store_path.store` supports directory listing.
The results are not guaranteed to be ordered.
"""
if not self.store_path.store.supports_listing:
msg = (
f"The store associated with this group ({type(self.store_path.store)}) "
"does not support listing, "
"specifically via the `list_dir` method. "
"This function requires a store that supports listing."
)

async def groups(self) -> AsyncIterator[AsyncGroup]:
raise NotImplementedError
raise ValueError(msg)
subkeys = await self.store_path.store.list_dir(self.store_path.path)
# would be nice to make these special keys accessible programmatically,
# and scoped to specific zarr versions
subkeys_filtered = filter(lambda v: v not in ("zarr.json", ".zgroup", ".zattrs"), subkeys)
# is there a better way to schedule this?
for subkey in subkeys_filtered:
try:
yield (subkey, await self.getitem(subkey))
except KeyError:
# keyerror is raised when `subkey` names an object (in the object storage sense),
# as opposed to a prefix, in the store under the prefix associated with this group
# in which case `subkey` cannot be the name of a sub-array or sub-group.
logger.warning(
"Object at %s is not recognized as a component of a Zarr hierarchy.", subkey
)
pass

async def array_keys(self) -> AsyncIterator[str]:
async def contains(self, member: str) -> bool:
raise NotImplementedError

# todo: decide if this method should be separate from `groups`
async def group_keys(self) -> AsyncGenerator[str, None]:
async for key, value in self.members():
if isinstance(value, AsyncGroup):
yield key

# todo: decide if this method should be separate from `group_keys`
async def groups(self) -> AsyncGenerator[AsyncGroup, None]:
async for key, value in self.members():
if isinstance(value, AsyncGroup):
yield value

# todo: decide if this method should be separate from `arrays`
async def array_keys(self) -> AsyncGenerator[str, None]:
async for key, value in self.members():
if isinstance(value, AsyncArray):
yield key

# todo: decide if this method should be separate from `array_keys`
async def arrays(self) -> AsyncIterator[AsyncArray]:
raise NotImplementedError
async for key, value in self.members():
if isinstance(value, AsyncArray):
yield value

async def tree(self, expand=False, level=None) -> Any:
raise NotImplementedError
Expand Down Expand Up @@ -331,7 +391,7 @@ def create(
cls,
store: StoreLike,
*,
attributes: Optional[Dict[str, Any]] = None,
attributes: dict[str, Any] = {},
exists_ok: bool = False,
runtime_configuration: RuntimeConfiguration = RuntimeConfiguration(),
) -> Group:
Expand All @@ -358,7 +418,7 @@ def open(
)
return cls(obj)

def __getitem__(self, path: str) -> Union[Array, Group]:
def __getitem__(self, path: str) -> Array | Group:
obj = self._sync(self._async_group.getitem(path))
if isinstance(obj, AsyncArray):
return Array(obj)
Expand All @@ -378,7 +438,7 @@ def __setitem__(self, key, value):
"""__setitem__ is not supported in v3"""
raise NotImplementedError

async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group:
async def update_attributes_async(self, new_attributes: dict[str, Any]) -> Group:
new_metadata = replace(self.metadata, attributes=new_attributes)

# Write new metadata
Expand All @@ -389,6 +449,10 @@ async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group
async_group = replace(self._async_group, metadata=new_metadata)
return replace(self, _async_group=async_group)

@property
def store_path(self) -> StorePath:
return self._async_group.store_path

@property
def metadata(self) -> GroupMetadata:
return self._async_group.metadata
Expand All @@ -401,50 +465,54 @@ def attrs(self) -> Attributes:
def info(self):
return self._async_group.info

@property
def store_path(self) -> StorePath:
return self._async_group.store_path

def update_attributes(self, new_attributes: Dict[str, Any]):
def update_attributes(self, new_attributes: dict[str, Any]):
self._sync(self._async_group.update_attributes(new_attributes))
return self

@property
def nchildren(self) -> int:
return self._sync(self._async_group.nchildren())
def nmembers(self) -> int:
return self._sync(self._async_group.nmembers())

@property
def children(self) -> List[Union[Array, Group]]:
raise NotImplementedError
# Uncomment with AsyncGroup implements this method
# _children: List[Union[AsyncArray, AsyncGroup]] = self._sync_iter(
# self._async_group.children()
# )
# return [Array(obj) if isinstance(obj, AsyncArray) else Group(obj) for obj in _children]
def members(self) -> tuple[tuple[str, Array | Group], ...]:
"""
Return the sub-arrays and sub-groups of this group as a `tuple` of (name, array | group)
pairs
"""
_members: list[tuple[str, AsyncArray | AsyncGroup]] = self._sync_iter(
self._async_group.members()
)
ret: list[tuple[str, Array | Group]] = []
for key, value in _members:
if isinstance(value, AsyncArray):
ret.append((key, Array(value)))
else:
ret.append((key, Group(value)))
return tuple(ret)

def __contains__(self, child) -> bool:
return self._sync(self._async_group.contains(child))
def __contains__(self, member) -> bool:
return self._sync(self._async_group.contains(member))

def group_keys(self) -> List[str]:
raise NotImplementedError
def group_keys(self) -> list[str]:
# uncomment with AsyncGroup implements this method
# return self._sync_iter(self._async_group.group_keys())
raise NotImplementedError

def groups(self) -> List[Group]:
def groups(self) -> list[Group]:
# TODO: in v2 this was a generator that return key: Group
raise NotImplementedError
# uncomment with AsyncGroup implements this method
# return [Group(obj) for obj in self._sync_iter(self._async_group.groups())]
raise NotImplementedError

def array_keys(self) -> List[str]:
def array_keys(self) -> list[str]:
# uncomment with AsyncGroup implements this method
# return self._sync_iter(self._async_group.array_keys())
# return self._sync_iter(self._async_group.array_keys)
raise NotImplementedError

def arrays(self) -> List[Array]:
raise NotImplementedError
def arrays(self) -> list[Array]:
# uncomment with AsyncGroup implements this method
# return [Array(obj) for obj in self._sync_iter(self._async_group.arrays())]
# return [Array(obj) for obj in self._sync_iter(self._async_group.arrays)]
raise NotImplementedError

def tree(self, expand=False, level=None) -> Any:
return self._sync(self._async_group.tree(expand=expand, level=level))
Expand Down
1 change: 0 additions & 1 deletion tests/v2/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pathlib

import pytest


Expand Down
32 changes: 32 additions & 0 deletions tests/v3/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pathlib
import pytest

from zarr.store import LocalStore, StorePath, MemoryStore, RemoteStore


@pytest.fixture(params=[str, pathlib.Path])
def path_type(request):
return request.param


# todo: harmonize this with local_store fixture
@pytest.fixture
def store_path(tmpdir):
store = LocalStore(str(tmpdir))
p = StorePath(store)
return p


@pytest.fixture(scope="function")
def local_store(tmpdir):
return LocalStore(str(tmpdir))


@pytest.fixture(scope="function")
def remote_store():
return RemoteStore()


@pytest.fixture(scope="function")
def memory_store():
return MemoryStore()
Loading

0 comments on commit 57d6ace

Please sign in to comment.