Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed consolidated Group getitem with multi-part key #2363

Merged
88 changes: 50 additions & 38 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,10 @@ def _from_bytes_v2(

@classmethod
def _from_bytes_v3(
cls, store_path: StorePath, zarr_json_bytes: Buffer, use_consolidated: bool | None
cls,
store_path: StorePath,
zarr_json_bytes: Buffer,
use_consolidated: bool | None,
) -> AsyncGroup:
group_metadata = json.loads(zarr_json_bytes.to_bytes())
if use_consolidated and group_metadata.get("consolidated_metadata") is None:
Expand Down Expand Up @@ -666,14 +669,33 @@ def _getitem_consolidated(
# the caller needs to verify this!
assert self.metadata.consolidated_metadata is not None

try:
metadata = self.metadata.consolidated_metadata.metadata[key]
except KeyError as e:
# The Group Metadata has consolidated metadata, but the key
# isn't present. We trust this to mean that the key isn't in
# the hierarchy, and *don't* fall back to checking the store.
msg = f"'{key}' not found in consolidated metadata."
raise KeyError(msg) from e
# we support nested getitems like group/subgroup/array
indexers = key.split("/")
indexers.reverse()
metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata = self.metadata

while indexers:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't the flattened form of the consolidated metadata make this a lot simpler?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe the problem with that is that it requires that the consolidated metadata be complete? whereas the iterative approach can handle a group with 'live' metadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this earlier. Flattened metadata would make this specific section simpler, but I think would complicate things a later since we'd when we need to "unflatten" it to put all of its children in its consolidated_metadata. Doable, but not obviously simpler in the end.

indexer = indexers.pop()
if isinstance(metadata, ArrayV2Metadata | ArrayV3Metadata):
# we've indexed into an array with group["array/subarray"]. Invalid.
raise KeyError(key)
if metadata.consolidated_metadata is None:
# we've indexed into a group without consolidated metadata.
# This isn't normal; typically, consolidated metadata
# will include explicit markers for when there are no child
# nodes as metadata={}.
# We have some freedom in exactly how we interpret this case.
# For now, we treat None as the same as {}, i.e. we don't
# have any children.
raise KeyError(key)
try:
metadata = metadata.consolidated_metadata.metadata[indexer]
except KeyError as e:
# The Group Metadata has consolidated metadata, but the key
# isn't present. We trust this to mean that the key isn't in
# the hierarchy, and *don't* fall back to checking the store.
msg = f"'{key}' not found in consolidated metadata."
raise KeyError(msg) from e

# update store_path to ensure that AsyncArray/Group.name is correct
if prefix != "/":
Expand Down Expand Up @@ -932,11 +954,7 @@ async def create_array(

@deprecated("Use AsyncGroup.create_array instead.")
async def create_dataset(
self,
name: str,
*,
shape: ShapeLike,
**kwargs: Any,
self, name: str, **kwargs: Any
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
"""Create an array.

Expand All @@ -947,8 +965,6 @@ async def create_dataset(
----------
name : str
Array name.
shape : int or tuple of ints
Array shape.
kwargs : dict
Additional arguments passed to :func:`zarr.AsyncGroup.create_array`.

Expand All @@ -959,7 +975,7 @@ async def create_dataset(
.. deprecated:: 3.0.0
The h5py compatibility methods will be removed in 3.1.0. Use `AsyncGroup.create_array` instead.
"""
return await self.create_array(name, shape=shape, **kwargs)
return await self.create_array(name, **kwargs)

@deprecated("Use AsyncGroup.require_array instead.")
async def require_dataset(
Expand Down Expand Up @@ -1081,6 +1097,8 @@ async def nmembers(
-------
count : int
"""
# check if we can use consolidated metadata, which requires that we have non-None
# consolidated metadata at all points in the hierarchy.
if self.metadata.consolidated_metadata is not None:
return len(self.metadata.consolidated_metadata.flattened_metadata)
# TODO: consider using aioitertools.builtins.sum for this
Expand All @@ -1094,7 +1112,8 @@ async def members(
self,
max_depth: int | None = 0,
) -> AsyncGenerator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
"""
Returns an AsyncGenerator over the arrays and groups contained in this group.
Expand Down Expand Up @@ -1125,12 +1144,12 @@ async def members(
async def _members(
self, max_depth: int | None, current_depth: int
) -> AsyncGenerator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
if self.metadata.consolidated_metadata is not None:
# we should be able to do members without any additional I/O
members = self._members_consolidated(max_depth, current_depth)

for member in members:
yield member
return
Expand Down Expand Up @@ -1186,7 +1205,8 @@ async def _members(
def _members_consolidated(
self, max_depth: int | None, current_depth: int, prefix: str = ""
) -> Generator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
consolidated_metadata = self.metadata.consolidated_metadata

Expand Down Expand Up @@ -1271,7 +1291,11 @@ async def full(
self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
return await async_api.full(
shape=shape, fill_value=fill_value, store=self.store_path, path=name, **kwargs
shape=shape,
fill_value=fill_value,
store=self.store_path,
path=name,
**kwargs,
)

async def empty_like(
Expand Down Expand Up @@ -1627,13 +1651,7 @@ def create_dataset(self, name: str, **kwargs: Any) -> Array:
return Array(self._sync(self._async_group.create_dataset(name, **kwargs)))

@deprecated("Use Group.require_array instead.")
def require_dataset(
self,
name: str,
*,
shape: ShapeLike,
**kwargs: Any,
) -> Array:
def require_dataset(self, name: str, **kwargs: Any) -> Array:
"""Obtain an array, creating if it doesn't exist.

Arrays are known as "datasets" in HDF5 terminology. For compatibility
Expand All @@ -1660,15 +1678,9 @@ def require_dataset(
.. deprecated:: 3.0.0
The h5py compatibility methods will be removed in 3.1.0. Use `Group.require_array` instead.
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
return Array(self._sync(self._async_group.require_array(name, **kwargs)))

def require_array(
self,
name: str,
*,
shape: ShapeLike,
**kwargs: Any,
) -> Array:
def require_array(self, name: str, **kwargs: Any) -> Array:
"""Obtain an array, creating if it doesn't exist.


Expand All @@ -1690,7 +1702,7 @@ def require_array(
-------
a : Array
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
return Array(self._sync(self._async_group.require_array(name, **kwargs)))

def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array:
return Array(self._sync(self._async_group.empty(name=name, shape=shape, **kwargs)))
Expand Down
52 changes: 51 additions & 1 deletion tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,18 +313,53 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat, consolidated: bool
group = Group.from_store(store, zarr_format=zarr_format)
subgroup = group.create_group(name="subgroup")
subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,))
subsubarray = subgroup.create_array(name="subarray", shape=(10,), chunk_shape=(10,))

if consolidated:
group = zarr.api.synchronous.consolidate_metadata(store=store, zarr_format=zarr_format)
# we're going to assume that `group.metadata` is correct, and reuse that to focus
# on indexing in this test. Other tests verify the correctness of group.metadata
object.__setattr__(
subgroup.metadata, "consolidated_metadata", ConsolidatedMetadata(metadata={})
subgroup.metadata,
"consolidated_metadata",
ConsolidatedMetadata(
metadata={"subarray": group.metadata.consolidated_metadata.metadata["subarray"]}
),
)

assert group["subgroup"] == subgroup
assert group["subarray"] == subarray
assert group["subgroup"]["subarray"] == subsubarray
assert group["subgroup/subarray"] == subsubarray

with pytest.raises(KeyError):
group["nope"]

with pytest.raises(KeyError, match="subarray/subsubarray"):
group["subarray/subsubarray"]

# Now test the mixed case
if consolidated:
object.__setattr__(
group.metadata.consolidated_metadata.metadata["subgroup"],
"consolidated_metadata",
None,
)

# test the implementation directly
with pytest.raises(KeyError):
group._async_group._getitem_consolidated(
group.store_path, "subgroup/subarray", prefix="/"
)

with pytest.raises(KeyError):
# We've chosen to trust the consolidted metadata, which doesn't
# contain this array
group["subgroup/subarray"]

with pytest.raises(KeyError, match="subarray/subsubarray"):
group["subarray/subsubarray"]


def test_group_get_with_default(store: Store, zarr_format: ZarrFormat) -> None:
group = Group.from_store(store, zarr_format=zarr_format)
Expand Down Expand Up @@ -1014,6 +1049,21 @@ async def test_group_members_async(store: Store, consolidated_metadata: bool) ->
with pytest.raises(ValueError, match="max_depth"):
[x async for x in group.members(max_depth=-1)]

if consolidated_metadata:
# test for mixed known and unknown metadata.
# For now, we trust the consolidated metadata.
object.__setattr__(
group.metadata.consolidated_metadata.metadata["g0"].consolidated_metadata.metadata[
"g1"
],
"consolidated_metadata",
None,
)
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
assert len(all_children) == 4
nmembers = await group.nmembers(max_depth=None)
assert nmembers == 4


async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.from_store(store=store, zarr_format=zarr_format)
Expand Down