From e492be23f8c36095a29902953e19a0892dd65c89 Mon Sep 17 00:00:00 2001 From: Davis Bennett Date: Tue, 26 Mar 2024 21:04:57 +0100 Subject: [PATCH] feat: functional .children method for groups --- src/zarr/v3/group.py | 37 ++++++++++++++++++++++++++++++++++--- tests/test_group_v3.py | 4 +++- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/zarr/v3/group.py b/src/zarr/v3/group.py index acd5ca0d62..4a84cbed80 100644 --- a/src/zarr/v3/group.py +++ b/src/zarr/v3/group.py @@ -4,7 +4,17 @@ import asyncio import json import logging -from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, Iterator, List +from typing import ( + Any, + AsyncGenerator, + Dict, + Literal, + Optional, + Union, + AsyncIterator, + Iterator, + List, +) from zarr.v3.abc.metadata import Metadata from zarr.v3.array import AsyncArray, Array @@ -271,8 +281,29 @@ def __repr__(self): async def nchildren(self) -> int: raise NotImplementedError - async def children(self) -> AsyncIterator[AsyncArray, AsyncGroup]: - raise NotImplementedError + async def children(self) -> AsyncGenerator[AsyncArray, AsyncGroup]: + """ + Returns an async iterator over the arrays and groups contained in this group. + """ + if not self.store_path.store.supports_listing: + msg = ( + f"The store associated with this group ({type(self.store_path.store)}) " + "does not support listing, " + "specifically the `list_dir` method. " + "This function requires a store that supports listing." + ) + + raise ValueError(msg) + subkeys = await self.store_path.store.list_dir(self.store_path.path) + # would be nice to make these special keys accessible programmatically, + # and scoped to specific zarr versions + subkeys_filtered = filter(lambda v: v not in ("zarr.json", ".zgroup", ".zattrs"), subkeys) + # might be smarter to wrap this in asyncio gather + for subkey in subkeys_filtered: + try: + yield await self.getitem(subkey) + except ValueError: + pass async def contains(self, child: str) -> bool: raise NotImplementedError diff --git a/tests/test_group_v3.py b/tests/test_group_v3.py index 1498d6779b..01859bb7ae 100644 --- a/tests/test_group_v3.py +++ b/tests/test_group_v3.py @@ -21,13 +21,15 @@ def test_group(store_path) -> None: runtime_configuration=RuntimeConfiguration(), ) group = Group(agroup) - assert agroup.metadata is group.metadata # create two groups foo = group.create_group("foo") bar = foo.create_group("bar", attributes={"baz": "qux"}) + # check that bar is in the children of foo + assert foo.children == [bar] + # create an array from the "bar" group data = np.arange(0, 4 * 4, dtype="uint16").reshape((4, 4)) arr = bar.create_array(