Skip to content

Commit

Permalink
Fixed consolidated Group getitem with multi-part key (#2363)
Browse files Browse the repository at this point in the history
* Fixed consolidated Group getitem with multi-part key

This fixes `Group.__getitem__` when indexing with a key
like 'subgroup/array'. The basic idea is to rewrite the indexing
operation as `group['subgroup']['array']` by splitting the key
and doing each operation independently.

Closes #2358

---------

Co-authored-by: Joe Hamman <joe@earthmover.io>
  • Loading branch information
TomAugspurger and jhamman authored Oct 17, 2024
1 parent 3a7426f commit 4d663cc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 39 deletions.
88 changes: 50 additions & 38 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,10 @@ def _from_bytes_v2(

@classmethod
def _from_bytes_v3(
cls, store_path: StorePath, zarr_json_bytes: Buffer, use_consolidated: bool | None
cls,
store_path: StorePath,
zarr_json_bytes: Buffer,
use_consolidated: bool | None,
) -> AsyncGroup:
group_metadata = json.loads(zarr_json_bytes.to_bytes())
if use_consolidated and group_metadata.get("consolidated_metadata") is None:
Expand Down Expand Up @@ -666,14 +669,33 @@ def _getitem_consolidated(
# the caller needs to verify this!
assert self.metadata.consolidated_metadata is not None

try:
metadata = self.metadata.consolidated_metadata.metadata[key]
except KeyError as e:
# The Group Metadata has consolidated metadata, but the key
# isn't present. We trust this to mean that the key isn't in
# the hierarchy, and *don't* fall back to checking the store.
msg = f"'{key}' not found in consolidated metadata."
raise KeyError(msg) from e
# we support nested getitems like group/subgroup/array
indexers = key.split("/")
indexers.reverse()
metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata = self.metadata

while indexers:
indexer = indexers.pop()
if isinstance(metadata, ArrayV2Metadata | ArrayV3Metadata):
# we've indexed into an array with group["array/subarray"]. Invalid.
raise KeyError(key)
if metadata.consolidated_metadata is None:
# we've indexed into a group without consolidated metadata.
# This isn't normal; typically, consolidated metadata
# will include explicit markers for when there are no child
# nodes as metadata={}.
# We have some freedom in exactly how we interpret this case.
# For now, we treat None as the same as {}, i.e. we don't
# have any children.
raise KeyError(key)
try:
metadata = metadata.consolidated_metadata.metadata[indexer]
except KeyError as e:
# The Group Metadata has consolidated metadata, but the key
# isn't present. We trust this to mean that the key isn't in
# the hierarchy, and *don't* fall back to checking the store.
msg = f"'{key}' not found in consolidated metadata."
raise KeyError(msg) from e

# update store_path to ensure that AsyncArray/Group.name is correct
if prefix != "/":
Expand Down Expand Up @@ -932,11 +954,7 @@ async def create_array(

@deprecated("Use AsyncGroup.create_array instead.")
async def create_dataset(
self,
name: str,
*,
shape: ShapeLike,
**kwargs: Any,
self, name: str, **kwargs: Any
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
"""Create an array.
Expand All @@ -947,8 +965,6 @@ async def create_dataset(
----------
name : str
Array name.
shape : int or tuple of ints
Array shape.
kwargs : dict
Additional arguments passed to :func:`zarr.AsyncGroup.create_array`.
Expand All @@ -959,7 +975,7 @@ async def create_dataset(
.. deprecated:: 3.0.0
The h5py compatibility methods will be removed in 3.1.0. Use `AsyncGroup.create_array` instead.
"""
return await self.create_array(name, shape=shape, **kwargs)
return await self.create_array(name, **kwargs)

@deprecated("Use AsyncGroup.require_array instead.")
async def require_dataset(
Expand Down Expand Up @@ -1081,6 +1097,8 @@ async def nmembers(
-------
count : int
"""
# check if we can use consolidated metadata, which requires that we have non-None
# consolidated metadata at all points in the hierarchy.
if self.metadata.consolidated_metadata is not None:
return len(self.metadata.consolidated_metadata.flattened_metadata)
# TODO: consider using aioitertools.builtins.sum for this
Expand All @@ -1094,7 +1112,8 @@ async def members(
self,
max_depth: int | None = 0,
) -> AsyncGenerator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
"""
Returns an AsyncGenerator over the arrays and groups contained in this group.
Expand Down Expand Up @@ -1125,12 +1144,12 @@ async def members(
async def _members(
self, max_depth: int | None, current_depth: int
) -> AsyncGenerator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
if self.metadata.consolidated_metadata is not None:
# we should be able to do members without any additional I/O
members = self._members_consolidated(max_depth, current_depth)

for member in members:
yield member
return
Expand Down Expand Up @@ -1186,7 +1205,8 @@ async def _members(
def _members_consolidated(
self, max_depth: int | None, current_depth: int, prefix: str = ""
) -> Generator[
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup], None
tuple[str, AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup],
None,
]:
consolidated_metadata = self.metadata.consolidated_metadata

Expand Down Expand Up @@ -1271,7 +1291,11 @@ async def full(
self, *, name: str, shape: ChunkCoords, fill_value: Any | None, **kwargs: Any
) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]:
return await async_api.full(
shape=shape, fill_value=fill_value, store=self.store_path, path=name, **kwargs
shape=shape,
fill_value=fill_value,
store=self.store_path,
path=name,
**kwargs,
)

async def empty_like(
Expand Down Expand Up @@ -1627,13 +1651,7 @@ def create_dataset(self, name: str, **kwargs: Any) -> Array:
return Array(self._sync(self._async_group.create_dataset(name, **kwargs)))

@deprecated("Use Group.require_array instead.")
def require_dataset(
self,
name: str,
*,
shape: ShapeLike,
**kwargs: Any,
) -> Array:
def require_dataset(self, name: str, **kwargs: Any) -> Array:
"""Obtain an array, creating if it doesn't exist.
Arrays are known as "datasets" in HDF5 terminology. For compatibility
Expand All @@ -1660,15 +1678,9 @@ def require_dataset(
.. deprecated:: 3.0.0
The h5py compatibility methods will be removed in 3.1.0. Use `Group.require_array` instead.
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
return Array(self._sync(self._async_group.require_array(name, **kwargs)))

def require_array(
self,
name: str,
*,
shape: ShapeLike,
**kwargs: Any,
) -> Array:
def require_array(self, name: str, **kwargs: Any) -> Array:
"""Obtain an array, creating if it doesn't exist.
Expand All @@ -1690,7 +1702,7 @@ def require_array(
-------
a : Array
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))
return Array(self._sync(self._async_group.require_array(name, **kwargs)))

def empty(self, *, name: str, shape: ChunkCoords, **kwargs: Any) -> Array:
return Array(self._sync(self._async_group.empty(name=name, shape=shape, **kwargs)))
Expand Down
52 changes: 51 additions & 1 deletion tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,18 +306,53 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat, consolidated: bool
group = Group.from_store(store, zarr_format=zarr_format)
subgroup = group.create_group(name="subgroup")
subarray = group.create_array(name="subarray", shape=(10,), chunk_shape=(10,))
subsubarray = subgroup.create_array(name="subarray", shape=(10,), chunk_shape=(10,))

if consolidated:
group = zarr.api.synchronous.consolidate_metadata(store=store, zarr_format=zarr_format)
# we're going to assume that `group.metadata` is correct, and reuse that to focus
# on indexing in this test. Other tests verify the correctness of group.metadata
object.__setattr__(
subgroup.metadata, "consolidated_metadata", ConsolidatedMetadata(metadata={})
subgroup.metadata,
"consolidated_metadata",
ConsolidatedMetadata(
metadata={"subarray": group.metadata.consolidated_metadata.metadata["subarray"]}
),
)

assert group["subgroup"] == subgroup
assert group["subarray"] == subarray
assert group["subgroup"]["subarray"] == subsubarray
assert group["subgroup/subarray"] == subsubarray

with pytest.raises(KeyError):
group["nope"]

with pytest.raises(KeyError, match="subarray/subsubarray"):
group["subarray/subsubarray"]

# Now test the mixed case
if consolidated:
object.__setattr__(
group.metadata.consolidated_metadata.metadata["subgroup"],
"consolidated_metadata",
None,
)

# test the implementation directly
with pytest.raises(KeyError):
group._async_group._getitem_consolidated(
group.store_path, "subgroup/subarray", prefix="/"
)

with pytest.raises(KeyError):
# We've chosen to trust the consolidted metadata, which doesn't
# contain this array
group["subgroup/subarray"]

with pytest.raises(KeyError, match="subarray/subsubarray"):
group["subarray/subsubarray"]


def test_group_get_with_default(store: Store, zarr_format: ZarrFormat) -> None:
group = Group.from_store(store, zarr_format=zarr_format)
Expand Down Expand Up @@ -1008,6 +1043,21 @@ async def test_group_members_async(store: Store, consolidated_metadata: bool) ->
with pytest.raises(ValueError, match="max_depth"):
[x async for x in group.members(max_depth=-1)]

if consolidated_metadata:
# test for mixed known and unknown metadata.
# For now, we trust the consolidated metadata.
object.__setattr__(
group.metadata.consolidated_metadata.metadata["g0"].consolidated_metadata.metadata[
"g1"
],
"consolidated_metadata",
None,
)
all_children = sorted([x async for x in group.members(max_depth=None)], key=lambda x: x[0])
assert len(all_children) == 4
nmembers = await group.nmembers(max_depth=None)
assert nmembers == 4


async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.from_store(store=store, zarr_format=zarr_format)
Expand Down

0 comments on commit 4d663cc

Please sign in to comment.