Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v3] h5py compat methods on Group #2128

Merged
merged 11 commits into from
Sep 4, 2024
1 change: 1 addition & 0 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ZATTRS_JSON = ".zattrs"

BytesLike = bytes | bytearray | memoryview
ShapeLike = tuple[int, ...] | int
ChunkCoords = tuple[int, ...]
ChunkCoordsLike = Iterable[int]
ZarrFormat = Literal[2, 3]
Expand Down
176 changes: 174 additions & 2 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import asdict, dataclass, field, replace
from typing import TYPE_CHECKING, Literal, cast, overload

import numpy as np
import numpy.typing as npt
from typing_extensions import deprecated

Expand All @@ -25,6 +26,7 @@
ZGROUP_JSON,
ChunkCoords,
ZarrFormat,
parse_shapelike,
)
from zarr.core.config import config
from zarr.core.sync import SyncMixin, sync
Expand Down Expand Up @@ -250,7 +252,7 @@ async def getitem(
if zarray is not None:
# TODO: update this once the V2 array support is part of the primary array class
zarr_json = {**zarray, "attributes": zattrs}
return AsyncArray.from_dict(store_path, zarray)
return AsyncArray.from_dict(store_path, zarr_json)
else:
zgroup = (
json.loads(zgroup_bytes.to_bytes())
Expand Down Expand Up @@ -324,6 +326,42 @@ async def create_group(
zarr_format=self.metadata.zarr_format,
)

async def require_group(self, name: str, overwrite: bool = False) -> AsyncGroup:
"""Obtain a sub-group, creating one if it doesn't exist.

Parameters
----------
name : string
Group name.
overwrite : bool, optional
Overwrite any existing group with given `name` if present.

Returns
-------
g : AsyncGroup
"""
if overwrite:
# TODO: check that exists_ok=True errors if an array exists where the group is being created
grp = await self.create_group(name, exists_ok=True)
else:
try:
item: AsyncGroup | AsyncArray = await self.getitem(name)
if not isinstance(item, AsyncGroup):
raise TypeError(
f"Incompatible object ({item.__class__.__name__}) already exists"
)
assert isinstance(item, AsyncGroup) # make mypy happy
grp = item
except KeyError:
grp = await self.create_group(name)
return grp

async def require_groups(self, *names: str) -> tuple[AsyncGroup, ...]:
"""Convenience method to require multiple groups in a single call."""
if not names:
return ()
return tuple(await asyncio.gather(*(self.require_group(name) for name in names)))

async def create_array(
self,
name: str,
Expand Down Expand Up @@ -413,6 +451,78 @@ async def create_array(
data=data,
)

async def create_dataset(self, name: str, **kwargs: Any) -> AsyncArray:
"""Create an array.

Arrays are known as "datasets" in HDF5 terminology. For compatibility
with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.require_dataset` method.

Parameters
----------
name : string
Array name.
kwargs : dict
Additional arguments passed to :func:`zarr.AsyncGroup.create_array`.

Returns
-------
a : AsyncArray
"""
return await self.create_array(name, **kwargs)

async def require_dataset(
self,
name: str,
*,
shape: ChunkCoords,
dtype: npt.DTypeLike = None,
exact: bool = False,
**kwargs: Any,
) -> AsyncArray:
"""Obtain an array, creating if it doesn't exist.

Arrays are known as "datasets" in HDF5 terminology. For compatibility
with h5py, Zarr groups also implement the :func:`zarr.AsyncGroup.create_dataset` method.

Other `kwargs` are as per :func:`zarr.AsyncGroup.create_dataset`.

Parameters
----------
name : string
Array name.
shape : int or tuple of ints
Array shape.
dtype : string or dtype, optional
NumPy dtype.
exact : bool, optional
If True, require `dtype` to match exactly. If false, require
`dtype` can be cast from array dtype.

Returns
-------
a : AsyncArray
"""
try:
ds = await self.getitem(name)
if not isinstance(ds, AsyncArray):
raise TypeError(f"Incompatible object ({ds.__class__.__name__}) already exists")

shape = parse_shapelike(shape)
if shape != ds.shape:
raise TypeError(f"Incompatible shape ({ds.shape} vs {shape})")

dtype = np.dtype(dtype)
if exact:
if ds.dtype != dtype:
raise TypeError(f"Incompatible dtype ({ds.dtype} vs {dtype})")
else:
if not np.can_cast(ds.dtype, dtype):
raise TypeError(f"Incompatible dtype ({ds.dtype} vs {dtype})")
except KeyError:
ds = await self.create_dataset(name, shape=shape, dtype=dtype, **kwargs)

return ds

async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup:
# metadata.attributes is "frozen" so we simply clear and update the dict
self.metadata.attributes.clear()
Expand Down Expand Up @@ -612,8 +722,9 @@ def create(
def open(
cls,
store: StoreLike,
zarr_format: Literal[2, 3, None] = 3,
) -> Group:
obj = sync(AsyncGroup.open(store))
obj = sync(AsyncGroup.open(store, zarr_format=zarr_format))
return cls(obj)

def __getitem__(self, path: str) -> Array | Group:
Expand Down Expand Up @@ -717,6 +828,26 @@ def tree(self, expand: bool = False, level: int | None = None) -> Any:
def create_group(self, name: str, **kwargs: Any) -> Group:
return Group(self._sync(self._async_group.create_group(name, **kwargs)))

def require_group(self, name: str, **kwargs: Any) -> Group:
"""Obtain a sub-group, creating one if it doesn't exist.

Parameters
----------
name : string
Group name.
overwrite : bool, optional
Overwrite any existing group with given `name` if present.

Returns
-------
g : Group
"""
return Group(self._sync(self._async_group.require_group(name, **kwargs)))

def require_groups(self, *names: str) -> tuple[Group, ...]:
"""Convenience method to require multiple groups in a single call."""
return tuple(map(Group, self._sync(self._async_group.require_groups(*names))))

def create_array(
self,
name: str,
Expand Down Expand Up @@ -811,6 +942,47 @@ def create_array(
)
)

def create_dataset(self, name: str, **kwargs: Any) -> Array:
"""Create an array.

Arrays are known as "datasets" in HDF5 terminology. For compatibility
with h5py, Zarr groups also implement the :func:`zarr.Group.require_dataset` method.

Parameters
----------
name : string
Array name.
kwargs : dict
Additional arguments passed to :func:`zarr.Group.create_array`

Returns
-------
a : Array
"""
return Array(self._sync(self._async_group.create_dataset(name, **kwargs)))

def require_dataset(self, name: str, **kwargs: Any) -> Array:
"""Obtain an array, creating if it doesn't exist.

Arrays are known as "datasets" in HDF5 terminology. For compatibility
with h5py, Zarr groups also implement the :func:`zarr.Group.create_dataset` method.

Other `kwargs` are as per :func:`zarr.Group.create_dataset`.

Parameters
----------
name : string
Array name.
shape : int or tuple of ints
Array shape.
dtype : string or dtype, optional
NumPy dtype.
exact : bool, optional
If True, require `dtype` to match exactly. If false, require
`dtype` can be cast from array dtype.
"""
return Array(self._sync(self._async_group.require_dataset(name, **kwargs)))

def empty(self, **kwargs: Any) -> Array:
return Array(self._sync(self._async_group.empty(**kwargs)))

Expand Down
87 changes: 87 additions & 0 deletions tests/v3/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,3 +730,90 @@ async def test_group_members_async(store: LocalStore | MemoryStore):

with pytest.raises(ValueError, match="max_depth"):
[x async for x in group.members(max_depth=-1)]


async def test_require_group(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)

# create foo group
_ = await root.create_group("foo", attributes={"foo": 100})

# test that we can get the group using require_group
foo_group = await root.require_group("foo")
assert foo_group.attrs == {"foo": 100}

# test that we can get the group using require_group and overwrite=True
foo_group = await root.require_group("foo", overwrite=True)

_ = await foo_group.create_array(
"bar", shape=(10,), dtype="uint8", chunk_shape=(2,), attributes={"foo": 100}
)

# test that overwriting a group w/ children fails
# TODO: figure out why ensure_no_existing_node is not catching the foo.bar array
#
# with pytest.raises(ContainsArrayError):
# await root.require_group("foo", overwrite=True)

# test that requiring a group where an array is fails
with pytest.raises(TypeError):
await foo_group.require_group("bar")


async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
# create foo group
_ = await root.create_group("foo", attributes={"foo": 100})
# create bar group
_ = await root.create_group("bar", attributes={"bar": 200})

foo_group, bar_group = await root.require_groups("foo", "bar")
assert foo_group.attrs == {"foo": 100}
assert bar_group.attrs == {"bar": 200}

# get a mix of existing and new groups
foo_group, spam_group = await root.require_groups("foo", "spam")
assert foo_group.attrs == {"foo": 100}
assert spam_group.attrs == {}

# no names
no_group = await root.require_groups()
assert no_group == ()


async def test_create_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
foo = await root.create_dataset("foo", shape=(10,), dtype="uint8")
assert foo.shape == (10,)

with pytest.raises(ContainsArrayError):
await root.create_dataset("foo", shape=(100,), dtype="int8")

_ = await root.create_group("bar")
with pytest.raises(ContainsGroupError):
await root.create_dataset("bar", shape=(100,), dtype="int8")


async def test_require_dataset(store: LocalStore | MemoryStore, zarr_format: ZarrFormat) -> None:
root = await AsyncGroup.create(store=store, zarr_format=zarr_format)
foo1 = await root.require_dataset("foo", shape=(10,), dtype="i8", attributes={"foo": 101})
assert foo1.attrs == {"foo": 101}
foo2 = await root.require_dataset("foo", shape=(10,), dtype="i8")
assert foo2.attrs == {"foo": 101}

# exact = False
_ = await root.require_dataset("foo", shape=10, dtype="f8")

# errors w/ exact True
with pytest.raises(TypeError, match="Incompatible dtype"):
await root.require_dataset("foo", shape=(10,), dtype="f8", exact=True)

with pytest.raises(TypeError, match="Incompatible shape"):
await root.require_dataset("foo", shape=(100, 100), dtype="i8")

with pytest.raises(TypeError, match="Incompatible dtype"):
await root.require_dataset("foo", shape=(10,), dtype="f4")

_ = await root.create_group("bar")
with pytest.raises(TypeError, match="Incompatible object"):
await root.require_dataset("bar", shape=(10,), dtype="int8")