From 8faf9944898dca8ca727eaebbcc42ef1f57fa4f7 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 11 Dec 2024 15:38:07 +0100 Subject: [PATCH 01/59] sketch out batch creation routine --- src/zarr/core/group.py | 100 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 3613e3e12b..efac25a0a5 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -6,6 +6,7 @@ import logging import warnings from collections import defaultdict +from collections.abc import AsyncIterator from dataclasses import asdict, dataclass, field, fields, replace from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload @@ -1195,6 +1196,37 @@ async def require_array( return ds + async def create_nodes( + self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] + ) -> tuple[tuple[str, AsyncGroup | AsyncArray]]: + """ + Create a set of arrays or groups rooted at this group. + """ + _nodes: ( + dict[str, GroupMetadata | ArrayV3Metadata] | dict[str, GroupMetadata | ArrayV2Metadata] + ) + match self.metadata.zarr_format: + case 2: + if not all( + isinstance(node, ArrayV2Metadata | GroupMetadata) for node in nodes.values() + ): + raise ValueError("Only v2 arrays and groups are supported") + _nodes = cast(dict[str, ArrayV2Metadata | GroupMetadata], nodes) + return await create_nodes_v2( + store=self.store_path.store, path=self.path, nodes=_nodes + ) + case 3: + if not all( + isinstance(node, ArrayV3Metadata | GroupMetadata) for node in nodes.values() + ): + raise ValueError("Only v3 arrays and groups are supported") + _nodes = cast(dict[str, ArrayV3Metadata | GroupMetadata], nodes) + return await create_nodes_v3( + store=self.store_path.store, path=self.path, nodes=_nodes + ) + case _: + raise ValueError(f"Unsupported zarr format: {self.metadata.zarr_format}") + async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: """Update group attributes. @@ -2627,3 +2659,71 @@ def array( ) ) ) + + +async def _save_metadata_return_node( + node: AsyncArray[Any] | AsyncGroup, +) -> AsyncArray[Any] | AsyncGroup: + if isinstance(node, AsyncArray): + await node._save_metadata(node.metadata, ensure_parents=False) + else: + await node._save_metadata(ensure_parents=False) + return node + + +async def create_nodes_v2( + *, store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata] +) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata]]]: ... + + +async def create_nodes( + *, store_path: StorePath, nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] +) -> AsyncIterator[AsyncGroup | AsyncArray[Any]]: + """ + Create a collection of arrays and groups concurrently and atomically. To ensure atomicity, + no attempt is made to ensure that intermediate groups are created. + """ + create_tasks = [] + for key, value in nodes.items(): + new_store_path = store_path / key + node: AsyncArray[Any] | AsyncGroup + match value: + case ArrayV3Metadata() | ArrayV2Metadata(): + node = AsyncArray(value, store_path=new_store_path) + case GroupMetadata(): + node = AsyncGroup(value, store_path=new_store_path) + case _: + raise ValueError(f"Unexpected metadata type {type(value)}") + create_tasks.append(_save_metadata_return_node(node)) + for coro in asyncio.as_completed(create_tasks): + yield await coro + + +T = TypeVar("T") + + +def _tuplize_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]: + """ + Given a dict of {string: T} pairs, where the keys are strings separated by some separator, + return the result of splitting each key with the separator. + + Parameters + ---------- + data : dict[str, T] + A dict of {string:, T} pairs. + + Returns + ------- + dict[tuple[str,...], T] + The same values, but the keys have been split and converted to tuples. + + Examples + -------- + >>> _tuplize_tree({"a": 1}, separator='/') + {("a",): 1} + + >>> _tuplize_tree({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/') + {("a", "b"): 1, ("a", "b", "c"): 2, ("c",): 3} + """ + + return {tuple(k.split(separator)): v for k, v in data.items()} From 89529110e77d6d1e734d31275651118361278b9b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 18 Dec 2024 15:41:35 +0100 Subject: [PATCH 02/59] scratch state of easy batch creation --- tests/conftest.py | 188 ++++++++++++++++++++++++++++++++++++++++++++ tests/test_group.py | 15 +++- 2 files changed, 201 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 35f31d39b3..e4d3e3cea1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ from __future__ import annotations import pathlib +from collections.abc import Iterable from dataclasses import dataclass, field from typing import TYPE_CHECKING @@ -10,7 +11,14 @@ from hypothesis import HealthCheck, Verbosity, settings from zarr import AsyncGroup, config +from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec from zarr.abc.store import Store +from zarr.codecs.bytes import BytesCodec +from zarr.codecs.sharding import ShardingCodec +from zarr.core.chunk_grids import _guess_chunks +from zarr.core.chunk_key_encodings import ChunkKeyEncoding +from zarr.core.metadata.v2 import ArrayV2Metadata +from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import sync from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage.remote import RemoteStore @@ -160,3 +168,183 @@ def zarr_format(request: pytest.FixtureRequest) -> ZarrFormat: suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow], verbosity=Verbosity.verbose, ) +import numcodecs + + +def meta_from_array_v2( + array: np.ndarray[Any, Any], + chunks: ChunkCoords | Literal["auto"] = "auto", + compressor: numcodecs.abc.Codec | Literal["auto"] | None = "auto", + filters: Iterable[numcodecs.abc.Codec] | Literal["auto"] = "auto", + fill_value: Any = "auto", + order: MemoryOrder | Literal["auto"] = "auto", + dimension_separator: Literal[".", "/", "auto"] = "auto", + attributes: dict[str, Any] | None = None, +) -> ArrayV2Metadata: + """ + Create a v2 metadata object from a numpy array + """ + + _chunks = auto_chunks(chunks, array.shape, array.dtype) + _compressor = auto_compressor(compressor) + _filters = auto_filters(filters) + _fill_value = auto_fill_value(fill_value) + _order = auto_order(order) + _dimension_separator = auto_dimension_separator(dimension_separator) + return ArrayV2Metadata( + shape=array.shape, + dtype=array.dtype, + chunks=_chunks, + compressor=_compressor, + filters=_filters, + fill_value=_fill_value, + order=_order, + dimension_separator=_dimension_separator, + attributes=attributes, + ) + + +from typing import TypedDict + + +class ChunkEncoding(TypedDict): + filters: tuple[ArrayArrayCodec] + compressors: tuple[BytesBytesCodec] + serializer: ArrayBytesCodec + + +class ChunkingSpec(TypedDict): + shard_shape: tuple[int, ...] + chunk_shape: tuple[int, ...] | None + chunk_key_encoding: ChunkKeyEncoding + + +def meta_from_array_v3( + array: np.ndarray[Any, Any], + shard_shape: tuple[int, ...] | Literal["auto"] | None, + chunk_shape: tuple[int, ...] | Literal["auto"], + serializer: ArrayBytesCodec | Literal["auto"] = "auto", + compressors: Iterable[BytesBytesCodec] | Literal["auto"] = "auto", + filters: Iterable[ArrayArrayCodec] | Literal["auto"] = "auto", + fill_value: Any = "auto", + chunk_key_encoding: ChunkKeyEncoding | Literal["auto"] = "auto", + dimension_names: Iterable[str] | None = None, + attributes: dict[str, Any] | None = None, +) -> ArrayV3Metadata: + _write_chunks, _read_chunks = auto_chunks_v3( + shard_shape=shard_shape, chunk_shape=chunk_shape, array_shape=array.shape, dtype=array.dtype + ) + _codecs = auto_codecs(serializer=serializer, compressors=compressors, filters=filters) + if _read_chunks is not None: + _codecs = (ShardingCodec(codecs=_codecs, chunk_shape=_read_chunks),) + + _fill_value = auto_fill_value(fill_value) + _chunk_key_encoding = auto_chunk_key_encoding(chunk_key_encoding) + return ArrayV3Metadata( + shape=array.shape, + dtype=array.dtype, + codecs=_codecs, + chunk_key_encoding=_chunk_key_encoding, + fill_value=fill_value, + chunk_grid={"name": "regular", "config": {"chunk_shape": shard_shape}}, + attributes=attributes, + dimension_names=dimension_names, + ) + + +from zarr.abc.codec import Codec +from zarr.codecs import ZstdCodec + + +def auto_codecs( + *, + filters: Iterable[ArrayArrayCodec] | Literal["auto"] = "auto", + compressors: Iterable[BytesBytesCodec] | Literal["auto"] = "auto", + serializer: ArrayBytesCodec | Literal["auto"] = "auto", +) -> tuple[Codec, ...]: + """ + Heuristically generate a tuple of codecs + """ + _compressors: tuple[BytesBytesCodec, ...] + _filters: tuple[ArrayArrayCodec, ...] + _serializer: ArrayBytesCodec + if filters == "auto": + _filters = () + else: + _filters = tuple(filters) + + if compressors == "auto": + _compressors = (ZstdCodec(level=3),) + else: + _compressors = tuple(compressors) + + if serializer == "auto": + _serializer = BytesCodec() + else: + _serializer = serializer + return (*_filters, _serializer, *_compressors) + + +def auto_dimension_separator(dimension_separator: Literal[".", "/", "auto"]) -> Literal[".", "/"]: + if dimension_separator == "auto": + return "/" + return dimension_separator + + +def auto_order(order: MemoryOrder | Literal["auto"]) -> MemoryOrder: + if order == "auto": + return "C" + return order + + +def auto_fill_value(fill_value: Any) -> Any: + if fill_value == "auto": + return 0 + return fill_value + + +def auto_compressor( + compressor: numcodecs.abc.Codec | Literal["auto"] | None, +) -> numcodecs.abc.Codec | None: + if compressor == "auto": + return numcodecs.Zstd(level=3) + return compressor + + +def auto_filters( + filters: Iterable[numcodecs.abc.Codec] | Literal["auto"], +) -> tuple[numcodecs.abc.Codec, ...]: + if filters == "auto": + return () + return tuple(filters) + + +def auto_chunks( + chunks: tuple[int, ...] | Literal["auto"], shape: tuple[int, ...], dtype: npt.DTypeLike +) -> tuple[int, ...]: + if chunks == "auto": + return _guess_chunks(shape, np.dtype(dtype).itemsize) + return chunks + + +def auto_chunks_v3( + *, + shard_shape: tuple[int, ...] | Literal["auto"], + chunk_shape: tuple[int, ...] | Literal["auto"] | None, + array_shape: tuple[int, ...], + dtype: npt.DTypeLike, +) -> tuple[tuple[int, ...], tuple[int, ...] | None]: + match (shard_shape, chunk_shape): + case ("auto", "auto"): + # stupid default but easy to think about + return ((256,) * len(array_shape), (64,) * len(array_shape)) + case ("auto", None): + return (_guess_chunks(array_shape, np.dtype(dtype).itemsize), None) + case ("auto", _): + return (chunk_shape, chunk_shape) + case (_, None): + return (shard_shape, None) + case (_, "auto"): + return (shard_shape, shard_shape) + case _: + return (shard_shape, chunk_shape) diff --git a/tests/test_group.py b/tests/test_group.py index afa290207d..531ba7449d 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -17,13 +17,13 @@ from zarr.abc.store import Store from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.group import ConsolidatedMetadata, GroupMetadata +from zarr.core.group import ConsolidatedMetadata, GroupMetadata, create_nodes from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage.common import make_store_path -from .conftest import parse_store +from .conftest import meta_from_array_v2, parse_store if TYPE_CHECKING: from _pytest.compat import LEGACY_PATH @@ -1420,3 +1420,14 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None del g1["0"] with pytest.raises(KeyError): g1["0/0"] + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +async def test_create_nodes(store: Store) -> None: + """ + Ensure that create_nodes works. + """ + arrays = {str(idx): meta_from_array_v2(np.arange(idx)) for idx in range(1, 5)} + spath = await make_store_path(store, path="foo") + results = [a async for a in create_nodes(store_path=spath, nodes=arrays)] + breakpoint() From c700e390a4db52565f9c492bfde2adef10ce24a2 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 3 Jan 2025 19:33:23 +0100 Subject: [PATCH 03/59] rename tupleize keys --- src/zarr/core/group.py | 53 ++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 43c6270d2c..87b08bcda2 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -6,8 +6,9 @@ import logging import warnings from collections import defaultdict -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Awaitable from dataclasses import asdict, dataclass, field, fields, replace +from functools import partial from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload import numpy as np @@ -44,7 +45,7 @@ from zarr.storage.common import StorePath, ensure_no_existing_node if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Generator, Iterable, Iterator + from collections.abc import AsyncGenerator, Callable, Generator, Iterable, Iterator from typing import Any from zarr.abc.codec import Codec @@ -1226,7 +1227,7 @@ async def require_array( async def create_nodes( self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] - ) -> tuple[tuple[str, AsyncGroup | AsyncArray]]: + ) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]]: """ Create a set of arrays or groups rooted at this group. """ @@ -2746,23 +2747,36 @@ def array( ) -async def _save_metadata_return_node( +async def _with_semaphore( + func: Callable[[Any], Awaitable[T]], semaphore: asyncio.Semaphore | None = None +) -> T: + if semaphore is None: + return await func(None) + async with semaphore: + return await func(None) + + +async def _save_metadata( node: AsyncArray[Any] | AsyncGroup, ) -> AsyncArray[Any] | AsyncGroup: - if isinstance(node, AsyncArray): - await node._save_metadata(node.metadata, ensure_parents=False) - else: - await node._save_metadata(ensure_parents=False) + """ + Save the metadata for an array or group, and return the array or group + """ + match node: + case AsyncArray(): + await node._save_metadata(node.metadata, ensure_parents=False) + case AsyncGroup(): + await node._save_metadata(ensure_parents=False) + case _: + raise ValueError(f"Unexpected node type {type(node)}") return node -async def create_nodes_v2( - *, store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata] -) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata]]]: ... - - async def create_nodes( - *, store_path: StorePath, nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + *, + store_path: StorePath, + nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], + semaphore: asyncio.Semaphore | None = None, ) -> AsyncIterator[AsyncGroup | AsyncArray[Any]]: """ Create a collection of arrays and groups concurrently and atomically. To ensure atomicity, @@ -2779,7 +2793,10 @@ async def create_nodes( node = AsyncGroup(value, store_path=new_store_path) case _: raise ValueError(f"Unexpected metadata type {type(value)}") - create_tasks.append(_save_metadata_return_node(node)) + partial_func = partial(_save_metadata, node) + fut = _with_semaphore(partial_func, semaphore) + create_tasks.append(fut) + for coro in asyncio.as_completed(create_tasks): yield await coro @@ -2787,7 +2804,7 @@ async def create_nodes( T = TypeVar("T") -def _tuplize_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]: +def _split_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]: """ Given a dict of {string: T} pairs, where the keys are strings separated by some separator, return the result of splitting each key with the separator. @@ -2804,10 +2821,10 @@ def _tuplize_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T Examples -------- - >>> _tuplize_tree({"a": 1}, separator='/') + >>> _split_keys({"a": 1}, separator='/') {("a",): 1} - >>> _tuplize_tree({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/') + >>> _split_keys({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/') {("a", "b"): 1, ("a", "b", "c"): 2, ("c",): 3} """ From 57ceb649e8406db899ffeffe063a987bc4a4ec61 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Jan 2025 14:18:52 +0100 Subject: [PATCH 04/59] tests and proper implementation for create_nodes and create_hierarchy --- src/zarr/core/group.py | 131 +++++++------- src/zarr/core/sync.py | 18 +- tests/conftest.py | 381 ++++++++++++++++++++++------------------- tests/test_group.py | 49 +++++- 4 files changed, 338 insertions(+), 241 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 82b676b208..6b2ca73e1a 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -6,9 +6,9 @@ import logging import warnings from collections import defaultdict -from collections.abc import AsyncIterator, Awaitable from dataclasses import asdict, dataclass, field, fields, replace from functools import partial +from itertools import accumulate from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload import numpy as np @@ -50,13 +50,20 @@ from zarr.core.config import config from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v3 import V3JsonEncoder -from zarr.core.sync import SyncMixin, sync +from zarr.core.sync import SyncMixin, _with_semaphore, sync from zarr.errors import MetadataValidationError from zarr.storage import StoreLike, StorePath, make_store_path from zarr.storage._common import ensure_no_existing_node if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Callable, Generator, Iterable, Iterator + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Generator, + Iterable, + Iterator, + Mapping, + ) from typing import Any from zarr.core.array_spec import ArrayConfig, ArrayConfigLike @@ -1265,36 +1272,14 @@ async def require_array( return ds - async def create_nodes( + async def _create_nodes( self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] - ) -> tuple[tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]]: + ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ Create a set of arrays or groups rooted at this group. """ - _nodes: ( - dict[str, GroupMetadata | ArrayV3Metadata] | dict[str, GroupMetadata | ArrayV2Metadata] - ) - match self.metadata.zarr_format: - case 2: - if not all( - isinstance(node, ArrayV2Metadata | GroupMetadata) for node in nodes.values() - ): - raise ValueError("Only v2 arrays and groups are supported") - _nodes = cast(dict[str, ArrayV2Metadata | GroupMetadata], nodes) - return await create_nodes_v2( - store=self.store_path.store, path=self.path, nodes=_nodes - ) - case 3: - if not all( - isinstance(node, ArrayV3Metadata | GroupMetadata) for node in nodes.values() - ): - raise ValueError("Only v3 arrays and groups are supported") - _nodes = cast(dict[str, ArrayV3Metadata | GroupMetadata], nodes) - return await create_nodes_v3( - store=self.store_path.store, path=self.path, nodes=_nodes - ) - case _: - raise ValueError(f"Unsupported zarr format: {self.metadata.zarr_format}") + async for node in create_hierarchy(store_path=self.store_path, nodes=nodes): + yield node async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: """Update group attributes. @@ -2818,15 +2803,6 @@ def array( ) -async def _with_semaphore( - func: Callable[[Any], Awaitable[T]], semaphore: asyncio.Semaphore | None = None -) -> T: - if semaphore is None: - return await func(None) - async with semaphore: - return await func(None) - - async def _save_metadata( node: AsyncArray[Any] | AsyncGroup, ) -> AsyncArray[Any] | AsyncGroup: @@ -2843,6 +2819,43 @@ async def _save_metadata( return node +async def create_hierarchy( + *, + store_path: StorePath, + nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], + semaphore: asyncio.Semaphore | None = None, +) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: + """ + Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input + ``nodes`` will be created as needed. + + This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy + concurrently. The groups and arrays in the hierarchy are created in a single pass, and the + function yields the created nodes in the order they are created. + + Parameters + ---------- + store_path : StorePath + The StorePath object pointing to the root of the hierarchy. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + semaphore : asyncio.Semaphore | None + An optional semaphore to limit the number of concurrent tasks. If not + provided, the number of concurrent tasks is not limited. + + Yields + ------ + AsyncGroup | AsyncArray + The created nodes in the order they are created. + """ + nodes_parsed = parse_hierarchy_dict(nodes) + async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore): + yield node + + async def create_nodes( *, store_path: StorePath, @@ -2875,28 +2888,28 @@ async def create_nodes( T = TypeVar("T") -def _split_keys(data: dict[str, T], separator: str) -> dict[tuple[str, ...], T]: +def parse_hierarchy_dict( + data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], +) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: """ - Given a dict of {string: T} pairs, where the keys are strings separated by some separator, - return the result of splitting each key with the separator. - - Parameters - ---------- - data : dict[str, T] - A dict of {string:, T} pairs. - - Returns - ------- - dict[tuple[str,...], T] - The same values, but the keys have been split and converted to tuples. + If the input represents a complete Zarr hierarchy, i.e. one with no implicit groups, + then return an identical copy of that dict. Otherwise, return a version of the input dict + with groups added where they are needed to make the hierarchy explicit. - Examples - -------- - >>> _split_keys({"a": 1}, separator='/') - {("a",): 1} + For example, an input of {'a/b/c': ...} will result in a return value of + {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}. - >>> _split_keys({"a/b": 1, "a/b/c": 2, "c": 3}, separator='/') - {("a", "b"): 1, ("a", "b", "c"): 2, ("c",): 3} + This function is useful for ensuring that the input to create_hierarchy is a complete + Zarr hierarchy. """ - - return {tuple(k.split(separator)): v for k, v in data.items()} + # Create a copy of the input dict + out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data} + for k, v in data.items(): + # Split the key into its path components + key_split = k.split("/") + # Iterate over the path components + for subpath in accumulate(key_split, lambda a, b: f"{a}/{b}"): + # If a component is not already in the output dict, add it + if subpath not in out: + out[subpath] = GroupMetadata(zarr_format=v.zarr_format) + return out diff --git a/src/zarr/core/sync.py b/src/zarr/core/sync.py index f7d4529478..653c9a5fc0 100644 --- a/src/zarr/core/sync.py +++ b/src/zarr/core/sync.py @@ -5,14 +5,14 @@ import logging import threading from concurrent.futures import ThreadPoolExecutor, wait -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from typing_extensions import ParamSpec from zarr.core.config import config if TYPE_CHECKING: - from collections.abc import AsyncIterator, Coroutine + from collections.abc import AsyncIterator, Awaitable, Callable, Coroutine from typing import Any logger = logging.getLogger(__name__) @@ -192,3 +192,17 @@ async def iter_to_list() -> list[T]: return [item async for item in async_iterator] return self._sync(iter_to_list()) + + +async def _with_semaphore( + func: Callable[[], Awaitable[T]], semaphore: asyncio.Semaphore | None = None +) -> T: + """ + Await the result of invoking the no-argument-callable ``func`` within the context manager + provided by a Semaphore, if one is provided. Otherwise, just await the result of invoking + ``func``. + """ + if semaphore is None: + return await func() + async with semaphore: + return await func() diff --git a/tests/conftest.py b/tests/conftest.py index 6433094cc0..9bd42a81ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,8 @@ from __future__ import annotations import pathlib -from collections.abc import Iterable from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np import numpy.typing as npt @@ -11,24 +10,31 @@ from hypothesis import HealthCheck, Verbosity, settings from zarr import AsyncGroup, config -from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec +from zarr.abc.codec import Codec from zarr.abc.store import Store -from zarr.codecs.bytes import BytesCodec -from zarr.codecs.sharding import ShardingCodec -from zarr.core.chunk_grids import _guess_chunks -from zarr.core.chunk_key_encodings import ChunkKeyEncoding +from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation +from zarr.core.array import ( + _parse_chunk_encoding_v2, + _parse_chunk_encoding_v3, + _parse_chunk_key_encoding, +) +from zarr.core.chunk_grids import RegularChunkGrid, _auto_partition +from zarr.core.common import JSON, parse_dtype, parse_shapelike +from zarr.core.config import config as zarr_config from zarr.core.metadata.v2 import ArrayV2Metadata from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import sync from zarr.storage import FsspecStore, LocalStore, MemoryStore, StorePath, ZipStore if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Iterable from typing import Any, Literal from _pytest.compat import LEGACY_PATH - from zarr.core.common import ChunkCoords, MemoryOrder, ZarrFormat + from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike + from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike + from zarr.core.common import ChunkCoords, MemoryOrder, ShapeLike, ZarrFormat async def parse_store( @@ -167,183 +173,210 @@ def zarr_format(request: pytest.FixtureRequest) -> ZarrFormat: suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow], verbosity=Verbosity.verbose, ) -import numcodecs - -def meta_from_array_v2( - array: np.ndarray[Any, Any], +# TODO: uncomment these overrides when we can get mypy to accept them +""" +@overload +def create_array_metadata( + *, + shape: ShapeLike, + dtype: npt.DTypeLike, + chunks: ChunkCoords | Literal["auto"], + shards: None, + filters: FiltersLike, + compressors: CompressorsLike, + serializer: SerializerLike, + fill_value: Any | None, + order: MemoryOrder | None, + zarr_format: Literal[2], + attributes: dict[str, JSON] | None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None, + dimension_names: None, +) -> ArrayV2Metadata: ... + + +@overload +def create_array_metadata( + *, + shape: ShapeLike, + dtype: npt.DTypeLike, + chunks: ChunkCoords | Literal["auto"], + shards: ShardsLike | None, + filters: FiltersLike, + compressors: CompressorsLike, + serializer: SerializerLike, + fill_value: Any | None, + order: None, + zarr_format: Literal[3], + attributes: dict[str, JSON] | None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None, + dimension_names: Iterable[str] | None, +) -> ArrayV3Metadata: ... +""" + + +def create_array_metadata( + *, + shape: ShapeLike, + dtype: npt.DTypeLike, chunks: ChunkCoords | Literal["auto"] = "auto", - compressor: numcodecs.abc.Codec | Literal["auto"] | None = "auto", - filters: Iterable[numcodecs.abc.Codec] | Literal["auto"] = "auto", - fill_value: Any = "auto", - order: MemoryOrder | Literal["auto"] = "auto", - dimension_separator: Literal[".", "/", "auto"] = "auto", - attributes: dict[str, Any] | None = None, -) -> ArrayV2Metadata: + shards: ShardsLike | None = None, + filters: FiltersLike = "auto", + compressors: CompressorsLike = "auto", + serializer: SerializerLike = "auto", + fill_value: Any | None = None, + order: MemoryOrder | None = None, + zarr_format: ZarrFormat, + attributes: dict[str, JSON] | None = None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None = None, + dimension_names: Iterable[str] | None = None, +) -> ArrayV2Metadata | ArrayV3Metadata: """ - Create a v2 metadata object from a numpy array + Create array metadata """ - - _chunks = auto_chunks(chunks, array.shape, array.dtype) - _compressor = auto_compressor(compressor) - _filters = auto_filters(filters) - _fill_value = auto_fill_value(fill_value) - _order = auto_order(order) - _dimension_separator = auto_dimension_separator(dimension_separator) - return ArrayV2Metadata( - shape=array.shape, - dtype=array.dtype, - chunks=_chunks, - compressor=_compressor, - filters=_filters, - fill_value=_fill_value, - order=_order, - dimension_separator=_dimension_separator, - attributes=attributes, + dtype_parsed = parse_dtype(dtype, zarr_format=zarr_format) + shape_parsed = parse_shapelike(shape) + chunk_key_encoding_parsed = _parse_chunk_key_encoding( + chunk_key_encoding, zarr_format=zarr_format ) + shard_shape_parsed, chunk_shape_parsed = _auto_partition( + array_shape=shape_parsed, shard_shape=shards, chunk_shape=chunks, dtype=dtype_parsed + ) -from typing import TypedDict - - -class ChunkEncoding(TypedDict): - filters: tuple[ArrayArrayCodec] - compressors: tuple[BytesBytesCodec] - serializer: ArrayBytesCodec - - -class ChunkingSpec(TypedDict): - shard_shape: tuple[int, ...] - chunk_shape: tuple[int, ...] | None - chunk_key_encoding: ChunkKeyEncoding - - -def meta_from_array_v3( + if order is None: + order_parsed = zarr_config.get("array.order") + else: + order_parsed = order + chunks_out: tuple[int, ...] + + if zarr_format == 2: + filters_parsed, compressor_parsed = _parse_chunk_encoding_v2( + compressor=compressors, filters=filters, dtype=np.dtype(dtype) + ) + return ArrayV2Metadata( + shape=shape_parsed, + dtype=np.dtype(dtype), + chunks=chunk_shape_parsed, + order=order_parsed, + dimension_separator=chunk_key_encoding_parsed.separator, + fill_value=fill_value, + compressor=compressor_parsed, + filters=filters_parsed, + attributes=attributes, + ) + elif zarr_format == 3: + array_array, array_bytes, bytes_bytes = _parse_chunk_encoding_v3( + compressors=compressors, + filters=filters, + serializer=serializer, + dtype=dtype_parsed, + ) + + sub_codecs = cast(tuple[Codec, ...], (*array_array, array_bytes, *bytes_bytes)) + codecs_out: tuple[Codec, ...] + if shard_shape_parsed is not None: + index_location = None + if isinstance(shards, dict): + index_location = ShardingCodecIndexLocation(shards.get("index_location", None)) + if index_location is None: + index_location = ShardingCodecIndexLocation.end + sharding_codec = ShardingCodec( + chunk_shape=chunk_shape_parsed, codecs=sub_codecs, index_location=index_location + ) + sharding_codec.validate( + shape=chunk_shape_parsed, + dtype=dtype_parsed, + chunk_grid=RegularChunkGrid(chunk_shape=shard_shape_parsed), + ) + codecs_out = (sharding_codec,) + chunks_out = shard_shape_parsed + else: + chunks_out = chunk_shape_parsed + codecs_out = sub_codecs + + return ArrayV3Metadata( + shape=shape_parsed, + data_type=dtype_parsed, + chunk_grid=RegularChunkGrid(chunk_shape=chunks_out), + chunk_key_encoding=chunk_key_encoding_parsed, + fill_value=fill_value, + codecs=codecs_out, + attributes=attributes, + dimension_names=dimension_names, + ) + + raise ValueError(f"Invalid Zarr format: {zarr_format}") + + +# TODO: uncomment these overrides when we can get mypy to accept them +""" +@overload +def meta_from_array( array: np.ndarray[Any, Any], - shard_shape: tuple[int, ...] | Literal["auto"] | None, - chunk_shape: tuple[int, ...] | Literal["auto"], - serializer: ArrayBytesCodec | Literal["auto"] = "auto", - compressors: Iterable[BytesBytesCodec] | Literal["auto"] = "auto", - filters: Iterable[ArrayArrayCodec] | Literal["auto"] = "auto", - fill_value: Any = "auto", - chunk_key_encoding: ChunkKeyEncoding | Literal["auto"] = "auto", + chunks: ChunkCoords | Literal["auto"], + shards: None, + filters: FiltersLike, + compressors: CompressorsLike, + serializer: SerializerLike, + fill_value: Any | None, + order: MemoryOrder | None, + zarr_format: Literal[2], + attributes: dict[str, JSON] | None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None, + dimension_names: Iterable[str] | None, +) -> ArrayV2Metadata: ... + + +@overload +def meta_from_array( + array: np.ndarray[Any, Any], + chunks: ChunkCoords | Literal["auto"], + shards: ShardsLike | None, + filters: FiltersLike, + compressors: CompressorsLike, + serializer: SerializerLike, + fill_value: Any | None, + order: None, + zarr_format: Literal[3], + attributes: dict[str, JSON] | None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None, + dimension_names: Iterable[str] | None, +) -> ArrayV3Metadata: ... + +""" + + +def meta_from_array( + array: np.ndarray[Any, Any], + *, + chunks: ChunkCoords | Literal["auto"] = "auto", + shards: ShardsLike | None = None, + filters: FiltersLike = "auto", + compressors: CompressorsLike = "auto", + serializer: SerializerLike = "auto", + fill_value: Any | None = None, + order: MemoryOrder | None = None, + zarr_format: ZarrFormat = 3, + attributes: dict[str, JSON] | None = None, + chunk_key_encoding: ChunkKeyEncoding | ChunkKeyEncodingLike | None = None, dimension_names: Iterable[str] | None = None, - attributes: dict[str, Any] | None = None, -) -> ArrayV3Metadata: - _write_chunks, _read_chunks = auto_chunks_v3( - shard_shape=shard_shape, chunk_shape=chunk_shape, array_shape=array.shape, dtype=array.dtype - ) - _codecs = auto_codecs(serializer=serializer, compressors=compressors, filters=filters) - if _read_chunks is not None: - _codecs = (ShardingCodec(codecs=_codecs, chunk_shape=_read_chunks),) - - _fill_value = auto_fill_value(fill_value) - _chunk_key_encoding = auto_chunk_key_encoding(chunk_key_encoding) - return ArrayV3Metadata( +) -> ArrayV3Metadata | ArrayV2Metadata: + """ + Create array metadata from an array + """ + return create_array_metadata( shape=array.shape, dtype=array.dtype, - codecs=_codecs, - chunk_key_encoding=_chunk_key_encoding, + chunks=chunks, + shards=shards, + filters=filters, + compressors=compressors, + serializer=serializer, fill_value=fill_value, - chunk_grid={"name": "regular", "config": {"chunk_shape": shard_shape}}, + order=order, + zarr_format=zarr_format, attributes=attributes, + chunk_key_encoding=chunk_key_encoding, dimension_names=dimension_names, ) - - -from zarr.abc.codec import Codec -from zarr.codecs import ZstdCodec - - -def auto_codecs( - *, - filters: Iterable[ArrayArrayCodec] | Literal["auto"] = "auto", - compressors: Iterable[BytesBytesCodec] | Literal["auto"] = "auto", - serializer: ArrayBytesCodec | Literal["auto"] = "auto", -) -> tuple[Codec, ...]: - """ - Heuristically generate a tuple of codecs - """ - _compressors: tuple[BytesBytesCodec, ...] - _filters: tuple[ArrayArrayCodec, ...] - _serializer: ArrayBytesCodec - if filters == "auto": - _filters = () - else: - _filters = tuple(filters) - - if compressors == "auto": - _compressors = (ZstdCodec(level=3),) - else: - _compressors = tuple(compressors) - - if serializer == "auto": - _serializer = BytesCodec() - else: - _serializer = serializer - return (*_filters, _serializer, *_compressors) - - -def auto_dimension_separator(dimension_separator: Literal[".", "/", "auto"]) -> Literal[".", "/"]: - if dimension_separator == "auto": - return "/" - return dimension_separator - - -def auto_order(order: MemoryOrder | Literal["auto"]) -> MemoryOrder: - if order == "auto": - return "C" - return order - - -def auto_fill_value(fill_value: Any) -> Any: - if fill_value == "auto": - return 0 - return fill_value - - -def auto_compressor( - compressor: numcodecs.abc.Codec | Literal["auto"] | None, -) -> numcodecs.abc.Codec | None: - if compressor == "auto": - return numcodecs.Zstd(level=3) - return compressor - - -def auto_filters( - filters: Iterable[numcodecs.abc.Codec] | Literal["auto"], -) -> tuple[numcodecs.abc.Codec, ...]: - if filters == "auto": - return () - return tuple(filters) - - -def auto_chunks( - chunks: tuple[int, ...] | Literal["auto"], shape: tuple[int, ...], dtype: npt.DTypeLike -) -> tuple[int, ...]: - if chunks == "auto": - return _guess_chunks(shape, np.dtype(dtype).itemsize) - return chunks - - -def auto_chunks_v3( - *, - shard_shape: tuple[int, ...] | Literal["auto"], - chunk_shape: tuple[int, ...] | Literal["auto"] | None, - array_shape: tuple[int, ...], - dtype: npt.DTypeLike, -) -> tuple[tuple[int, ...], tuple[int, ...] | None]: - match (shard_shape, chunk_shape): - case ("auto", "auto"): - # stupid default but easy to think about - return ((256,) * len(array_shape), (64,) * len(array_shape)) - case ("auto", None): - return (_guess_chunks(array_shape, np.dtype(dtype).itemsize), None) - case ("auto", _): - return (chunk_shape, chunk_shape) - case (_, None): - return (shard_shape, None) - case (_, "auto"): - return (shard_shape, shard_shape) - case _: - return (shard_shape, chunk_shape) diff --git a/tests/test_group.py b/tests/test_group.py index 35a3996901..699926c1da 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -4,6 +4,7 @@ import operator import pickle import warnings +from pathlib import Path from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -18,12 +19,12 @@ from zarr.abc.store import Store from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.group import ConsolidatedMetadata, GroupMetadata, create_nodes +from zarr.core.group import ConsolidatedMetadata, GroupMetadata, create_hierarchy, create_nodes from zarr.core.sync import sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore, make_store_path -from .conftest import meta_from_array_v2, parse_store +from .conftest import meta_from_array, parse_store if TYPE_CHECKING: from _pytest.compat import LEGACY_PATH @@ -1441,13 +1442,49 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None @pytest.mark.parametrize("store", ["memory"], indirect=True) -async def test_create_nodes(store: Store) -> None: +async def test_create_nodes(store: Store, zarr_format: ZarrFormat) -> None: """ - Ensure that create_nodes works. + Ensure that ``create_nodes`` can create a zarr hierarchy from a model of that + hierarchy in dict form. Note that this creates an incomplete Zarr hierarchy. """ - arrays = {str(idx): meta_from_array_v2(np.arange(idx)) for idx in range(1, 5)} + path = "foo" + expected_meta = { + "group": GroupMetadata(attributes={"foo": 10}), + "group/array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), + "group/array_1": meta_from_array(np.arange(4), zarr_format=zarr_format), + "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), + "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), + } spath = await make_store_path(store, path="foo") - results = [a async for a in create_nodes(store_path=spath, nodes=arrays)] + observed_nodes = { + str(Path(a.name).relative_to("/" + path)): a + async for a in create_nodes(store_path=spath, nodes=expected_meta) + } + assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +async def test_create_hierarchy(store: Store, zarr_format: ZarrFormat) -> None: + """ + Test that ``create_hierarchy`` can create a complete Zarr hierarchy, even if the input describes + an incomplete one. + """ + path = "foo" + hierarchy_spec = { + "group": GroupMetadata(attributes={"foo": 10}), + "group/array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), + "group/array_1": meta_from_array(np.arange(4), zarr_format=zarr_format), + "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), + "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), + } + expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} + spath = await make_store_path(store, path="foo") + observed_nodes = { + str(Path(a.name).relative_to("/" + path)): a + async for a in create_hierarchy(store_path=spath, nodes=expected_meta) + } + assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} + @pytest.mark.parametrize("store", ["local", "memory"], indirect=["store"]) def test_deprecated_compressor(store: Store) -> None: From 181d3d00b441153067b7d17d4c0ce47ae51d414c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Jan 2025 14:28:03 +0100 Subject: [PATCH 05/59] privatize --- src/zarr/core/group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 6b2ca73e1a..b40c1e2c46 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2851,7 +2851,7 @@ async def create_hierarchy( AsyncGroup | AsyncArray The created nodes in the order they are created. """ - nodes_parsed = parse_hierarchy_dict(nodes) + nodes_parsed = _parse_hierarchy_dict(nodes) async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore): yield node @@ -2888,7 +2888,7 @@ async def create_nodes( T = TypeVar("T") -def parse_hierarchy_dict( +def _parse_hierarchy_dict( data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], ) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: """ From e8e610706676e9a82208fcf43d2a103e2b7602a9 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Jan 2025 14:48:54 +0100 Subject: [PATCH 06/59] use Posixpath instead of Path in tests; avoid redundant cast --- tests/conftest.py | 6 +++--- tests/test_group.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 9bd42a81ed..603a93985d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pathlib from dataclasses import dataclass, field -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import numpy as np import numpy.typing as npt @@ -10,7 +10,6 @@ from hypothesis import HealthCheck, Verbosity, settings from zarr import AsyncGroup, config -from zarr.abc.codec import Codec from zarr.abc.store import Store from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.core.array import ( @@ -32,6 +31,7 @@ from _pytest.compat import LEGACY_PATH + from zarr.abc.codec import Codec from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike from zarr.core.common import ChunkCoords, MemoryOrder, ShapeLike, ZarrFormat @@ -273,7 +273,7 @@ def create_array_metadata( dtype=dtype_parsed, ) - sub_codecs = cast(tuple[Codec, ...], (*array_array, array_bytes, *bytes_bytes)) + sub_codecs = (*array_array, array_bytes, *bytes_bytes) codecs_out: tuple[Codec, ...] if shard_shape_parsed is not None: index_location = None diff --git a/tests/test_group.py b/tests/test_group.py index 699926c1da..1aadc150e7 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -4,7 +4,7 @@ import operator import pickle import warnings -from pathlib import Path +from pathlib import PosixPath from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -1457,7 +1457,7 @@ async def test_create_nodes(store: Store, zarr_format: ZarrFormat) -> None: } spath = await make_store_path(store, path="foo") observed_nodes = { - str(Path(a.name).relative_to("/" + path)): a + str(PosixPath(a.name).relative_to("/" + path)): a async for a in create_nodes(store_path=spath, nodes=expected_meta) } assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @@ -1480,7 +1480,7 @@ async def test_create_hierarchy(store: Store, zarr_format: ZarrFormat) -> None: expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} spath = await make_store_path(store, path="foo") observed_nodes = { - str(Path(a.name).relative_to("/" + path)): a + str(PosixPath(a.name).relative_to("/" + path)): a async for a in create_hierarchy(store_path=spath, nodes=expected_meta) } assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} From 4f2c954849ab9d998557a25d6120de93a7b30409 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Jan 2025 14:52:22 +0100 Subject: [PATCH 07/59] restore cast --- tests/conftest.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 603a93985d..9bd42a81ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pathlib from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import numpy as np import numpy.typing as npt @@ -10,6 +10,7 @@ from hypothesis import HealthCheck, Verbosity, settings from zarr import AsyncGroup, config +from zarr.abc.codec import Codec from zarr.abc.store import Store from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.core.array import ( @@ -31,7 +32,6 @@ from _pytest.compat import LEGACY_PATH - from zarr.abc.codec import Codec from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike from zarr.core.common import ChunkCoords, MemoryOrder, ShapeLike, ZarrFormat @@ -273,7 +273,7 @@ def create_array_metadata( dtype=dtype_parsed, ) - sub_codecs = (*array_array, array_bytes, *bytes_bytes) + sub_codecs = cast(tuple[Codec, ...], (*array_array, array_bytes, *bytes_bytes)) codecs_out: tuple[Codec, ...] if shard_shape_parsed is not None: index_location = None From cf728347bb0733dcef3045851cd7bcee4c3d296b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Jan 2025 16:09:56 +0100 Subject: [PATCH 08/59] pureposixpath instead of posixpath --- tests/test_group.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_group.py b/tests/test_group.py index 30320bbbfd..0679183c0c 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -5,7 +5,7 @@ import pickle import time import warnings -from pathlib import PosixPath +from pathlib import PurePosixPath from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -1459,7 +1459,7 @@ async def test_create_nodes(store: Store, zarr_format: ZarrFormat) -> None: } spath = await make_store_path(store, path="foo") observed_nodes = { - str(PosixPath(a.name).relative_to("/" + path)): a + str(PurePosixPath(a.name).relative_to("/" + path)): a async for a in create_nodes(store_path=spath, nodes=expected_meta) } assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @@ -1482,7 +1482,7 @@ async def test_create_hierarchy(store: Store, zarr_format: ZarrFormat) -> None: expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} spath = await make_store_path(store, path="foo") observed_nodes = { - str(PosixPath(a.name).relative_to("/" + path)): a + str(PurePosixPath(a.name).relative_to("/" + path)): a async for a in create_hierarchy(store_path=spath, nodes=expected_meta) } assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} From e2cff8c2ba81a1a936746e7b1e5563493850e1f4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Jan 2025 17:35:08 +0100 Subject: [PATCH 09/59] group-level create_hierarchy --- src/zarr/core/group.py | 54 ++++++++++++++++++++++++++++++++++++++++++ tests/test_group.py | 18 ++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 6acd9faff4..6bca066fce 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field, fields, replace from functools import partial from itertools import accumulate +from pathlib import PurePosixPath from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload import numpy as np @@ -1424,6 +1425,33 @@ async def _members( ): yield member + # TODO: find a better name for this. create_tree could work. + # TODO: include an example in the docstring + async def create_hierarchy( + self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata] + ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: + """ + Create a hierarchy of arrays or groups rooted at this group. + + This method takes a dictionary where the keys are the names of the arrays or groups + to create and the values are the metadata or objects representing the arrays or groups. + + The method returns an asynchronous iterator over the created nodes. + + Parameters + ---------- + nodes : A dictionary representing the hierarchy to create + + Returns + ------- + An asynchronous iterator over the created nodes. + """ + semaphore = asyncio.Semaphore(config.get("async.concurrency")) + async for node in create_hierarchy( + store_path=self.store_path, nodes=nodes, semaphore=semaphore + ): + yield node + async def keys(self) -> AsyncGenerator[str, None]: """Iterate over member names.""" async for key, _ in self.members(): @@ -2046,6 +2074,32 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], return tuple((kv[0], _parse_async_node(kv[1])) for kv in _members) + def create_hierarchy( + self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata] + ) -> dict[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: + """ + Create a hierarchy of arrays or groups rooted at this group. + + This method takes a dictionary where the keys are the names of the arrays or groups + to create and the values are the metadata objects for the arrays or groups. + + The method returns an asynchronous iterator over the created nodes. + + Parameters + ---------- + nodes : A dictionary representing the hierarchy to create + + Returns + ------- + A dict containing the created nodes.The keys are the same as th + """ + nodes_created = self._sync_iter(self._async_group.create_hierarchy(nodes)) + if self.path == "": + root = "/" + else: + root = self.path + return {str(PurePosixPath(n.name).relative_to(root)): n for n in nodes_created} + def keys(self) -> Generator[str, None]: """Return an iterator over group member names. diff --git a/tests/test_group.py b/tests/test_group.py index 0679183c0c..dcf13a9a06 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1488,6 +1488,24 @@ async def test_create_hierarchy(store: Store, zarr_format: ZarrFormat) -> None: assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat): + """ + Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict. + """ + g = Group.from_store(store) + tree = { + "a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), + "a/b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a/b"}), + "a/b/c": meta_from_array( + np.zeros(5), zarr_format=zarr_format, attributes={"name": "a/b/c"} + ), + } + nodes = g.create_hierarchy(tree) + for k, v in nodes.items(): + assert v.metadata == tree[k] + + def test_group_members_performance(store: MemoryStore) -> None: """ Test that the execution time of Group.members is less than the number of members times the From 0912ecb2826bd89409d3dce7d81ff0a727018a37 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 7 Jan 2025 17:35:56 +0100 Subject: [PATCH 10/59] docstring --- src/zarr/core/group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 6bca066fce..627b4cd021 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2083,7 +2083,7 @@ def create_hierarchy( This method takes a dictionary where the keys are the names of the arrays or groups to create and the values are the metadata objects for the arrays or groups. - The method returns an asynchronous iterator over the created nodes. + The method returns a dict containing the created nodes. Parameters ---------- From 089feefca33cfc1215d4729445a87ec6e2419492 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 8 Jan 2025 19:55:52 +0100 Subject: [PATCH 11/59] sketch out from_flat for groups --- src/zarr/core/group.py | 43 ++++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 64ca065e01..f081caa161 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -10,7 +10,7 @@ from functools import partial from itertools import accumulate from pathlib import PurePosixPath -from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload +from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload import numpy as np import numpy.typing as npt @@ -426,6 +426,27 @@ class AsyncGroup: metadata: GroupMetadata store_path: StorePath + # TODO: make this correct and work + # TODO: ensure that this can be bound properly to subclass of AsyncGroup + @classmethod + async def from_flat( + cls, + store: StoreLike, + *, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False) -> Self: + + if overwrite: + store_path = await make_store_path(store, mode='w') + else: + store_path = await make_store_path(store, mode='w-') + semaphore = asyncio.Semaphore(config.get("async.concurrency")) + + nodes_created = {x.name: x async for x in create_hierarchy( + store_path=store_path, nodes=nodes, semaphore=semaphore + )} + return nodes_created[''] + @classmethod async def from_store( cls, @@ -1269,15 +1290,6 @@ async def require_array( return ds - async def _create_nodes( - self, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] - ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: - """ - Create a set of arrays or groups rooted at this group. - """ - async for node in create_hierarchy(store_path=self.store_path, nodes=nodes): - yield node - async def update_attributes(self, new_attributes: dict[str, Any]) -> AsyncGroup: """Update group attributes. @@ -1731,6 +1743,17 @@ async def move(self, source: str, dest: str) -> None: @dataclass(frozen=True) class Group(SyncMixin): _async_group: AsyncGroup + + @classmethod + def from_flat( + cls, + store: StoreLike, + *, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False) -> Group: + nodes = sync(AsyncGroup.from_flat(store, nodes=nodes, overwrite=overwrite)) + # return the root node of the hierarchy + return nodes[''] @classmethod def from_store( From 116ab87528bce4f3c5e8aade2cc4148517a65760 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 9 Jan 2025 17:21:18 +0100 Subject: [PATCH 12/59] better concurrency for v2 --- src/zarr/core/config.py | 2 +- src/zarr/core/group.py | 217 ++++++++++++++++++++++++++++++++-------- tests/test_group.py | 103 +++++++++++++++++-- 3 files changed, 270 insertions(+), 52 deletions(-) diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 7920d220a4..958d90f535 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -94,7 +94,7 @@ def reset(self) -> None: ], }, }, - "async": {"concurrency": 10, "timeout": None}, + "async": {"concurrency": 256, "timeout": None}, "threading": {"max_workers": None}, "json_indent": 2, "codec_pipeline": { diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index f081caa161..2f0d0ff018 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1,16 +1,24 @@ from __future__ import annotations import asyncio +import contextlib import itertools import json import logging import warnings from collections import defaultdict from dataclasses import asdict, dataclass, field, fields, replace -from functools import partial from itertools import accumulate from pathlib import PurePosixPath -from typing import TYPE_CHECKING, Literal, Self, TypeVar, assert_never, cast, overload +from typing import ( + TYPE_CHECKING, + Literal, + Self, + TypeVar, + assert_never, + cast, + overload, +) import numpy as np import numpy.typing as npt @@ -51,7 +59,7 @@ from zarr.core.config import config from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v3 import V3JsonEncoder -from zarr.core.sync import SyncMixin, _with_semaphore, sync +from zarr.core.sync import SyncMixin, sync from zarr.errors import MetadataValidationError from zarr.storage import StoreLike, StorePath from zarr.storage._common import ensure_no_existing_node, make_store_path @@ -60,6 +68,7 @@ from collections.abc import ( AsyncGenerator, AsyncIterator, + Coroutine, Generator, Iterable, Iterator, @@ -431,21 +440,32 @@ class AsyncGroup: @classmethod async def from_flat( cls, - store: StoreLike, - *, - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - overwrite: bool = False) -> Self: - + store: StoreLike, + *, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, + ) -> Self: + if not _is_rooted(nodes): + msg = ( + "The input does not specify a root node. ", + "This function can only create hierarchies that contain a root node, which is ", + "defined as a group that is ancestral to all the other arrays and ", + "groups in the hierarchy.", + ) + raise ValueError(msg) + if overwrite: - store_path = await make_store_path(store, mode='w') + store_path = await make_store_path(store, mode="w") else: - store_path = await make_store_path(store, mode='w-') + store_path = await make_store_path(store, mode="w-") + semaphore = asyncio.Semaphore(config.get("async.concurrency")) - - nodes_created = {x.name: x async for x in create_hierarchy( - store_path=store_path, nodes=nodes, semaphore=semaphore - )} - return nodes_created[''] + + nodes_created = { + x.name: x + async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore) + } + # TODO: make this work @classmethod async def from_store( @@ -1743,17 +1763,18 @@ async def move(self, source: str, dest: str) -> None: @dataclass(frozen=True) class Group(SyncMixin): _async_group: AsyncGroup - + @classmethod def from_flat( - cls, + cls, store: StoreLike, - *, + *, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - overwrite: bool = False) -> Group: + overwrite: bool = False, + ) -> Group: nodes = sync(AsyncGroup.from_flat(store, nodes=nodes, overwrite=overwrite)) # return the root node of the hierarchy - return nodes[''] + return nodes[""] @classmethod def from_store( @@ -2110,17 +2131,28 @@ def create_hierarchy( Parameters ---------- - nodes : A dictionary representing the hierarchy to create + nodes : A dictionary representing the hierarchy to create. The keys should be relative paths + and the values should be the metadata for the arrays or groups to create. Returns ------- - A dict containing the created nodes.The keys are the same as th - """ + A dict containing the created nodes, with the same keys as the input + """ + # check that all the nodes have the same zarr_format as Self + for key, value in nodes.items(): + if value.zarr_format != self.metadata.zarr_format: + msg = ( + "The zarr_format of the nodes must be the same as the parent group. " + f"The node at {key} has zarr_format {value.zarr_format}, but the parent group" + f" has zarr_format {self.metadata.zarr_format}." + ) + raise ValueError(msg) nodes_created = self._sync_iter(self._async_group.create_hierarchy(nodes)) if self.path == "": root = "/" else: root = self.path + # TODO: make this safe against invalid path inputs return {str(PurePosixPath(n.name).relative_to(root)): n for n in nodes_created} def keys(self) -> Generator[str, None]: @@ -2859,6 +2891,7 @@ def array( async def _save_metadata( node: AsyncArray[Any] | AsyncGroup, + overwrite: bool, ) -> AsyncArray[Any] | AsyncGroup: """ Save the metadata for an array or group, and return the array or group @@ -2878,6 +2911,7 @@ async def create_hierarchy( store_path: StorePath, nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], semaphore: asyncio.Semaphore | None = None, + overwrite: bool = False, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input @@ -2906,6 +2940,7 @@ async def create_hierarchy( The created nodes in the order they are created. """ nodes_parsed = _parse_hierarchy_dict(nodes) + async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore): yield node @@ -2913,35 +2948,73 @@ async def create_hierarchy( async def create_nodes( *, store_path: StorePath, - nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], semaphore: asyncio.Semaphore | None = None, -) -> AsyncIterator[AsyncGroup | AsyncArray[Any]]: +) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ - Create a collection of arrays and groups concurrently and atomically. To ensure atomicity, + Create a collection of zarr v2 arrays and groups concurrently and atomically. To ensure atomicity, no attempt is made to ensure that intermediate groups are created. """ - create_tasks = [] + ctx: asyncio.Semaphore | contextlib.nullcontext[None] + if semaphore is None: + ctx = contextlib.nullcontext() + else: + ctx = semaphore + + create_tasks: list[Coroutine[None, None, str]] = [] + for key, value in nodes.items(): - new_store_path = store_path / key - node: AsyncArray[Any] | AsyncGroup - match value: - case ArrayV3Metadata() | ArrayV2Metadata(): - node = AsyncArray(value, store_path=new_store_path) - case GroupMetadata(): - node = AsyncGroup(value, store_path=new_store_path) - case _: - raise ValueError(f"Unexpected metadata type {type(value)}") - partial_func = partial(_save_metadata, node) - fut = _with_semaphore(partial_func, semaphore) - create_tasks.append(fut) + create_tasks.extend( + _prepare_save_metadata(store_path.store, f"{store_path.path}/{key}", value) + ) + + created_keys = [] + async with ctx: + for coro in asyncio.as_completed(create_tasks): + created_key = await coro + relative_path = PurePosixPath(created_key).relative_to(store_path.path) + created_keys.append(str(relative_path)) + # convert /foo/bar/baz/.zattrs to bar/baz + node_name = str(relative_path.parent) + meta_out = nodes[node_name] + + if meta_out.zarr_format == 3: + if isinstance(meta_out, GroupMetadata): + yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name) + else: + yield AsyncArray(metadata=meta_out, store_path=store_path / node_name) + else: + # For zarr v2 + # we only want to yield when both the metadata and attributes are created + # so we track which keys have been created, and wait for both the meta key and + # the attrs key to be created before yielding back the AsyncArray / AsyncGroup + + attrs_done = f"{node_name}/.zattrs" in created_keys + + if isinstance(meta_out, GroupMetadata): + meta_done = f"{node_name}/.zgroup" in created_keys + else: + meta_done = f"{node_name}/.zarray" in created_keys - for coro in asyncio.as_completed(create_tasks): - yield await coro + if meta_done and attrs_done: + if isinstance(meta_out, GroupMetadata): + yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name) + else: + yield AsyncArray(metadata=meta_out, store_path=store_path / node_name) T = TypeVar("T") +def _is_rooted(data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]) -> bool: + """ + Check if the data describes a hierarchy that's rooted, which means there is a single node with + the least number of components in its key + """ + # a dict + return False + + def _parse_hierarchy_dict( data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], ) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: @@ -2953,19 +3026,54 @@ def _parse_hierarchy_dict( For example, an input of {'a/b/c': ...} will result in a return value of {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}. - This function is useful for ensuring that the input to create_hierarchy is a complete + The input is also checked for the following conditions, and an error is raised if any + of them are violated: + + - No arrays can contain group or arrays (i.e., all arrays must be leaf nodes). + - All arrays and groups must have the same ``zarr_format`` value. + + This function ensures that the input is transformed into a specification of a complete and valid Zarr hierarchy. """ # Create a copy of the input dict out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data} + + observed_zarr_formats: dict[ZarrFormat, list[str]] = {2: [], 3: []} + + # We will iterate over the dict again, but a full pass here ensures that the error message + # is comprehensive, and I think the performance cost will be negligible. + for k, v in data.items(): + observed_zarr_formats[v.zarr_format].append(k) + + if len(observed_zarr_formats[2]) > 0 and len(observed_zarr_formats[3]) > 0: + msg = ( + "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. " + f"The following keys map to Zarr v2 nodes: {observed_zarr_formats.get(2)}. " + f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}." + "Ensure that all nodes have the same Zarr format." + ) + + raise ValueError(msg) + for k, v in data.items(): + # TODO: ensure that the key is a valid path # Split the key into its path components key_split = k.split("/") - # Iterate over the path components - for subpath in accumulate(key_split, lambda a, b: f"{a}/{b}"): + + # Iterate over the intermediate path components + *subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}") + for subpath in subpaths: # If a component is not already in the output dict, add it if subpath not in out: out[subpath] = GroupMetadata(zarr_format=v.zarr_format) + else: + if not isinstance(out[subpath], GroupMetadata): + msg = ( + f"The node at {subpath} contains other nodes, but it is not a Zarr group. " + "This is invalid. Only Zarr groups can contain other nodes." + ) + raise ValueError(msg) + return out @@ -3155,3 +3263,24 @@ def _build_node_v2( return AsyncGroup(metadata, store_path=store_path) case _: raise ValueError(f"Unexpected metadata type: {type(metadata)}") + + +async def _set_return_key(store: Store, key: str, value: Buffer) -> str: + """ + Store.set, but the key and the value are returned. + Useful when saving metadata via asyncio.as_completed, because + we need to know which key was saved. + """ + await store.set(key, value) + return key + + +def _prepare_save_metadata( + store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata +) -> tuple[Coroutine[None, None, str], ...]: + """ + Prepare to save a metadata document to storage. Returns a tuple of coroutines that must be awaited. + """ + + to_save = metadata.to_buffer_dict(default_buffer_prototype()) + return tuple(_set_return_key(store, f"{path}/{key}", value) for key, value in to_save.items()) diff --git a/tests/test_group.py b/tests/test_group.py index 2507594e78..148087ac13 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -3,6 +3,7 @@ import contextlib import operator import pickle +import re import time import warnings from pathlib import PurePosixPath @@ -21,7 +22,7 @@ from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype from zarr.core.group import ConsolidatedMetadata, GroupMetadata, create_hierarchy, create_nodes -from zarr.core.sync import sync +from zarr.core.sync import _collect_aiterator, sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage._common import make_store_path @@ -1473,14 +1474,14 @@ async def test_create_hierarchy(store: Store, zarr_format: ZarrFormat) -> None: """ path = "foo" hierarchy_spec = { - "group": GroupMetadata(attributes={"foo": 10}), + "group": GroupMetadata(attributes={"foo": 10}, zarr_format=zarr_format), "group/array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), "group/array_1": meta_from_array(np.arange(4), zarr_format=zarr_format), "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} - spath = await make_store_path(store, path="foo") + spath = await make_store_path(store, path=path) observed_nodes = { str(PurePosixPath(a.name).relative_to("/" + path)): a async for a in create_hierarchy(store_path=spath, nodes=expected_meta) @@ -1493,7 +1494,7 @@ def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat): """ Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict. """ - g = Group.from_store(store) + g = Group.from_store(store, zarr_format=zarr_format) tree = { "a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), "a/b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a/b"}), @@ -1502,11 +1503,99 @@ def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat): ), } nodes = g.create_hierarchy(tree) - for k, v in nodes.items(): - assert v.metadata == tree[k] + for k, v in g.members(max_depth=None): + assert v.metadata == tree[k] == nodes[k].metadata + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_group_create_hierarchy_invalid_mixed_zarr_format(store: Store, zarr_format: ZarrFormat): + """ + Test that ```Group.create_hierarchy``` will raise an error if the zarr_format of the nodes is + different from the parent group. + """ + other_format = 2 if zarr_format == 3 else 3 + g = Group.from_store(store, zarr_format=other_format) + tree = { + "a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), + "a/b": meta_from_array(np.zeros(5), zarr_format=zarr_format, attributes={"name": "a/c"}), + } + + msg = "The zarr_format of the nodes must be the same as the parent group." + with pytest.raises(ValueError, match=msg): + _ = g.create_hierarchy(tree) + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("defect", ["array/array", "array/group"]) +async def test_create_hierarchy_invalid_nested( + store: Store, defect: tuple[str, str], zarr_format +) -> None: + """ + Test that create_hierarchy will not create a Zarr array that contains a Zarr group + or Zarr array. + """ + + if defect == "array/array": + hierarchy_spec = { + "array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), + "array_0/subarray": meta_from_array(np.arange(4), zarr_format=zarr_format), + } + elif defect == "array/group": + hierarchy_spec = { + "array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), + "array_0/subgroup": GroupMetadata(attributes={"foo": 10}, zarr_format=zarr_format), + } + msg = "Only Zarr groups can contain other nodes." + with pytest.raises(ValueError, match=msg): + spath = await make_store_path(store, path="foo") + await _collect_aiterator(create_hierarchy(store_path=spath, nodes=hierarchy_spec)) -def test_group_members_performance(store: MemoryStore) -> None: + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +async def test_create_hierarchy_invalid_mixed_format(store: Store): + """ + Test that create_hierarchy will not create a Zarr group that contains a both Zarr v2 and + Zarr v3 nodes. + """ + spath = await make_store_path(store, path="foo") + msg = ( + "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. " + "The following keys map to Zarr v2 nodes: ['v2']. " + "The following keys map to Zarr v3 nodes: ['v3']." + "Ensure that all nodes have the same Zarr format." + ) + with pytest.raises(ValueError, match=re.escape(msg)): + await _collect_aiterator( + create_hierarchy( + store_path=spath, + nodes={ + "v2": GroupMetadata(zarr_format=2), + "v3": GroupMetadata(zarr_format=3), + }, + ) + ) + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +async def test_group_from_flat(store: Store, zarr_format): + """ + Test that the AsyncGroup.from_flat method creates a zarr group in one shot. + """ + hierarchy_spec = { + "a": GroupMetadata(zarr_format=zarr_format), + "a/b": GroupMetadata(zarr_format=zarr_format), + "a/b/c": GroupMetadata(zarr_format=zarr_format), + } + g = await AsyncGroup.from_flat(store, nodes=hierarchy_spec) + assert g.members() == [ + ("b", GroupMetadata(zarr_format=zarr_format)), + ("b/c", GroupMetadata(zarr_format=zarr_format)), + ] + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_group_members_performance(store: Store) -> None: """ Test that the execution time of Group.members is less than the number of members times the latency for accessing each member. From e38c1ca6feb64897b26e5e0e558226a34e8e2c3d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 9 Jan 2025 21:42:03 +0100 Subject: [PATCH 13/59] revert change to default concurrency --- src/zarr/core/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zarr/core/config.py b/src/zarr/core/config.py index 958d90f535..7920d220a4 100644 --- a/src/zarr/core/config.py +++ b/src/zarr/core/config.py @@ -94,7 +94,7 @@ def reset(self) -> None: ], }, }, - "async": {"concurrency": 256, "timeout": None}, + "async": {"concurrency": 10, "timeout": None}, "threading": {"max_workers": None}, "json_indent": 2, "codec_pipeline": { From 2fb9083c4e98d026a4905b41dadde725e6da4e16 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 9 Jan 2025 21:48:35 +0100 Subject: [PATCH 14/59] create root correctly --- src/zarr/core/group.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 2f0d0ff018..4bb909b643 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2967,12 +2967,15 @@ async def create_nodes( create_tasks.extend( _prepare_save_metadata(store_path.store, f"{store_path.path}/{key}", value) ) - + if store_path.path == "": + root = "/" + else: + root = store_path.path created_keys = [] async with ctx: for coro in asyncio.as_completed(create_tasks): created_key = await coro - relative_path = PurePosixPath(created_key).relative_to(store_path.path) + relative_path = PurePosixPath(created_key).relative_to(root) created_keys.append(str(relative_path)) # convert /foo/bar/baz/.zattrs to bar/baz node_name = str(relative_path.parent) From b099fba88b7cac38c97ce344dde15455a5cdb404 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 10 Jan 2025 15:03:58 +0100 Subject: [PATCH 15/59] working _from_flat --- src/zarr/core/group.py | 201 ++++++++++++++++++++++++----------------- tests/test_group.py | 37 +++++--- 2 files changed, 143 insertions(+), 95 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 4bb909b643..e4de08e6c1 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -13,7 +13,6 @@ from typing import ( TYPE_CHECKING, Literal, - Self, TypeVar, assert_never, cast, @@ -437,35 +436,6 @@ class AsyncGroup: # TODO: make this correct and work # TODO: ensure that this can be bound properly to subclass of AsyncGroup - @classmethod - async def from_flat( - cls, - store: StoreLike, - *, - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - overwrite: bool = False, - ) -> Self: - if not _is_rooted(nodes): - msg = ( - "The input does not specify a root node. ", - "This function can only create hierarchies that contain a root node, which is ", - "defined as a group that is ancestral to all the other arrays and ", - "groups in the hierarchy.", - ) - raise ValueError(msg) - - if overwrite: - store_path = await make_store_path(store, mode="w") - else: - store_path = await make_store_path(store, mode="w-") - - semaphore = asyncio.Semaphore(config.get("async.concurrency")) - - nodes_created = { - x.name: x - async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore) - } - # TODO: make this work @classmethod async def from_store( @@ -1764,18 +1734,6 @@ async def move(self, source: str, dest: str) -> None: class Group(SyncMixin): _async_group: AsyncGroup - @classmethod - def from_flat( - cls, - store: StoreLike, - *, - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - overwrite: bool = False, - ) -> Group: - nodes = sync(AsyncGroup.from_flat(store, nodes=nodes, overwrite=overwrite)) - # return the root node of the hierarchy - return nodes[""] - @classmethod def from_store( cls, @@ -2889,23 +2847,6 @@ def array( ) -async def _save_metadata( - node: AsyncArray[Any] | AsyncGroup, - overwrite: bool, -) -> AsyncArray[Any] | AsyncGroup: - """ - Save the metadata for an array or group, and return the array or group - """ - match node: - case AsyncArray(): - await node._save_metadata(node.metadata, ensure_parents=False) - case AsyncGroup(): - await node._save_metadata(ensure_parents=False) - case _: - raise ValueError(f"Unexpected node type {type(node)}") - return node - - async def create_hierarchy( *, store_path: StorePath, @@ -2962,22 +2903,17 @@ async def create_nodes( ctx = semaphore create_tasks: list[Coroutine[None, None, str]] = [] - for key, value in nodes.items(): - create_tasks.extend( - _prepare_save_metadata(store_path.store, f"{store_path.path}/{key}", value) - ) - if store_path.path == "": - root = "/" - else: - root = store_path.path + write_key = str(PurePosixPath(store_path.path) / key) + create_tasks.extend(_persist_metadata(store_path.store, write_key, value)) + created_keys = [] async with ctx: for coro in asyncio.as_completed(create_tasks): created_key = await coro - relative_path = PurePosixPath(created_key).relative_to(root) + relative_path = PurePosixPath(created_key).relative_to(store_path.path) created_keys.append(str(relative_path)) - # convert /foo/bar/baz/.zattrs to bar/baz + # convert foo/bar/baz/.zattrs to bar/baz node_name = str(relative_path.parent) meta_out = nodes[node_name] @@ -3009,13 +2945,17 @@ async def create_nodes( T = TypeVar("T") -def _is_rooted(data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]) -> bool: +def _get_roots( + data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], +) -> tuple[str, ...]: """ - Check if the data describes a hierarchy that's rooted, which means there is a single node with - the least number of components in its key + Return the keys of the root(s) of the hierarchy """ - # a dict - return False + keys_split = sorted((key.split("/") for key in data), key=len) + groups: defaultdict[int, list[str]] = defaultdict(list) + for key_split in keys_split: + groups[len(key_split)].append("/".join(key_split)) + return tuple(groups[min(groups.keys())]) def _parse_hierarchy_dict( @@ -3066,9 +3006,9 @@ def _parse_hierarchy_dict( # Iterate over the intermediate path components *subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}") for subpath in subpaths: - # If a component is not already in the output dict, add it + # If a component is not already in the output dict, add an implicit group marker if subpath not in out: - out[subpath] = GroupMetadata(zarr_format=v.zarr_format) + out[subpath] = _ImplicitGroupMetadata(zarr_format=v.zarr_format) else: if not isinstance(out[subpath], GroupMetadata): msg = ( @@ -3268,22 +3208,115 @@ def _build_node_v2( raise ValueError(f"Unexpected metadata type: {type(metadata)}") -async def _set_return_key(store: Store, key: str, value: Buffer) -> str: +async def _set_return_key(*, store: Store, key: str, value: Buffer, replace: bool) -> str: """ - Store.set, but the key and the value are returned. - Useful when saving metadata via asyncio.as_completed, because - we need to know which key was saved. + Either write a value to storage at the given key, or ensure that there is already a value in + storage at the given key. The key is returned in either case. + Useful when saving values via routines that return results in execution order, + like asyncio.as_completed, because in this case we need to know which key was saved in order + to yield the right object to the caller. + + Parameters + ---------- + store : Store + The store to save the value to. + key : str + The key to save the value to. + value : Buffer + The value to save. + replace : bool + If True, then the value will be written even if a value associated with the key + already exists in storage. If False, an existing value will not be overwritten. """ - await store.set(key, value) + if replace: + await store.set(key, value) + else: + await store.set_if_not_exists(key, value) return key -def _prepare_save_metadata( +def _persist_metadata( store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata ) -> tuple[Coroutine[None, None, str], ...]: """ - Prepare to save a metadata document to storage. Returns a tuple of coroutines that must be awaited. + Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. + If ``metadata`` is an instance of ``_ImplicitGroupMetadata``, then _set_return_key will be invoked with + ``replace=False``, which defers to a pre-existing metadata document in storage if one exists. Otherwise, existing values will be overwritten. """ to_save = metadata.to_buffer_dict(default_buffer_prototype()) - return tuple(_set_return_key(store, f"{path}/{key}", value) for key, value in to_save.items()) + if isinstance(metadata, _ImplicitGroupMetadata): + replace = False + else: + replace = True + # TODO: should this function be a generator that yields values instead of eagerly returning a tuple? + return tuple( + _set_return_key(store=store, key=f"{path}/{key}", value=value, replace=replace) + for key, value in to_save.items() + ) + + +class _ImplicitGroupMetadata(GroupMetadata): + """ + This class represents the metadata document of a group that should created at some + location in storage if and only if there is not already a group at that location. + + This class is used to fill group-shaped "holes" in a dict specification of a Zarr hierarchy. + + When attempting to write this class to disk, the writer should first check if a Zarr group + already exists at the desired location. If such a group does exist, the writer should do nothing. + If not, the writer should write this metadata document to storage. + + """ + + def __init__( + self, + attributes: dict[str, Any] | None = None, + zarr_format: ZarrFormat = 3, + consolidated_metadata: ConsolidatedMetadata | None = None, + ) -> None: + if attributes is not None: + raise ValueError("attributes must be None for implicit groups") + + if consolidated_metadata is not None: + raise ValueError("consolidated_metadata must be None for implicit groups") + + super().__init__(attributes, zarr_format, consolidated_metadata) + + +async def _from_flat( + store: StoreLike, + *, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> AsyncGroup: + """ + Create an ``AsyncGroup`` from a store + a dict of nodes. + """ + roots = _get_roots(nodes) + if len(roots) != 1: + msg = ( + "The input does not specify a root node. " + "This function can only create hierarchies that contain a root node, which is " + "defined as a group that is ancestral to all the other arrays and " + "groups in the hierarchy." + ) + raise ValueError(msg) + else: + root = roots[0] + + if overwrite: + store_path = await make_store_path(store, mode="w") + else: + store_path = await make_store_path(store, mode="w-") + + semaphore = asyncio.Semaphore(config.get("async.concurrency")) + + nodes_created = { + x.path: x + async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore) + } + root_group = nodes_created[root] + if not isinstance(root_group, AsyncGroup): + raise TypeError("Invalid root node returned from create_hierarchy.") + return root_group diff --git a/tests/test_group.py b/tests/test_group.py index 148087ac13..08e112344a 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -21,7 +21,13 @@ from zarr.abc.store import Store from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype -from zarr.core.group import ConsolidatedMetadata, GroupMetadata, create_hierarchy, create_nodes +from zarr.core.group import ( + ConsolidatedMetadata, + GroupMetadata, + _from_flat, + create_hierarchy, + create_nodes, +) from zarr.core.sync import _collect_aiterator, sync from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore @@ -1578,20 +1584,29 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store): @pytest.mark.parametrize("store", ["memory"], indirect=True) -async def test_group_from_flat(store: Store, zarr_format): +@pytest.mark.parametrize("zarr_format", [2, 3]) +@pytest.mark.parametrize("root_key", ["", "a", "a/b"]) +async def test_group_from_flat(store: Store, zarr_format, root_key: str): """ Test that the AsyncGroup.from_flat method creates a zarr group in one shot. """ - hierarchy_spec = { - "a": GroupMetadata(zarr_format=zarr_format), - "a/b": GroupMetadata(zarr_format=zarr_format), - "a/b/c": GroupMetadata(zarr_format=zarr_format), + root_key = "a" + root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} + members_expected_meta = { + f"{root_key}/b": GroupMetadata( + zarr_format=zarr_format, attributes={"path": f"{root_key}/b"} + ), + f"{root_key}/b/c": GroupMetadata( + zarr_format=zarr_format, attributes={"path": f"{root_key}/b/c"} + ), } - g = await AsyncGroup.from_flat(store, nodes=hierarchy_spec) - assert g.members() == [ - ("b", GroupMetadata(zarr_format=zarr_format)), - ("b/c", GroupMetadata(zarr_format=zarr_format)), - ] + g = await _from_flat(store, nodes=root_meta | members_expected_meta) + members = await _collect_aiterator(g.members(max_depth=None)) + members_observed_meta = {k: v.metadata for k, v in members} + members_expected_meta_relative = { + str(PurePosixPath(k).relative_to(root_key)): v for k, v in members_expected_meta.items() + } + assert members_observed_meta == members_expected_meta_relative @pytest.mark.parametrize("store", ["memory"], indirect=True) From 4562e8600d96c341a264b99a29d5e3acd38914c7 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 10 Jan 2025 20:42:47 +0100 Subject: [PATCH 16/59] working dict serialization for _ImplicitGroupMetadata --- src/zarr/core/group.py | 7 +++++-- tests/test_group.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index e4de08e6c1..e175204909 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3258,12 +3258,12 @@ def _persist_metadata( class _ImplicitGroupMetadata(GroupMetadata): """ - This class represents the metadata document of a group that should created at some + This class represents the metadata document of a group that should be created at a location in storage if and only if there is not already a group at that location. This class is used to fill group-shaped "holes" in a dict specification of a Zarr hierarchy. - When attempting to write this class to disk, the writer should first check if a Zarr group + When attempting to write this class to storage, the writer should first check if a Zarr group already exists at the desired location. If such a group does exist, the writer should do nothing. If not, the writer should write this metadata document to storage. @@ -3283,6 +3283,9 @@ def __init__( super().__init__(attributes, zarr_format, consolidated_metadata) + def to_dict(self) -> dict[str, JSON]: + return asdict(self) + async def _from_flat( store: StoreLike, diff --git a/tests/test_group.py b/tests/test_group.py index 08e112344a..c164dba68a 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -25,6 +25,7 @@ ConsolidatedMetadata, GroupMetadata, _from_flat, + _ImplicitGroupMetadata, create_hierarchy, create_nodes, ) @@ -1609,6 +1610,28 @@ async def test_group_from_flat(store: Store, zarr_format, root_key: str): assert members_observed_meta == members_expected_meta_relative +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("zarr_format", [2, 3]) +async def test_create_hierarchy_implicit_groups(store: Store, zarr_format): + """ + Test that writing a hierarchy with implicit groups does not result in altering an existing group + """ + spath = await make_store_path(store, path="") + key = "a" + attrs = {"name": key} + _ = await _from_flat( + store, + nodes={ + key: GroupMetadata(zarr_format=zarr_format, attributes=attrs), + }, + ) + + _ = await _collect_aiterator( + create_nodes(store_path=spath, nodes={key: _ImplicitGroupMetadata(zarr_format=zarr_format)}) + ) + assert zarr.open_group(store, path=key).metadata.attributes == attrs + + @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_performance(store: Store) -> None: """ From cdfd5de9672ec6560ae538b59e9de7d8e0f821d0 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 15 Jan 2025 11:54:00 +0100 Subject: [PATCH 17/59] remove implicit group metadata, and add some key name normalization --- src/zarr/core/group.py | 80 +++++++++++++----------------------------- tests/test_group.py | 39 ++++++-------------- 2 files changed, 36 insertions(+), 83 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index e175204909..2e2c621ad5 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2881,7 +2881,8 @@ async def create_hierarchy( The created nodes in the order they are created. """ nodes_parsed = _parse_hierarchy_dict(nodes) - + if overwrite: + await store_path.delete_dir() async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore): yield node @@ -2904,17 +2905,24 @@ async def create_nodes( create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): - write_key = str(PurePosixPath(store_path.path) / key) + write_key = f"{store_path.path}/{key}".lstrip("/") create_tasks.extend(_persist_metadata(store_path.store, write_key, value)) created_keys = [] async with ctx: for coro in asyncio.as_completed(create_tasks): created_key = await coro - relative_path = PurePosixPath(created_key).relative_to(store_path.path) - created_keys.append(str(relative_path)) - # convert foo/bar/baz/.zattrs to bar/baz - node_name = str(relative_path.parent) + # the created key will be in the store key space. we have to remove the store_path.path + # component of that path to bring it back to the relative key space of store_path + + relative_path = created_key.removeprefix(store_path.path).lstrip("/") + created_keys.append(relative_path) + + if len(relative_path.split("/")) == 1: + node_name = "" + else: + node_name = "/".join(["", *relative_path.split("/")[:-1]]) + meta_out = nodes[node_name] if meta_out.zarr_format == 3: @@ -2928,18 +2936,19 @@ async def create_nodes( # so we track which keys have been created, and wait for both the meta key and # the attrs key to be created before yielding back the AsyncArray / AsyncGroup - attrs_done = f"{node_name}/.zattrs" in created_keys + attrs_done = f"{node_name}/.zattrs".lstrip("/") in created_keys if isinstance(meta_out, GroupMetadata): - meta_done = f"{node_name}/.zgroup" in created_keys + meta_done = f"{node_name}/.zgroup".lstrip("/") in created_keys else: - meta_done = f"{node_name}/.zarray" in created_keys + meta_done = f"{node_name}/.zarray".lstrip("/") in created_keys if meta_done and attrs_done: if isinstance(meta_out, GroupMetadata): yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name) else: yield AsyncArray(metadata=meta_out, store_path=store_path / node_name) + continue T = TypeVar("T") @@ -3006,9 +3015,9 @@ def _parse_hierarchy_dict( # Iterate over the intermediate path components *subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}") for subpath in subpaths: - # If a component is not already in the output dict, add an implicit group marker + # If a component is not already in the output dict, add a group if subpath not in out: - out[subpath] = _ImplicitGroupMetadata(zarr_format=v.zarr_format) + out[subpath] = GroupMetadata(zarr_format=v.zarr_format) else: if not isinstance(out[subpath], GroupMetadata): msg = ( @@ -3245,50 +3254,14 @@ def _persist_metadata( """ to_save = metadata.to_buffer_dict(default_buffer_prototype()) - if isinstance(metadata, _ImplicitGroupMetadata): - replace = False - else: - replace = True - # TODO: should this function be a generator that yields values instead of eagerly returning a tuple? return tuple( - _set_return_key(store=store, key=f"{path}/{key}", value=value, replace=replace) + _set_return_key(store=store, key=f"{path}/{key}".lstrip("/"), value=value, replace=True) for key, value in to_save.items() ) -class _ImplicitGroupMetadata(GroupMetadata): - """ - This class represents the metadata document of a group that should be created at a - location in storage if and only if there is not already a group at that location. - - This class is used to fill group-shaped "holes" in a dict specification of a Zarr hierarchy. - - When attempting to write this class to storage, the writer should first check if a Zarr group - already exists at the desired location. If such a group does exist, the writer should do nothing. - If not, the writer should write this metadata document to storage. - - """ - - def __init__( - self, - attributes: dict[str, Any] | None = None, - zarr_format: ZarrFormat = 3, - consolidated_metadata: ConsolidatedMetadata | None = None, - ) -> None: - if attributes is not None: - raise ValueError("attributes must be None for implicit groups") - - if consolidated_metadata is not None: - raise ValueError("consolidated_metadata must be None for implicit groups") - - super().__init__(attributes, zarr_format, consolidated_metadata) - - def to_dict(self) -> dict[str, JSON]: - return asdict(self) - - async def _from_flat( - store: StoreLike, + store_path: StorePath, *, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], overwrite: bool = False, @@ -3308,16 +3281,13 @@ async def _from_flat( else: root = roots[0] - if overwrite: - store_path = await make_store_path(store, mode="w") - else: - store_path = await make_store_path(store, mode="w-") - semaphore = asyncio.Semaphore(config.get("async.concurrency")) nodes_created = { x.path: x - async for x in create_hierarchy(store_path=store_path, nodes=nodes, semaphore=semaphore) + async for x in create_hierarchy( + store_path=store_path, nodes=nodes, semaphore=semaphore, overwrite=overwrite + ) } root_group = nodes_created[root] if not isinstance(root_group, AsyncGroup): diff --git a/tests/test_group.py b/tests/test_group.py index c164dba68a..c3cbc653a5 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -25,7 +25,6 @@ ConsolidatedMetadata, GroupMetadata, _from_flat, - _ImplicitGroupMetadata, create_hierarchy, create_nodes, ) @@ -1474,7 +1473,8 @@ async def test_create_nodes(store: Store, zarr_format: ZarrFormat) -> None: @pytest.mark.parametrize("store", ["memory"], indirect=True) -async def test_create_hierarchy(store: Store, zarr_format: ZarrFormat) -> None: +@pytest.mark.parametrize("overwrite", [True, False]) +async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: ZarrFormat) -> None: """ Test that ``create_hierarchy`` can create a complete Zarr hierarchy, even if the input describes an incomplete one. @@ -1487,11 +1487,15 @@ async def test_create_hierarchy(store: Store, zarr_format: ZarrFormat) -> None: "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } + pre_existing_nodes = {"extra": GroupMetadata(zarr_format=zarr_format)} + # we expect create_hierarchy to insert a group that was missing from the hierarchy spec expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} spath = await make_store_path(store, path=path) + # initialize the group with some nodes + await _collect_aiterator(_from_flat(store_path=spath, nodes=pre_existing_nodes)) observed_nodes = { str(PurePosixPath(a.name).relative_to("/" + path)): a - async for a in create_hierarchy(store_path=spath, nodes=expected_meta) + async for a in create_hierarchy(store_path=spath, nodes=expected_meta, overwrite=overwrite) } assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @@ -1587,11 +1591,12 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store): @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "a", "a/b"]) -async def test_group_from_flat(store: Store, zarr_format, root_key: str): +@pytest.mark.parametrize("path", ["", "foo"]) +async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: str): """ Test that the AsyncGroup.from_flat method creates a zarr group in one shot. """ - root_key = "a" + spath = await make_store_path(store, path=path) root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} members_expected_meta = { f"{root_key}/b": GroupMetadata( @@ -1601,7 +1606,7 @@ async def test_group_from_flat(store: Store, zarr_format, root_key: str): zarr_format=zarr_format, attributes={"path": f"{root_key}/b/c"} ), } - g = await _from_flat(store, nodes=root_meta | members_expected_meta) + g = await _from_flat(spath, nodes=root_meta | members_expected_meta) members = await _collect_aiterator(g.members(max_depth=None)) members_observed_meta = {k: v.metadata for k, v in members} members_expected_meta_relative = { @@ -1610,28 +1615,6 @@ async def test_group_from_flat(store: Store, zarr_format, root_key: str): assert members_observed_meta == members_expected_meta_relative -@pytest.mark.parametrize("store", ["memory"], indirect=True) -@pytest.mark.parametrize("zarr_format", [2, 3]) -async def test_create_hierarchy_implicit_groups(store: Store, zarr_format): - """ - Test that writing a hierarchy with implicit groups does not result in altering an existing group - """ - spath = await make_store_path(store, path="") - key = "a" - attrs = {"name": key} - _ = await _from_flat( - store, - nodes={ - key: GroupMetadata(zarr_format=zarr_format, attributes=attrs), - }, - ) - - _ = await _collect_aiterator( - create_nodes(store_path=spath, nodes={key: _ImplicitGroupMetadata(zarr_format=zarr_format)}) - ) - assert zarr.open_group(store, path=key).metadata.attributes == attrs - - @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_performance(store: Store) -> None: """ From 787d6bf4a9f031c47d97aeaa8f715a8addb8ed54 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 11:58:55 +0100 Subject: [PATCH 18/59] add path normalization routines --- src/zarr/core/group.py | 48 +++++++++++++++++++++++++++++++++--------- tests/test_group.py | 34 ++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 1cb439af5e..cf10d5e85e 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -62,6 +62,7 @@ from zarr.errors import MetadataValidationError from zarr.storage import StoreLike, StorePath from zarr.storage._common import ensure_no_existing_node, make_store_path +from zarr.storage._utils import normalize_path if TYPE_CHECKING: from collections.abc import ( @@ -2961,7 +2962,8 @@ def _get_roots( data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], ) -> tuple[str, ...]: """ - Return the keys of the root(s) of the hierarchy + Return the keys of the root(s) of the hierarchy. A root is a key with the fewest number of + path segments. """ keys_split = sorted((key.split("/") for key in data), key=len) groups: defaultdict[int, list[str]] = defaultdict(list) @@ -2978,8 +2980,8 @@ def _parse_hierarchy_dict( then return an identical copy of that dict. Otherwise, return a version of the input dict with groups added where they are needed to make the hierarchy explicit. - For example, an input of {'a/b/c': ...} will result in a return value of - {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}. + For example, an input of {'a/b/c': ArrayMetadata} will result in a return value of + {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}. The input is also checked for the following conditions, and an error is raised if any of them are violated: @@ -2990,8 +2992,6 @@ def _parse_hierarchy_dict( This function ensures that the input is transformed into a specification of a complete and valid Zarr hierarchy. """ - # Create a copy of the input dict - out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data} observed_zarr_formats: dict[ZarrFormat, list[str]] = {2: [], 3: []} @@ -3007,16 +3007,15 @@ def _parse_hierarchy_dict( f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}." "Ensure that all nodes have the same Zarr format." ) - raise ValueError(msg) + out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {} + for k, v in data.items(): # TODO: ensure that the key is a valid path - # Split the key into its path components - key_split = k.split("/") - # Iterate over the intermediate path components - *subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}") + *subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b])) + for subpath in subpaths: # If a component is not already in the output dict, add a group if subpath not in out: @@ -3032,6 +3031,35 @@ def _parse_hierarchy_dict( return out +def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: + """ + Normalize the input paths according to the normalization scheme used for zarr node paths. + If any two paths normalize to the same value, raise a ValueError. + """ + path_map: dict[str, str] = {} + for path in paths: + parsed = normalize_path(path) + if parsed in path_map: + msg = ( + f"After normalization, the value '{path}' collides with '{path_map[parsed]}'. " + f"Both '{path}' and '{path_map[parsed]}' normalize to the same value: '{parsed}'. " + f"You should use either '{path}' or '{path_map[parsed]}', but not both." + ) + raise ValueError(msg) + path_map[parsed] = path + return tuple(path_map.keys()) + + +def _normalize_path_keys(data: dict[str, T]) -> dict[str, T]: + """ + Normalize the keys of the input dict according to the normalization scheme used for zarr node + paths. If any two keys in the input normalize to the value, raise a ValueError. Return the + values of data with the normalized keys. + """ + parsed_keys = _normalize_paths(data.keys()) + return dict(zip(parsed_keys, data.values(), strict=False)) + + async def _getitem_semaphore( node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None ) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: diff --git a/tests/test_group.py b/tests/test_group.py index 6a08290e5e..3e125ccea1 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -25,6 +25,8 @@ ConsolidatedMetadata, GroupMetadata, _from_flat, + _normalize_path_keys, + _normalize_paths, create_hierarchy, create_nodes, ) @@ -32,6 +34,7 @@ from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage._common import make_store_path +from zarr.storage._utils import normalize_path from zarr.testing.store import LatencyStore from .conftest import meta_from_array, parse_store @@ -1615,6 +1618,37 @@ async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: s assert members_observed_meta == members_expected_meta_relative +@pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")]) +def test_normalize_paths_invalid(paths: tuple[str, str]): + """ + Ensure that calling _normalize_paths on values that will normalize to the same value + will generate a ValueError. + """ + a, b = paths + msg = f"After normalization, the value '{b}' collides with '{a}'. " + with pytest.raises(ValueError, match=msg): + _normalize_paths(paths) + + +@pytest.mark.parametrize( + "paths", [("/a", "a/b"), ("a", "a/b"), ("a/", "a///b"), ("/a/", "//a/b///")] +) +def test_normalize_paths_valid(paths: tuple[str, str]): + """ + Ensure that calling _normalize_paths on values that normalize to distinct values + returns a tuple of those normalized values. + """ + expected = tuple(map(normalize_path, paths)) + assert _normalize_paths(paths) == expected + + +def test_normalize_path_keys(): + data = {"": 10, "a": "hello", "a/b": None, "/a/b/c/d": None} + observed = _normalize_path_keys(data) + expected = {normalize_path(k): v for k, v in data.items()} + assert observed == expected + + @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_performance(store: Store) -> None: """ From d07435ba088063f462bf889dd2e528ffb0917bec Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 12:43:39 +0100 Subject: [PATCH 19/59] use _join_paths for safer path concatenation --- src/zarr/core/group.py | 137 +++++++++++++++++++++++++++-------------- tests/test_group.py | 44 ++++++++----- 2 files changed, 120 insertions(+), 61 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 1cb439af5e..65dfd5442e 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -9,7 +9,6 @@ from collections import defaultdict from dataclasses import asdict, dataclass, field, fields, replace from itertools import accumulate -from pathlib import PurePosixPath from typing import ( TYPE_CHECKING, Literal, @@ -2079,7 +2078,9 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], def create_hierarchy( self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata] - ) -> dict[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: + ) -> Iterator[ + tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] + ]: """ Create a hierarchy of arrays or groups rooted at this group. @@ -2097,6 +2098,14 @@ def create_hierarchy( ------- A dict containing the created nodes, with the same keys as the input """ + if "" in nodes: + msg = ( + "Found the key '' in nodes, which denotes the root group. Creating the root group " + "from an existing group is not supported. If you want to create an entire Zarr group, " + "including the root group, from a dict then use the _from_flat method." + ) + raise ValueError(msg) + # check that all the nodes have the same zarr_format as Self for key, value in nodes.items(): if value.zarr_format != self.metadata.zarr_format: @@ -2107,12 +2116,8 @@ def create_hierarchy( ) raise ValueError(msg) nodes_created = self._sync_iter(self._async_group.create_hierarchy(nodes)) - if self.path == "": - root = "/" - else: - root = self.path - # TODO: make this safe against invalid path inputs - return {str(PurePosixPath(n.name).relative_to(root)): n for n in nodes_created} + for n in nodes_created: + yield (_join_paths([self.path, n.name]), n) def keys(self) -> Generator[str, None]: """Return an iterator over group member names. @@ -2884,8 +2889,12 @@ async def create_hierarchy( The created nodes in the order they are created. """ nodes_parsed = _parse_hierarchy_dict(nodes) + if overwrite: await store_path.delete_dir() + else: + # TODO: check if any of the nodes already exist, and error if so + raise NotImplementedError async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore): yield node @@ -2897,10 +2906,11 @@ async def create_nodes( semaphore: asyncio.Semaphore | None = None, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ - Create a collection of zarr v2 arrays and groups concurrently and atomically. To ensure atomicity, + Create a collection of zarr arrays and groups concurrently and atomically. To ensure atomicity, no attempt is made to ensure that intermediate groups are created. """ ctx: asyncio.Semaphore | contextlib.nullcontext[None] + if semaphore is None: ctx = contextlib.nullcontext() else: @@ -2908,49 +2918,62 @@ async def create_nodes( create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): - write_key = f"{store_path.path}/{key}".lstrip("/") - create_tasks.extend(_persist_metadata(store_path.store, write_key, value)) + # transform the key, which is relative to a store_path.path, to a key in the store + write_prefix = _join_paths([store_path.path, key]) + create_tasks.extend(_persist_metadata(store_path.store, write_prefix, value)) - created_keys = [] + created_object_keys = [] async with ctx: for coro in asyncio.as_completed(create_tasks): created_key = await coro - # the created key will be in the store key space. we have to remove the store_path.path + + # the created key will be in the store key space, and it will end with the name of + # a metadata document. + # we have to remove the store_path.path # component of that path to bring it back to the relative key space of store_path - relative_path = created_key.removeprefix(store_path.path).lstrip("/") - created_keys.append(relative_path) + # the relative path of the object we just created -- we need this to track which metadata documents + # were written so that we can yield a complete v2 Array / Group class after both .zattrs + # and the metadata JSON was created. + object_path_relative = created_key.removeprefix(store_path.path).lstrip("/") + created_object_keys.append(object_path_relative) - if len(relative_path.split("/")) == 1: + # get the node name from the object key + if len(object_path_relative.split("/")) == 1: + # this is the root node + meta_out = nodes[""] node_name = "" else: - node_name = "/".join(["", *relative_path.split("/")[:-1]]) - - meta_out = nodes[node_name] + # turn "foo/" into "foo" + node_name = object_path_relative[: object_path_relative.rfind("/")] + meta_out = nodes[node_name] if meta_out.zarr_format == 3: + # yes, it is silly that we relativize, then de-relativize this same path + node_store_path = store_path / node_name if isinstance(meta_out, GroupMetadata): - yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name) + yield AsyncGroup(metadata=meta_out, store_path=node_store_path) else: - yield AsyncArray(metadata=meta_out, store_path=store_path / node_name) + yield AsyncArray(metadata=meta_out, store_path=node_store_path) else: # For zarr v2 # we only want to yield when both the metadata and attributes are created # so we track which keys have been created, and wait for both the meta key and # the attrs key to be created before yielding back the AsyncArray / AsyncGroup - attrs_done = f"{node_name}/.zattrs".lstrip("/") in created_keys + attrs_done = _join_paths([node_name, ZATTRS_JSON]) in created_object_keys if isinstance(meta_out, GroupMetadata): - meta_done = f"{node_name}/.zgroup".lstrip("/") in created_keys + meta_done = _join_paths([node_name, ZGROUP_JSON]) in created_object_keys else: - meta_done = f"{node_name}/.zarray".lstrip("/") in created_keys + meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys if meta_done and attrs_done: + node_store_path = store_path / node_name if isinstance(meta_out, GroupMetadata): - yield AsyncGroup(metadata=meta_out, store_path=store_path / node_name) + yield AsyncGroup(metadata=meta_out, store_path=node_store_path) else: - yield AsyncArray(metadata=meta_out, store_path=store_path / node_name) + yield AsyncArray(metadata=meta_out, store_path=node_store_path) continue @@ -2963,6 +2986,8 @@ def _get_roots( """ Return the keys of the root(s) of the hierarchy """ + if "" in data: + return ("",) keys_split = sorted((key.split("/") for key in data), key=len) groups: defaultdict[int, list[str]] = defaultdict(list) for key_split in keys_split: @@ -2970,6 +2995,15 @@ def _get_roots( return tuple(groups[min(groups.keys())]) +def _join_paths(paths: Iterable[str]) -> str: + """ + Filter out instances of '' and join the remaining strings with '/'. + + Because the root node of a zarr hierarchy is represented by an empty string, + """ + return "/".join(filter(lambda v: v != "", paths)) + + def _parse_hierarchy_dict( data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], ) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: @@ -2993,7 +3027,7 @@ def _parse_hierarchy_dict( # Create a copy of the input dict out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data} - observed_zarr_formats: dict[ZarrFormat, list[str]] = {2: [], 3: []} + observed_zarr_formats: dict[ZarrFormat, list[str | None]] = {2: [], 3: []} # We will iterate over the dict again, but a full pass here ensures that the error message # is comprehensive, and I think the performance cost will be negligible. @@ -3011,23 +3045,30 @@ def _parse_hierarchy_dict( raise ValueError(msg) for k, v in data.items(): - # TODO: ensure that the key is a valid path - # Split the key into its path components - key_split = k.split("/") - - # Iterate over the intermediate path components - *subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}") - for subpath in subpaths: - # If a component is not already in the output dict, add a group - if subpath not in out: - out[subpath] = GroupMetadata(zarr_format=v.zarr_format) - else: - if not isinstance(out[subpath], GroupMetadata): - msg = ( - f"The node at {subpath} contains other nodes, but it is not a Zarr group. " - "This is invalid. Only Zarr groups can contain other nodes." - ) - raise ValueError(msg) + if k is None: + # root node + pass + else: + if k.startswith("/"): + msg = f"Keys of hierarchy dicts must be relative paths, i.e. they cannot start with '/'. Got {k}, which violates this rule." + raise ValueError(k) + # TODO: ensure that the key is a valid path + # Split the key into its path components + key_split = k.split("/") + + # Iterate over the intermediate path components + *subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}") + for subpath in subpaths: + # If a component is not already in the output dict, add a group + if subpath not in out: + out[subpath] = GroupMetadata(zarr_format=v.zarr_format) + else: + if not isinstance(out[subpath], GroupMetadata): + msg = ( + f"The node at {subpath} contains other nodes, but it is not a Zarr group. " + "This is invalid. Only Zarr groups can contain other nodes." + ) + raise ValueError(msg) return out @@ -3258,7 +3299,7 @@ def _persist_metadata( to_save = metadata.to_buffer_dict(default_buffer_prototype()) return tuple( - _set_return_key(store=store, key=f"{path}/{key}".lstrip("/"), value=value, replace=True) + _set_return_key(store=store, key=_join_paths([path, key]), value=value, replace=True) for key, value in to_save.items() ) @@ -3278,7 +3319,7 @@ async def _from_flat( "The input does not specify a root node. " "This function can only create hierarchies that contain a root node, which is " "defined as a group that is ancestral to all the other arrays and " - "groups in the hierarchy." + "groups in the hierarchy, or a single array." ) raise ValueError(msg) else: @@ -3292,7 +3333,9 @@ async def _from_flat( store_path=store_path, nodes=nodes, semaphore=semaphore, overwrite=overwrite ) } - root_group = nodes_created[root] + # the names of the created nodes will be relative to the store_path instance + root_relative_to_store_path = _join_paths([store_path.path, root]) + root_group = nodes_created[root_relative_to_store_path] if not isinstance(root_group, AsyncGroup): raise TypeError("Invalid root node returned from create_hierarchy.") return root_group diff --git a/tests/test_group.py b/tests/test_group.py index 6a08290e5e..fe9eadabf6 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -25,6 +25,7 @@ ConsolidatedMetadata, GroupMetadata, _from_flat, + _join_paths, create_hierarchy, create_nodes, ) @@ -1492,7 +1493,7 @@ async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: Zarr expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} spath = await make_store_path(store, path=path) # initialize the group with some nodes - await _collect_aiterator(_from_flat(store_path=spath, nodes=pre_existing_nodes)) + await _from_flat(store_path=spath, nodes=pre_existing_nodes) observed_nodes = { str(PurePosixPath(a.name).relative_to("/" + path)): a async for a in create_hierarchy(store_path=spath, nodes=expected_meta, overwrite=overwrite) @@ -1501,7 +1502,8 @@ async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: Zarr @pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat): +@pytest.mark.parametrize("overwrite", [True, False]) +def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat, overwrite: bool): """ Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict. """ @@ -1533,7 +1535,7 @@ def test_group_create_hierarchy_invalid_mixed_zarr_format(store: Store, zarr_for msg = "The zarr_format of the nodes must be the same as the parent group." with pytest.raises(ValueError, match=msg): - _ = g.create_hierarchy(tree) + _ = tuple(g.create_hierarchy(tree)) @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1588,9 +1590,9 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store): ) -@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) -@pytest.mark.parametrize("root_key", ["", "a", "a/b"]) +@pytest.mark.parametrize("root_key", ["", "root"]) @pytest.mark.parametrize("path", ["", "foo"]) async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: str): """ @@ -1598,19 +1600,33 @@ async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: s """ spath = await make_store_path(store, path=path) root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} - members_expected_meta = { - f"{root_key}/b": GroupMetadata( - zarr_format=zarr_format, attributes={"path": f"{root_key}/b"} - ), - f"{root_key}/b/c": GroupMetadata( - zarr_format=zarr_format, attributes={"path": f"{root_key}/b/c"} - ), + group_names = ["a", "a/b"] + array_names = ["a/b/c", "a/b/d"] + + # just to ensure that we don't use the same name twice in tests + assert set(group_names) & set(array_names) == set() + + groups_expected_meta = { + _join_paths([root_key, node_name]): GroupMetadata( + zarr_format=zarr_format, attributes={"path": node_name} + ) + for node_name in group_names + } + arrays_expected_meta = { + _join_paths([root_key, node_name]): meta_from_array(np.zeros(4), zarr_format=zarr_format) + for node_name in array_names } - g = await _from_flat(spath, nodes=root_meta | members_expected_meta) + + nodes_create = root_meta | groups_expected_meta | arrays_expected_meta + + g = await _from_flat(spath, nodes=nodes_create, overwrite=True) + assert g.metadata.attributes == {"path": root_key} + members = await _collect_aiterator(g.members(max_depth=None)) members_observed_meta = {k: v.metadata for k, v in members} members_expected_meta_relative = { - str(PurePosixPath(k).relative_to(root_key)): v for k, v in members_expected_meta.items() + k.removeprefix(root_key).lstrip("/"): v + for k, v in (groups_expected_meta | arrays_expected_meta).items() } assert members_observed_meta == members_expected_meta_relative From 63dd07fd991dfeb0cdbbe81dde11c1abaca1f166 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 17:40:26 +0100 Subject: [PATCH 20/59] handle overwrite --- src/zarr/core/group.py | 280 +++++++++++++++++++++++++++++------------ src/zarr/errors.py | 7 ++ tests/test_group.py | 38 +++++- 3 files changed, 239 insertions(+), 86 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 3592b5ce1a..7b1bfe5f77 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -58,7 +58,12 @@ from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v3 import V3JsonEncoder from zarr.core.sync import SyncMixin, sync -from zarr.errors import MetadataValidationError +from zarr.errors import ( + ContainsArrayError, + ContainsGroupError, + MetadataValidationError, + RootedHierarchyError, +) from zarr.storage import StoreLike, StorePath from zarr.storage._common import ensure_no_existing_node, make_store_path from zarr.storage._utils import normalize_path @@ -683,53 +688,14 @@ async def getitem( """ store_path = self.store_path / key logger.debug("key=%s, store_path=%s", key, store_path) - metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata # Consolidated metadata lets us avoid some I/O operations so try that first. if self.metadata.consolidated_metadata is not None: return self._getitem_consolidated(store_path, key, prefix=self.name) - - # Note: - # in zarr-python v2, we first check if `key` references an Array, else if `key` references - # a group,using standalone `contains_array` and `contains_group` functions. These functions - # are reusable, but for v3 they would perform redundant I/O operations. - # Not clear how much of that strategy we want to keep here. elif self.metadata.zarr_format == 3: - zarr_json_bytes = await (store_path / ZARR_JSON).get() - if zarr_json_bytes is None: - raise KeyError(key) - else: - zarr_json = json.loads(zarr_json_bytes.to_bytes()) - metadata = _build_metadata_v3(zarr_json) - return _build_node_v3(metadata, store_path) - + return await _read_node_v3(store_path=store_path) elif self.metadata.zarr_format == 2: - # Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs? - # This guarantees that we will always make at least one extra request to the store - zgroup_bytes, zarray_bytes, zattrs_bytes = await asyncio.gather( - (store_path / ZGROUP_JSON).get(), - (store_path / ZARRAY_JSON).get(), - (store_path / ZATTRS_JSON).get(), - ) - - if zgroup_bytes is None and zarray_bytes is None: - raise KeyError(key) - - # unpack the zarray, if this is None then we must be opening a group - zarray = json.loads(zarray_bytes.to_bytes()) if zarray_bytes else None - zgroup = json.loads(zgroup_bytes.to_bytes()) if zgroup_bytes else None - # unpack the zattrs, this can be None if no attrs were written - zattrs = json.loads(zattrs_bytes.to_bytes()) if zattrs_bytes is not None else {} - - if zarray is not None: - metadata = _build_metadata_v2(zarray, zattrs) - return _build_node_v2(metadata=metadata, store_path=store_path) - else: - # this is just for mypy - if TYPE_CHECKING: - assert zgroup is not None - metadata = _build_metadata_v2(zgroup, zattrs) - return _build_node_v2(metadata=metadata, store_path=store_path) + return await _read_node_v2(store_path=store_path) else: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") @@ -1431,7 +1397,9 @@ async def _members( # TODO: find a better name for this. create_tree could work. # TODO: include an example in the docstring async def create_hierarchy( - self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata] + self, + nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], + overwrite: bool, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ Create a hierarchy of arrays or groups rooted at this group. @@ -1445,13 +1413,20 @@ async def create_hierarchy( ---------- nodes : A dictionary representing the hierarchy to create + overwrite : bool + Whether or not existing arrays / groups should be replaced. + Returns ------- - An asynchronous iterator over the created nodes. + An asynchronous iterator over the created arrays and / or groups. """ semaphore = asyncio.Semaphore(config.get("async.concurrency")) async for node in create_hierarchy( - store_path=self.store_path, nodes=nodes, semaphore=semaphore + store_path=self.store_path, + nodes=nodes, + semaphore=semaphore, + overwrite=overwrite, + allow_root=False, ): yield node @@ -2078,7 +2053,9 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], return tuple((kv[0], _parse_async_node(kv[1])) for kv in _members) def create_hierarchy( - self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata] + self, + nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], + overwrite: bool = False, ) -> Iterator[ tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] ]: @@ -2099,14 +2076,6 @@ def create_hierarchy( ------- A dict containing the created nodes, with the same keys as the input """ - if "" in nodes: - msg = ( - "Found the key '' in nodes, which denotes the root group. Creating the root group " - "from an existing group is not supported. If you want to create an entire Zarr group, " - "including the root group, from a dict then use the _from_flat method." - ) - raise ValueError(msg) - # check that all the nodes have the same zarr_format as Self for key, value in nodes.items(): if value.zarr_format != self.metadata.zarr_format: @@ -2116,9 +2085,19 @@ def create_hierarchy( f" has zarr_format {self.metadata.zarr_format}." ) raise ValueError(msg) - nodes_created = self._sync_iter(self._async_group.create_hierarchy(nodes)) - for n in nodes_created: - yield (_join_paths([self.path, n.name]), n) + try: + nodes_created = self._sync_iter( + self._async_group.create_hierarchy(nodes, overwrite=overwrite) + ) + for n in nodes_created: + yield (_join_paths([self.path, n.name]), n) + except RootedHierarchyError as e: + msg = ( + "The input defines a root node, but a root node already exists, namely this Group instance." + "It is an error to use this method to create a root node. " + "Remove the root node from the input dict, or use a function like _from_flat to create a rooted hierarchy." + ) + raise ValueError(msg) from e def keys(self) -> Generator[str, None]: """Return an iterator over group member names. @@ -2862,6 +2841,7 @@ async def create_hierarchy( nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], semaphore: asyncio.Semaphore | None = None, overwrite: bool = False, + allow_root: bool = True, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input @@ -2883,6 +2863,11 @@ async def create_hierarchy( semaphore : asyncio.Semaphore | None An optional semaphore to limit the number of concurrent tasks. If not provided, the number of concurrent tasks is not limited. + allow_root : bool + Whether to allow a root node to be created. If ``False``, attempting to create a root node + will result in an error. Use this option when calling this function as part of a method + defined on ``AsyncGroup`` instances, because in this case the root node has already been + created. Yields ------ @@ -2891,11 +2876,75 @@ async def create_hierarchy( """ nodes_parsed = _parse_hierarchy_dict(nodes) - if overwrite: - await store_path.delete_dir() - else: - # TODO: check if any of the nodes already exist, and error if so - raise NotImplementedError + if not allow_root and "" in nodes_parsed: + msg = ( + "Found the key '' in nodes (after key name normalization). That key denotes the root of a hierarchy, but ``allow_root`` is False, and so creating this node " + "is not allowed. Either remove this key from ``nodes``, or set ``allow_root`` to True." + ) + raise RootedHierarchyError(msg) + + # we allow creating empty hierarchies -- it's a no-op + if len(nodes_parsed) > 0: + if overwrite: + await store_path.delete_dir() + else: + # attempt to fetch all of the metadata described in hierarchy + # first figure out which zarr format we are dealing with + sample, *_ = nodes_parsed.values() + redundant_implicit_groups = [] + # TODO: decide if this set difference is sufficient for detecting implicit groups. + # an alternative would be to use an explicit implicit group class. + + implicit_group_names = set(nodes_parsed.keys()) - set(nodes.keys()) + + zarr_format = sample.zarr_format + if zarr_format == 3 or zarr_format == 2: + func = _read_metadata_v3 + else: + raise ValueError(f"Invalid zarr_format: {zarr_format}") + + coros = (func(store_path=store_path / key) for key in nodes_parsed) + extant_node_query = dict( + zip( + nodes_parsed.keys(), + await asyncio.gather(*coros, return_exceptions=True), + strict=False, + ) + ) + + for key, value in extant_node_query.items(): + if isinstance(value, BaseException): + if isinstance(value, KeyError): + # ignore KeyErrors, because they represent nodes we can safely create + pass + else: + # Any other exception is a real error + raise value + else: + # this is a node that already exists, but a node with this name was specified in + # nodes_parsed. + # Two cases produce exceptions: + # 1. The node is a group, and a node with this name was explicitly defined in + # nodes + # 2. The node is an array. + # The third case is when this extant node is a group, but its name was not + # explicitly defined in nodes. This means it was added as an implicit group by + # _parse_hierarchy_dict, and we can remove the reference to this node from + # nodes_parsed. We don't need to create this node. + + if isinstance(value, GroupMetadata): + if key not in implicit_group_names: + raise ContainsGroupError(store_path.store, key) + else: + # as there is already a group with this name, we should not create a new one + redundant_implicit_groups.append(key) + elif isinstance(value, ArrayV2Metadata | ArrayV3Metadata): + raise ContainsArrayError(store_path.store, key) + + nodes_parsed = { + k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups + } + async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore): yield node @@ -3043,12 +3092,19 @@ def _parse_hierarchy_dict( ) raise ValueError(msg) - out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {} + # normalize the keys of the dict + + data_normed: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = ( + _normalize_path_keys(data) + ) + + out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data_normed} for k, v in data.items(): # TODO: ensure that the key is a valid path key_split = k.split("/") - *subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b])) + # we use /.join here because it checks the types of its inputs, unlike an f string + *subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b])) # noqa: FLY002 for subpath in subpaths: # If a component is not already in the output dict, add a group @@ -3061,7 +3117,6 @@ def _parse_hierarchy_dict( "This is invalid. Only Zarr groups can contain other nodes." ) raise ValueError(msg) - return out @@ -3084,7 +3139,7 @@ def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: return tuple(path_map.keys()) -def _normalize_path_keys(data: dict[str, T]) -> dict[str, T]: +def _normalize_path_keys(data: Mapping[str, T]) -> dict[str, T]: """ Normalize the keys of the input dict according to the normalization scheme used for zarr node paths. If any two keys in the input normalize to the value, raise a ValueError. Return the @@ -3212,20 +3267,56 @@ async def _iter_members_deep( yield key, node -def _resolve_metadata_v2( - blobs: tuple[str | bytes | bytearray, str | bytes | bytearray], -) -> ArrayV2Metadata | GroupMetadata: - zarr_metadata = json.loads(blobs[0]) - attrs = json.loads(blobs[1]) - if "shape" in zarr_metadata: - return ArrayV2Metadata.from_dict(zarr_metadata | {"attrs": attrs}) +async def _read_metadata_v3(store_path: StorePath) -> ArrayV3Metadata | GroupMetadata: + """ + Given a store_path, return ArrayV3Metadata or GroupMetadata defined by the metadata + document stored at store_path.path / zarr.json. If no such document is found, raise a KeyError. + """ + zarr_json_bytes = await (store_path / ZARR_JSON).get() + if zarr_json_bytes is None: + raise KeyError(store_path.path) + else: + zarr_json = json.loads(zarr_json_bytes.to_bytes()) + return _build_metadata_v3(zarr_json) + + +async def _read_metadata_v2(store_path: StorePath) -> ArrayV2Metadata | GroupMetadata: + """ + Given a store_path, return ArrayV2Metadata or GroupMetadata defined by the metadata + document stored at store_path.path / (.zgroup | .zarray). If no such document is found, + raise a KeyError. + """ + # TODO: consider first fetching array metadata, and only fetching group metadata when we don't + # find an array + zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather( + (store_path / ZARRAY_JSON).get(), + (store_path / ZGROUP_JSON).get(), + (store_path / ZATTRS_JSON).get(), + ) + + if zattrs_bytes is None: + zattrs = {} else: - return GroupMetadata.from_dict(zarr_metadata | {"attrs": attrs}) + zattrs = json.loads(zattrs_bytes.to_bytes()) + + # TODO: decide how to handle finding both array and group metadata. The spec does not seem to + # consider this situation. A practical approach would be to ignore that combination, and only + # return the array metadata. + if zarray_bytes is not None: + zmeta = json.loads(zarray_bytes.to_bytes()) + else: + if zgroup_bytes is None: + # neither .zarray or .zgroup were found results in KeyError + raise KeyError(store_path.path) + else: + zmeta = json.loads(zgroup_bytes.to_bytes()) + return _build_metadata_v2(zmeta, zattrs) -def _build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata: + +def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMetadata: """ - Take a dict and convert it into the correct metadata type. + Convert a dict representation of Zarr V3 metadata into the corresponding metadata class. """ if "node_type" not in zarr_json: raise KeyError("missing `node_type` key in metadata document.") @@ -3239,10 +3330,10 @@ def _build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMeta def _build_metadata_v2( - zarr_json: dict[str, Any], attrs_json: dict[str, Any] + zarr_json: dict[str, object], attrs_json: dict[str, JSON] ) -> ArrayV2Metadata | GroupMetadata: """ - Take a dict and convert it into the correct metadata type. + Convert a dict representation of Zarr V2 metadata into the corresponding metadata class. """ match zarr_json: case {"shape": _}: @@ -3282,6 +3373,37 @@ def _build_node_v2( raise ValueError(f"Unexpected metadata type: {type(metadata)}") +async def _read_node_v2(store_path: StorePath) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: + """ + Read a Zarr v2 AsyncArray or AsyncGroup from a location defined by a StorePath. + """ + metadata = await _read_metadata_v2(store_path=store_path) + return _build_node_v2(metadata=metadata, store_path=store_path) + + +async def _read_node_v3(store_path: StorePath) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: + """ + Read a Zarr v3 AsyncArray or AsyncGroup from a location defined by a StorePath. + """ + metadata = await _read_metadata_v3(store_path=store_path) + return _build_node_v3(metadata=metadata, store_path=store_path) + + +async def _read_node( + store_path: StorePath, zarr_format: ZarrFormat +) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: + """ + Read and AsyncArray or AsyncGroup from a location defined by a StorePath. + """ + match zarr_format: + case 2: + return await _read_node_v2(store_path=store_path) + case 3: + return await _read_node_v3(store_path=store_path) + case _: + raise ValueError(f"Unexpected zarr format: {zarr_format}") + + async def _set_return_key(*, store: Store, key: str, value: Buffer, replace: bool) -> str: """ Either write a value to storage at the given key, or ensure that there is already a value in @@ -3314,8 +3436,6 @@ def _persist_metadata( ) -> tuple[Coroutine[None, None, str], ...]: """ Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. - If ``metadata`` is an instance of ``_ImplicitGroupMetadata``, then _set_return_key will be invoked with - ``replace=False``, which defers to a pre-existing metadata document in storage if one exists. Otherwise, existing values will be overwritten. """ to_save = metadata.to_buffer_dict(default_buffer_prototype()) diff --git a/src/zarr/errors.py b/src/zarr/errors.py index 441cdab9a3..855ea51b9d 100644 --- a/src/zarr/errors.py +++ b/src/zarr/errors.py @@ -57,3 +57,10 @@ class NodeTypeValidationError(MetadataValidationError): This can be raised when the value is invalid or unexpected given the context, for example an 'array' node when we expected a 'group'. """ + + +class RootedHierarchyError(BaseZarrError): + """ + Exception raised when attempting to create a rooted hierarchy in a context where that is not + permitted. + """ diff --git a/tests/test_group.py b/tests/test_group.py index def4fc554a..3d599141b7 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -28,6 +28,7 @@ _join_paths, _normalize_path_keys, _normalize_paths, + _read_node, create_hierarchy, create_nodes, ) @@ -1487,20 +1488,29 @@ async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: Zarr hierarchy_spec = { "group": GroupMetadata(attributes={"foo": 10}, zarr_format=zarr_format), "group/array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), - "group/array_1": meta_from_array(np.arange(4), zarr_format=zarr_format), "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), - "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } - pre_existing_nodes = {"extra": GroupMetadata(zarr_format=zarr_format)} + pre_existing_nodes = { + "group/extra": GroupMetadata(zarr_format=zarr_format, attributes={"name": "extra"}), + "": GroupMetadata(zarr_format=zarr_format, attributes={"name": "root"}), + } # we expect create_hierarchy to insert a group that was missing from the hierarchy spec expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} spath = await make_store_path(store, path=path) + # initialize the group with some nodes - await _from_flat(store_path=spath, nodes=pre_existing_nodes) + sync(_collect_aiterator(create_nodes(store_path=spath, nodes=pre_existing_nodes))) + observed_nodes = { str(PurePosixPath(a.name).relative_to("/" + path)): a async for a in create_hierarchy(store_path=spath, nodes=expected_meta, overwrite=overwrite) } + if not overwrite: + extra_group = await _read_node(spath / "group/extra", zarr_format=zarr_format) + assert extra_group.metadata.attributes == {"name": "extra"} + else: + with pytest.raises(KeyError): + await _read_node(spath / "group/extra", zarr_format=zarr_format) assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @@ -1518,15 +1528,31 @@ def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat, overwrite np.zeros(5), zarr_format=zarr_format, attributes={"name": "a/b/c"} ), } - nodes = g.create_hierarchy(tree) + nodes = g.create_hierarchy(tree, overwrite=overwrite) for k, v in g.members(max_depth=None): assert v.metadata == tree[k] == nodes[k].metadata +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("overwrite", [True, False]) +def test_group_create_hierarchy_no_root(store: Store, zarr_format: ZarrFormat, overwrite: bool): + """ + Test that the Group.create_hierarchy method will error if the dict provided contains a root. + """ + g = Group.from_store(store, zarr_format=zarr_format) + tree = { + "": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), + } + with pytest.raises( + ValueError, match="It is an error to use this method to create a root node. " + ): + _ = tuple(g.create_hierarchy(tree, overwrite=overwrite)) + + @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_create_hierarchy_invalid_mixed_zarr_format(store: Store, zarr_format: ZarrFormat): """ - Test that ```Group.create_hierarchy``` will raise an error if the zarr_format of the nodes is + Test that ``Group.create_hierarchy`` will raise an error if the zarr_format of the nodes is different from the parent group. """ other_format = 2 if zarr_format == 3 else 3 From 15c4a7ef045a5477d95e728ec7fa9efa0221ee70 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 19:18:08 +0100 Subject: [PATCH 21/59] rename _from_flat to _create_rooted_hierarchy, add sync version --- src/zarr/core/group.py | 34 +++++++++++++++++++++++++++------- tests/test_group.py | 4 ++-- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 7b1bfe5f77..2d3e9567b7 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3445,14 +3445,16 @@ def _persist_metadata( ) -async def _from_flat( +async def _create_rooted_hierarchy( store_path: StorePath, *, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], overwrite: bool = False, -) -> AsyncGroup: +) -> AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: """ - Create an ``AsyncGroup`` from a store + a dict of nodes. + Create an ``AsyncGroup`` or ``AsyncArray`` from a store and a dict of metadata documents. + This function ensures that its input contains a specification of a root node, + calls ``create_hierarchy`` to create nodes, and returns the root node of the hierarchy. """ roots = _get_roots(nodes) if len(roots) != 1: @@ -3476,7 +3478,25 @@ async def _from_flat( } # the names of the created nodes will be relative to the store_path instance root_relative_to_store_path = _join_paths([store_path.path, root]) - root_group = nodes_created[root_relative_to_store_path] - if not isinstance(root_group, AsyncGroup): - raise TypeError("Invalid root node returned from create_hierarchy.") - return root_group + return nodes_created[root_relative_to_store_path] + + +def _create_rooted_hierarchy_sync( + store_path: StorePath, + *, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> Group | Array: + """ + Create a ``Group`` from a store and a dict of metadata documents. Calls the async method + ``_create_rooted_hierarchy`` and waits for the result. + """ + async_node = sync( + _create_rooted_hierarchy(store_path=store_path, nodes=nodes, overwrite=overwrite) + ) + if isinstance(async_node, AsyncGroup): + return Group(async_node) + elif isinstance(async_node, AsyncArray): + return Array(async_node) + else: + raise TypeError(f"Unexpected node type: {type(async_node)}") diff --git a/tests/test_group.py b/tests/test_group.py index 3d599141b7..f55e76ade2 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -24,7 +24,7 @@ from zarr.core.group import ( ConsolidatedMetadata, GroupMetadata, - _from_flat, + _create_rooted_hierarchy, _join_paths, _normalize_path_keys, _normalize_paths, @@ -1648,7 +1648,7 @@ async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: s nodes_create = root_meta | groups_expected_meta | arrays_expected_meta - g = await _from_flat(spath, nodes=nodes_create, overwrite=True) + g = await _create_rooted_hierarchy(spath, nodes=nodes_create, overwrite=True) assert g.metadata.attributes == {"path": root_key} members = await _collect_aiterator(g.members(max_depth=None)) From bd9afd1909f02aad798a040f7f4f0a89f41a3e18 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 19:52:36 +0100 Subject: [PATCH 22/59] add test for _create_rooted_hierarchy when the output should be an array, and for when the input is invalid --- tests/test_group.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) diff --git a/tests/test_group.py b/tests/test_group.py index f55e76ade2..da9c94cf02 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1623,9 +1623,9 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store): @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) @pytest.mark.parametrize("path", ["", "foo"]) -async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: str): +async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: str, root_key: str): """ - Test that the AsyncGroup.from_flat method creates a zarr group in one shot. + Test that the _create_rooted_hierarchy can create a group. """ spath = await make_store_path(store, path=path) root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} @@ -1660,6 +1660,42 @@ async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: s assert members_observed_meta == members_expected_meta_relative +@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) +@pytest.mark.parametrize("zarr_format", [2, 3]) +@pytest.mark.parametrize("root_key", ["", "root"]) +@pytest.mark.parametrize("path", ["", "foo"]) +async def test_create_rooted_hierarchy_array(store: Store, zarr_format, path: str, root_key: str): + """ + Test that the _create_rooted_hierarchy can create an array. + """ + spath = await make_store_path(store, path=path) + root_meta = { + root_key: meta_from_array( + np.arange(3), zarr_format=zarr_format, attributes={"path": root_key} + ) + } + + nodes_create = root_meta + + a = await _create_rooted_hierarchy(spath, nodes=nodes_create, overwrite=True) + assert a.metadata.attributes == {"path": root_key} + + +async def test_create_rooted_hierarchy_invalid(): + """ + Ensure _create_rooted_hierarchy will raise a ValueError if the input does not contain + a root node. + """ + zarr_format = 3 + nodes = { + "a": GroupMetadata(zarr_format=zarr_format), + "b": GroupMetadata(zarr_format=zarr_format), + } + msg = "The input does not specify a root node. " + with pytest.raises(ValueError, match=msg): + await _create_rooted_hierarchy(store_path=StorePath("store"), nodes=nodes) + + @pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")]) def test_normalize_paths_invalid(paths: tuple[str, str]): """ From 8be3876e9bd9ff3f03402ae5615c5e99bb53a207 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 20:04:23 +0100 Subject: [PATCH 23/59] increase coverage, one way or another --- src/zarr/core/group.py | 12 ++++++----- tests/test_group.py | 45 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 8c49e12a42..8ea1271fc2 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3326,7 +3326,9 @@ def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMet case {"node_type": "group"}: return GroupMetadata.from_dict(zarr_json) case _: - raise ValueError("invalid value for `node_type` key in metadata document") + raise ValueError( + "invalid value for `node_type` key in metadata document" + ) # pragma: no cover def _build_metadata_v2( @@ -3354,7 +3356,7 @@ def _build_node_v3( case GroupMetadata(): return AsyncGroup(metadata, store_path=store_path) case _: - raise ValueError(f"Unexpected metadata type: {type(metadata)}") + raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover def _build_node_v2( @@ -3370,7 +3372,7 @@ def _build_node_v2( case GroupMetadata(): return AsyncGroup(metadata, store_path=store_path) case _: - raise ValueError(f"Unexpected metadata type: {type(metadata)}") + raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover async def _read_node_v2(store_path: StorePath) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: @@ -3401,7 +3403,7 @@ async def _read_node( case 3: return await _read_node_v3(store_path=store_path) case _: - raise ValueError(f"Unexpected zarr format: {zarr_format}") + raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover async def _set_return_key(*, store: Store, key: str, value: Buffer, replace: bool) -> str: @@ -3499,4 +3501,4 @@ def _create_rooted_hierarchy_sync( elif isinstance(async_node, AsyncArray): return Array(async_node) else: - raise TypeError(f"Unexpected node type: {type(async_node)}") + raise TypeError(f"Unexpected node type: {type(async_node)}") # pragma: no cover diff --git a/tests/test_group.py b/tests/test_group.py index da9c94cf02..159f4a3af8 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -25,6 +25,7 @@ ConsolidatedMetadata, GroupMetadata, _create_rooted_hierarchy, + _create_rooted_hierarchy_sync, _join_paths, _normalize_path_keys, _normalize_paths, @@ -1648,7 +1649,7 @@ async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: st nodes_create = root_meta | groups_expected_meta | arrays_expected_meta - g = await _create_rooted_hierarchy(spath, nodes=nodes_create, overwrite=True) + g = await _create_rooted_hierarchy(spath, nodes=nodes_create) assert g.metadata.attributes == {"path": root_key} members = await _collect_aiterator(g.members(max_depth=None)) @@ -1660,6 +1661,48 @@ async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: st assert members_observed_meta == members_expected_meta_relative +# TODO: simplify testing sync versions of async functions. +@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) +@pytest.mark.parametrize("zarr_format", [2, 3]) +@pytest.mark.parametrize("root_key", ["", "root"]) +@pytest.mark.parametrize("path", ["", "foo"]) +def test_create_rooted_hierarchy_sync_group(store: Store, zarr_format, path: str, root_key: str): + """ + Test that the _create_rooted_hierarchy_sync can create a group. + """ + spath = sync(make_store_path(store, path=path)) + root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} + group_names = ["a", "a/b"] + array_names = ["a/b/c", "a/b/d"] + + # just to ensure that we don't use the same name twice in tests + assert set(group_names) & set(array_names) == set() + + groups_expected_meta = { + _join_paths([root_key, node_name]): GroupMetadata( + zarr_format=zarr_format, attributes={"path": node_name} + ) + for node_name in group_names + } + arrays_expected_meta = { + _join_paths([root_key, node_name]): meta_from_array(np.zeros(4), zarr_format=zarr_format) + for node_name in array_names + } + + nodes_create = root_meta | groups_expected_meta | arrays_expected_meta + + g = _create_rooted_hierarchy_sync(spath, nodes=nodes_create) + assert g.metadata.attributes == {"path": root_key} + + members = g.members(max_depth=None) + members_observed_meta = {k: v.metadata for k, v in members} + members_expected_meta_relative = { + k.removeprefix(root_key).lstrip("/"): v + for k, v in (groups_expected_meta | arrays_expected_meta).items() + } + assert members_observed_meta == members_expected_meta_relative + + @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) From 06e5482dbb0ffbe9a759567db54e1153c5fb4ac6 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 22:21:23 +0100 Subject: [PATCH 24/59] remove replace kwarg for _set_return_key --- src/zarr/core/group.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 8ea1271fc2..d28c94546f 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3402,14 +3402,13 @@ async def _read_node( return await _read_node_v2(store_path=store_path) case 3: return await _read_node_v3(store_path=store_path) - case _: + case _: # pragma: no cover raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover -async def _set_return_key(*, store: Store, key: str, value: Buffer, replace: bool) -> str: +async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str: """ - Either write a value to storage at the given key, or ensure that there is already a value in - storage at the given key. The key is returned in either case. + Write a value to storage at the given key. The key is returned. Useful when saving values via routines that return results in execution order, like asyncio.as_completed, because in this case we need to know which key was saved in order to yield the right object to the caller. @@ -3422,14 +3421,8 @@ async def _set_return_key(*, store: Store, key: str, value: Buffer, replace: boo The key to save the value to. value : Buffer The value to save. - replace : bool - If True, then the value will be written even if a value associated with the key - already exists in storage. If False, an existing value will not be overwritten. """ - if replace: - await store.set(key, value) - else: - await store.set_if_not_exists(key, value) + await store.set(key, value) return key @@ -3442,7 +3435,7 @@ def _persist_metadata( to_save = metadata.to_buffer_dict(default_buffer_prototype()) return tuple( - _set_return_key(store=store, key=_join_paths([path, key]), value=value, replace=True) + _set_return_key(store=store, key=_join_paths([path, key]), value=value) for key, value in to_save.items() ) From 37186d601aa1b487743c3989863410a87c43d16b Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 22:23:12 +0100 Subject: [PATCH 25/59] shield lines from coverage --- src/zarr/core/group.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index d28c94546f..3baced34a7 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3325,7 +3325,7 @@ def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMet return ArrayV3Metadata.from_dict(zarr_json) case {"node_type": "group"}: return GroupMetadata.from_dict(zarr_json) - case _: + case _: # pragma: no cover raise ValueError( "invalid value for `node_type` key in metadata document" ) # pragma: no cover @@ -3340,7 +3340,7 @@ def _build_metadata_v2( match zarr_json: case {"shape": _}: return ArrayV2Metadata.from_dict(zarr_json | {"attributes": attrs_json}) - case _: + case _: # pragma: no cover return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json}) @@ -3355,7 +3355,7 @@ def _build_node_v3( return AsyncArray(metadata, store_path=store_path) case GroupMetadata(): return AsyncGroup(metadata, store_path=store_path) - case _: + case _: # pragma: no cover raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover @@ -3371,7 +3371,7 @@ def _build_node_v2( return AsyncArray(metadata, store_path=store_path) case GroupMetadata(): return AsyncGroup(metadata, store_path=store_path) - case _: + case _: # pragma: no cover raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover @@ -3402,7 +3402,7 @@ async def _read_node( return await _read_node_v2(store_path=store_path) case 3: return await _read_node_v3(store_path=store_path) - case _: # pragma: no cover + case _: # pragma: no cover raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover From ed4e8464e7278ba3156f6b8e429e061dd73d0ebf Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 22:23:22 +0100 Subject: [PATCH 26/59] add some tests --- tests/test_group.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/test_group.py b/tests/test_group.py index 159f4a3af8..118b335389 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -29,6 +29,7 @@ _join_paths, _normalize_path_keys, _normalize_paths, + _parse_hierarchy_dict, _read_node, create_hierarchy, create_nodes, @@ -1642,6 +1643,7 @@ async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: st ) for node_name in group_names } + arrays_expected_meta = { _join_paths([root_key, node_name]): meta_from_array(np.zeros(4), zarr_format=zarr_format) for node_name in array_names @@ -1660,6 +1662,18 @@ async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: st } assert members_observed_meta == members_expected_meta_relative +@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) +def test_create_hierarchy_implicit_groups(store: Store): + spath = sync(make_store_path(store, path='')) + nodes = { + '': GroupMetadata(zarr_format=3, attributes={'implicit': False}), + 'a/b/c': GroupMetadata(zarr_format=3, attributes={'implicit': False}) + } + + hierarchy_parsed = _parse_hierarchy_dict(nodes) + g = _create_rooted_hierarchy_sync(spath, nodes=nodes) + for key, value in hierarchy_parsed.items(): + assert g[key].metadata.attributes == value.attributes # TODO: simplify testing sync versions of async functions. @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @@ -1721,9 +1735,32 @@ async def test_create_rooted_hierarchy_array(store: Store, zarr_format, path: st nodes_create = root_meta a = await _create_rooted_hierarchy(spath, nodes=nodes_create, overwrite=True) + assert isinstance(a, AsyncArray) + assert a.metadata.attributes == {"path": root_key} + +@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) +@pytest.mark.parametrize("zarr_format", [2, 3]) +@pytest.mark.parametrize("root_key", ["", "root"]) +@pytest.mark.parametrize("path", ["", "foo"]) +async def test_create_rooted_hierarchy_sync_array(store: Store, zarr_format, path: str, root_key: str): + """ + Test that _create_rooted_hierarchy_sync can create an array. + """ + spath = await make_store_path(store, path=path) + root_meta = { + root_key: meta_from_array( + np.arange(3), zarr_format=zarr_format, attributes={"path": root_key} + ) + } + + nodes_create = root_meta + + a =_create_rooted_hierarchy_sync(spath, nodes=nodes_create, overwrite=True) + assert isinstance(a, Array) assert a.metadata.attributes == {"path": root_key} + async def test_create_rooted_hierarchy_invalid(): """ Ensure _create_rooted_hierarchy will raise a ValueError if the input does not contain From 02ac91d9c97a2a416edbdddfc029673edadf0059 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 22:23:54 +0100 Subject: [PATCH 27/59] lint --- tests/test_group.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/test_group.py b/tests/test_group.py index 118b335389..ba8205ad05 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1643,7 +1643,7 @@ async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: st ) for node_name in group_names } - + arrays_expected_meta = { _join_paths([root_key, node_name]): meta_from_array(np.zeros(4), zarr_format=zarr_format) for node_name in array_names @@ -1662,19 +1662,21 @@ async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: st } assert members_observed_meta == members_expected_meta_relative + @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) def test_create_hierarchy_implicit_groups(store: Store): - spath = sync(make_store_path(store, path='')) + spath = sync(make_store_path(store, path="")) nodes = { - '': GroupMetadata(zarr_format=3, attributes={'implicit': False}), - 'a/b/c': GroupMetadata(zarr_format=3, attributes={'implicit': False}) - } - + "": GroupMetadata(zarr_format=3, attributes={"implicit": False}), + "a/b/c": GroupMetadata(zarr_format=3, attributes={"implicit": False}), + } + hierarchy_parsed = _parse_hierarchy_dict(nodes) g = _create_rooted_hierarchy_sync(spath, nodes=nodes) for key, value in hierarchy_parsed.items(): assert g[key].metadata.attributes == value.attributes + # TODO: simplify testing sync versions of async functions. @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @@ -1738,11 +1740,14 @@ async def test_create_rooted_hierarchy_array(store: Store, zarr_format, path: st assert isinstance(a, AsyncArray) assert a.metadata.attributes == {"path": root_key} + @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) @pytest.mark.parametrize("path", ["", "foo"]) -async def test_create_rooted_hierarchy_sync_array(store: Store, zarr_format, path: str, root_key: str): +async def test_create_rooted_hierarchy_sync_array( + store: Store, zarr_format, path: str, root_key: str +): """ Test that _create_rooted_hierarchy_sync can create an array. """ @@ -1755,12 +1760,11 @@ async def test_create_rooted_hierarchy_sync_array(store: Store, zarr_format, pat nodes_create = root_meta - a =_create_rooted_hierarchy_sync(spath, nodes=nodes_create, overwrite=True) + a = _create_rooted_hierarchy_sync(spath, nodes=nodes_create, overwrite=True) assert isinstance(a, Array) assert a.metadata.attributes == {"path": root_key} - async def test_create_rooted_hierarchy_invalid(): """ Ensure _create_rooted_hierarchy will raise a ValueError if the input does not contain From f6a08a0a1587d7677a74b185a77938b0e99ea6b5 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 22 Jan 2025 23:17:46 +0100 Subject: [PATCH 28/59] improve coverage with more tests --- src/zarr/core/group.py | 18 +++++-- tests/test_group.py | 110 +++++++++++++++++++++++++++++++++++------ 2 files changed, 108 insertions(+), 20 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 3baced34a7..8717810bec 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -72,6 +72,7 @@ from collections.abc import ( AsyncGenerator, AsyncIterator, + Callable, Coroutine, Generator, Iterable, @@ -2898,10 +2899,17 @@ async def create_hierarchy( implicit_group_names = set(nodes_parsed.keys()) - set(nodes.keys()) zarr_format = sample.zarr_format - if zarr_format == 3 or zarr_format == 2: + # TODO: this type hint is so long + func: ( + Callable[[StorePath], Coroutine[Any, Any, GroupMetadata | ArrayV3Metadata]] + | Callable[[StorePath], Coroutine[Any, Any, GroupMetadata | ArrayV2Metadata]] + ) + if zarr_format == 3: func = _read_metadata_v3 - else: - raise ValueError(f"Invalid zarr_format: {zarr_format}") + elif zarr_format == 2: + func = _read_metadata_v2 + else: # pragma: no cover + raise ValueError(f"Invalid zarr_format: {zarr_format}") # pragma: no cover coros = (func(store_path=store_path / key) for key in nodes_parsed) extant_node_query = dict( @@ -3319,7 +3327,7 @@ def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMet Convert a dict representation of Zarr V3 metadata into the corresponding metadata class. """ if "node_type" not in zarr_json: - raise KeyError("missing `node_type` key in metadata document.") + raise MetadataValidationError("node_type", "array or group", "nothing (the key is missing)") match zarr_json: case {"node_type": "array"}: return ArrayV3Metadata.from_dict(zarr_json) @@ -3395,7 +3403,7 @@ async def _read_node( store_path: StorePath, zarr_format: ZarrFormat ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: """ - Read and AsyncArray or AsyncGroup from a location defined by a StorePath. + Read an AsyncArray or AsyncGroup from a location defined by a StorePath. """ match zarr_format: case 2: diff --git a/tests/test_group.py b/tests/test_group.py index ba8205ad05..7c8eaf9e8f 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -24,6 +24,7 @@ from zarr.core.group import ( ConsolidatedMetadata, GroupMetadata, + _build_metadata_v3, _create_rooted_hierarchy, _create_rooted_hierarchy_sync, _join_paths, @@ -34,8 +35,9 @@ create_hierarchy, create_nodes, ) +from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import _collect_aiterator, sync -from zarr.errors import ContainsArrayError, ContainsGroupError +from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage._common import make_store_path from zarr.storage._utils import normalize_path @@ -1516,9 +1518,53 @@ async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: Zarr assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("extant_node", ["array", "group"]) +async def test_create_hierarchy_existing_nodes( + store: Store, extant_node: Literal["array", "group"], zarr_format: ZarrFormat +) -> None: + """ + Test that create_hierarchy with overwrite = False will not overwrite an existing array or group, + and raises an exception instead. + """ + spath = await make_store_path(store, path="path") + extant_node_path = "node" + + if extant_node == "array": + extant_metadata = meta_from_array( + np.zeros(4), zarr_format=zarr_format, attributes={"extant": True} + ) + new_metadata = meta_from_array(np.zeros(4), zarr_format=zarr_format) + err_cls = ContainsArrayError + else: + extant_metadata = GroupMetadata(zarr_format=zarr_format, attributes={"extant": True}) + new_metadata = GroupMetadata(zarr_format=zarr_format) + err_cls = ContainsGroupError + + # write the extant metadata + sync( + _collect_aiterator( + create_nodes(store_path=spath, nodes={extant_node_path: extant_metadata}) + ) + ) + + msg = f"{extant_node} exists in store {store!r} at path {extant_node_path!r}." + # ensure that we cannot invoke create_hierarchy with overwrite=False here + with pytest.raises(err_cls, match=re.escape(msg)): + sync( + _collect_aiterator( + create_hierarchy(store_path=spath, nodes={"node": new_metadata}, overwrite=False) + ) + ) + # ensure that the extant metadata was not overwritten + assert ( + await _read_node(spath / extant_node_path, zarr_format=zarr_format) + ).metadata.attributes == {"extant": True} + + @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("overwrite", [True, False]) -def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat, overwrite: bool): +def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat, overwrite: bool) -> None: """ Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict. """ @@ -1537,7 +1583,9 @@ def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat, overwrite @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("overwrite", [True, False]) -def test_group_create_hierarchy_no_root(store: Store, zarr_format: ZarrFormat, overwrite: bool): +def test_group_create_hierarchy_no_root( + store: Store, zarr_format: ZarrFormat, overwrite: bool +) -> None: """ Test that the Group.create_hierarchy method will error if the dict provided contains a root. """ @@ -1552,7 +1600,9 @@ def test_group_create_hierarchy_no_root(store: Store, zarr_format: ZarrFormat, o @pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_group_create_hierarchy_invalid_mixed_zarr_format(store: Store, zarr_format: ZarrFormat): +def test_group_create_hierarchy_invalid_mixed_zarr_format( + store: Store, zarr_format: ZarrFormat +) -> None: """ Test that ``Group.create_hierarchy`` will raise an error if the zarr_format of the nodes is different from the parent group. @@ -1597,7 +1647,7 @@ async def test_create_hierarchy_invalid_nested( @pytest.mark.parametrize("store", ["memory"], indirect=True) -async def test_create_hierarchy_invalid_mixed_format(store: Store): +async def test_create_hierarchy_invalid_mixed_format(store: Store) -> None: """ Test that create_hierarchy will not create a Zarr group that contains a both Zarr v2 and Zarr v3 nodes. @@ -1625,7 +1675,9 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store): @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) @pytest.mark.parametrize("path", ["", "foo"]) -async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: str, root_key: str): +async def test_create_rooted_hierarchy_group( + store: Store, zarr_format, path: str, root_key: str +) -> None: """ Test that the _create_rooted_hierarchy can create a group. """ @@ -1663,8 +1715,8 @@ async def test_create_rooted_hierarchy_group(store: Store, zarr_format, path: st assert members_observed_meta == members_expected_meta_relative -@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) -def test_create_hierarchy_implicit_groups(store: Store): +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_create_hierarchy_implicit_groups(store: Store) -> None: spath = sync(make_store_path(store, path="")) nodes = { "": GroupMetadata(zarr_format=3, attributes={"implicit": False}), @@ -1682,7 +1734,9 @@ def test_create_hierarchy_implicit_groups(store: Store): @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) @pytest.mark.parametrize("path", ["", "foo"]) -def test_create_rooted_hierarchy_sync_group(store: Store, zarr_format, path: str, root_key: str): +def test_create_rooted_hierarchy_sync_group( + store: Store, zarr_format, path: str, root_key: str +) -> None: """ Test that the _create_rooted_hierarchy_sync can create a group. """ @@ -1723,7 +1777,9 @@ def test_create_rooted_hierarchy_sync_group(store: Store, zarr_format, path: str @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) @pytest.mark.parametrize("path", ["", "foo"]) -async def test_create_rooted_hierarchy_array(store: Store, zarr_format, path: str, root_key: str): +async def test_create_rooted_hierarchy_array( + store: Store, zarr_format, path: str, root_key: str +) -> None: """ Test that the _create_rooted_hierarchy can create an array. """ @@ -1747,7 +1803,7 @@ async def test_create_rooted_hierarchy_array(store: Store, zarr_format, path: st @pytest.mark.parametrize("path", ["", "foo"]) async def test_create_rooted_hierarchy_sync_array( store: Store, zarr_format, path: str, root_key: str -): +) -> None: """ Test that _create_rooted_hierarchy_sync can create an array. """ @@ -1765,7 +1821,7 @@ async def test_create_rooted_hierarchy_sync_array( assert a.metadata.attributes == {"path": root_key} -async def test_create_rooted_hierarchy_invalid(): +async def test_create_rooted_hierarchy_invalid() -> None: """ Ensure _create_rooted_hierarchy will raise a ValueError if the input does not contain a root node. @@ -1781,7 +1837,7 @@ async def test_create_rooted_hierarchy_invalid(): @pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")]) -def test_normalize_paths_invalid(paths: tuple[str, str]): +def test_normalize_paths_invalid(paths: tuple[str, str]) -> None: """ Ensure that calling _normalize_paths on values that will normalize to the same value will generate a ValueError. @@ -1795,7 +1851,7 @@ def test_normalize_paths_invalid(paths: tuple[str, str]): @pytest.mark.parametrize( "paths", [("/a", "a/b"), ("a", "a/b"), ("a/", "a///b"), ("/a/", "//a/b///")] ) -def test_normalize_paths_valid(paths: tuple[str, str]): +def test_normalize_paths_valid(paths: tuple[str, str]) -> None: """ Ensure that calling _normalize_paths on values that normalize to distinct values returns a tuple of those normalized values. @@ -1804,7 +1860,10 @@ def test_normalize_paths_valid(paths: tuple[str, str]): assert _normalize_paths(paths) == expected -def test_normalize_path_keys(): +def test_normalize_path_keys() -> None: + """ + Test that normalize_path_keys returns a dict where each key has been normalized. + """ data = {"": 10, "a": "hello", "a/b": None, "/a/b/c/d": None} observed = _normalize_path_keys(data) expected = {normalize_path(k): v for k, v in data.items()} @@ -1874,3 +1933,24 @@ def test_group_members_concurrency_limit(store: MemoryStore) -> None: elapsed = time.time() - start assert elapsed > num_groups * get_latency + + +@pytest.mark.parametrize("option", ["array", "group", "invalid"]) +def test_build_metadata_v3(option: Literal["array", "group", "invalid"]) -> None: + """ + Test that _build_metadata_v3 returns the correct metadata for a v3 array or group + """ + match option: + case "array": + metadata_dict = meta_from_array(np.arange(10), zarr_format=3).to_dict() + assert _build_metadata_v3(metadata_dict) == ArrayV3Metadata.from_dict(metadata_dict) + case "group": + metadata_dict = GroupMetadata(attributes={"foo": 10}, zarr_format=3).to_dict() + assert _build_metadata_v3(metadata_dict) == GroupMetadata.from_dict(metadata_dict) + case "invalid": + metadata_dict = GroupMetadata(zarr_format=3).to_dict() + metadata_dict.pop("node_type") + # TODO: fix the error message + msg = "Invalid value for 'node_type'. Expected 'array or group'. Got 'nothing (the key is missing)'." + with pytest.raises(MetadataValidationError, match=re.escape(msg)): + _build_metadata_v3(metadata_dict) From 661678fc2c00ca17b5d98a78884a6511023740b4 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 28 Jan 2025 16:14:21 +0100 Subject: [PATCH 29/59] use store + path instead of StorePath for hierarchy api --- src/zarr/core/group.py | 95 +++++++++++++++++++++++------------------- tests/test_group.py | 58 ++++++++++++++------------ 2 files changed, 84 insertions(+), 69 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 8717810bec..08634d1a1b 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -694,9 +694,9 @@ async def getitem( if self.metadata.consolidated_metadata is not None: return self._getitem_consolidated(store_path, key, prefix=self.name) elif self.metadata.zarr_format == 3: - return await _read_node_v3(store_path=store_path) + return await _read_node_v3(store=self.store, path=store_path.path) elif self.metadata.zarr_format == 2: - return await _read_node_v2(store_path=store_path) + return await _read_node_v2(store=self.store, path=store_path.path) else: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") @@ -1423,7 +1423,8 @@ async def create_hierarchy( """ semaphore = asyncio.Semaphore(config.get("async.concurrency")) async for node in create_hierarchy( - store_path=self.store_path, + store=self.store, + path=self.path, nodes=nodes, semaphore=semaphore, overwrite=overwrite, @@ -2838,7 +2839,8 @@ def array( async def create_hierarchy( *, - store_path: StorePath, + store: Store, + path: str, nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], semaphore: asyncio.Semaphore | None = None, overwrite: bool = False, @@ -2887,7 +2889,7 @@ async def create_hierarchy( # we allow creating empty hierarchies -- it's a no-op if len(nodes_parsed) > 0: if overwrite: - await store_path.delete_dir() + await store.delete_dir(path) else: # attempt to fetch all of the metadata described in hierarchy # first figure out which zarr format we are dealing with @@ -2901,8 +2903,8 @@ async def create_hierarchy( zarr_format = sample.zarr_format # TODO: this type hint is so long func: ( - Callable[[StorePath], Coroutine[Any, Any, GroupMetadata | ArrayV3Metadata]] - | Callable[[StorePath], Coroutine[Any, Any, GroupMetadata | ArrayV2Metadata]] + Callable[[Store, str], Coroutine[Any, Any, GroupMetadata | ArrayV3Metadata]] + | Callable[[Store, str], Coroutine[Any, Any, GroupMetadata | ArrayV2Metadata]] ) if zarr_format == 3: func = _read_metadata_v3 @@ -2911,7 +2913,7 @@ async def create_hierarchy( else: # pragma: no cover raise ValueError(f"Invalid zarr_format: {zarr_format}") # pragma: no cover - coros = (func(store_path=store_path / key) for key in nodes_parsed) + coros = (func(store=store, path=_join_paths([path, key])) for key in nodes_parsed) extant_node_query = dict( zip( nodes_parsed.keys(), @@ -2942,24 +2944,25 @@ async def create_hierarchy( if isinstance(value, GroupMetadata): if key not in implicit_group_names: - raise ContainsGroupError(store_path.store, key) + raise ContainsGroupError(store, key) else: # as there is already a group with this name, we should not create a new one redundant_implicit_groups.append(key) elif isinstance(value, ArrayV2Metadata | ArrayV3Metadata): - raise ContainsArrayError(store_path.store, key) + raise ContainsArrayError(store, key) nodes_parsed = { k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups } - async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore): + async for node in create_nodes(store=store, path=path, nodes=nodes_parsed, semaphore=semaphore): yield node async def create_nodes( *, - store_path: StorePath, + store: Store, + path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], semaphore: asyncio.Semaphore | None = None, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: @@ -2977,8 +2980,8 @@ async def create_nodes( create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): # transform the key, which is relative to a store_path.path, to a key in the store - write_prefix = _join_paths([store_path.path, key]) - create_tasks.extend(_persist_metadata(store_path.store, write_prefix, value)) + write_prefix = _join_paths([path, key]) + create_tasks.extend(_persist_metadata(store, write_prefix, value)) created_object_keys = [] async with ctx: @@ -2993,7 +2996,7 @@ async def create_nodes( # the relative path of the object we just created -- we need this to track which metadata documents # were written so that we can yield a complete v2 Array / Group class after both .zattrs # and the metadata JSON was created. - object_path_relative = created_key.removeprefix(store_path.path).lstrip("/") + object_path_relative = created_key.removeprefix(path).lstrip("/") created_object_keys.append(object_path_relative) # get the node name from the object key @@ -3008,7 +3011,7 @@ async def create_nodes( if meta_out.zarr_format == 3: # yes, it is silly that we relativize, then de-relativize this same path - node_store_path = store_path / node_name + node_store_path = StorePath(store=store, path=path) / node_name if isinstance(meta_out, GroupMetadata): yield AsyncGroup(metadata=meta_out, store_path=node_store_path) else: @@ -3027,7 +3030,7 @@ async def create_nodes( meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys if meta_done and attrs_done: - node_store_path = store_path / node_name + node_store_path = StorePath(store=store, path=path) / node_name if isinstance(meta_out, GroupMetadata): yield AsyncGroup(metadata=meta_out, store_path=node_store_path) else: @@ -3275,20 +3278,22 @@ async def _iter_members_deep( yield key, node -async def _read_metadata_v3(store_path: StorePath) -> ArrayV3Metadata | GroupMetadata: +async def _read_metadata_v3(store: Store, path: str) -> ArrayV3Metadata | GroupMetadata: """ Given a store_path, return ArrayV3Metadata or GroupMetadata defined by the metadata document stored at store_path.path / zarr.json. If no such document is found, raise a KeyError. """ - zarr_json_bytes = await (store_path / ZARR_JSON).get() + zarr_json_bytes = await store.get( + _join_paths([path, ZARR_JSON]), prototype=default_buffer_prototype() + ) if zarr_json_bytes is None: - raise KeyError(store_path.path) + raise KeyError(path) else: zarr_json = json.loads(zarr_json_bytes.to_bytes()) return _build_metadata_v3(zarr_json) -async def _read_metadata_v2(store_path: StorePath) -> ArrayV2Metadata | GroupMetadata: +async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupMetadata: """ Given a store_path, return ArrayV2Metadata or GroupMetadata defined by the metadata document stored at store_path.path / (.zgroup | .zarray). If no such document is found, @@ -3297,9 +3302,9 @@ async def _read_metadata_v2(store_path: StorePath) -> ArrayV2Metadata | GroupMet # TODO: consider first fetching array metadata, and only fetching group metadata when we don't # find an array zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather( - (store_path / ZARRAY_JSON).get(), - (store_path / ZGROUP_JSON).get(), - (store_path / ZATTRS_JSON).get(), + store.get(_join_paths([path, ZARRAY_JSON]), prototype=default_buffer_prototype()), + store.get(_join_paths([path, ZGROUP_JSON]), prototype=default_buffer_prototype()), + store.get(_join_paths([path, ZATTRS_JSON]), prototype=default_buffer_prototype()), ) if zattrs_bytes is None: @@ -3315,7 +3320,7 @@ async def _read_metadata_v2(store_path: StorePath) -> ArrayV2Metadata | GroupMet else: if zgroup_bytes is None: # neither .zarray or .zgroup were found results in KeyError - raise KeyError(store_path.path) + raise KeyError(path) else: zmeta = json.loads(zgroup_bytes.to_bytes()) @@ -3353,11 +3358,15 @@ def _build_metadata_v2( def _build_node_v3( - metadata: ArrayV3Metadata | GroupMetadata, store_path: StorePath + *, + store: Store, + path: str, + metadata: ArrayV3Metadata | GroupMetadata, ) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: """ Take a metadata object and return a node (AsyncArray or AsyncGroup). """ + store_path = StorePath(store=store, path=path) match metadata: case ArrayV3Metadata(): return AsyncArray(metadata, store_path=store_path) @@ -3368,12 +3377,12 @@ def _build_node_v3( def _build_node_v2( - metadata: ArrayV2Metadata | GroupMetadata, store_path: StorePath + *, store: Store, path: str, metadata: ArrayV2Metadata | GroupMetadata ) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: """ Take a metadata object and return a node (AsyncArray or AsyncGroup). """ - + store_path = StorePath(store=store, path=path) match metadata: case ArrayV2Metadata(): return AsyncArray(metadata, store_path=store_path) @@ -3383,33 +3392,33 @@ def _build_node_v2( raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover -async def _read_node_v2(store_path: StorePath) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: +async def _read_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: """ Read a Zarr v2 AsyncArray or AsyncGroup from a location defined by a StorePath. """ - metadata = await _read_metadata_v2(store_path=store_path) - return _build_node_v2(metadata=metadata, store_path=store_path) + metadata = await _read_metadata_v2(store=store, path=path) + return _build_node_v2(store=store, path=path, metadata=metadata) -async def _read_node_v3(store_path: StorePath) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: +async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: """ Read a Zarr v3 AsyncArray or AsyncGroup from a location defined by a StorePath. """ - metadata = await _read_metadata_v3(store_path=store_path) - return _build_node_v3(metadata=metadata, store_path=store_path) + metadata = await _read_metadata_v3(store=store, path=path) + return _build_node_v3(store=store, path=path, metadata=metadata) async def _read_node( - store_path: StorePath, zarr_format: ZarrFormat + store: Store, path: str, zarr_format: ZarrFormat ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: """ Read an AsyncArray or AsyncGroup from a location defined by a StorePath. """ match zarr_format: case 2: - return await _read_node_v2(store_path=store_path) + return await _read_node_v2(store=store, path=path) case 3: - return await _read_node_v3(store_path=store_path) + return await _read_node_v3(store=store, path=path) case _: # pragma: no cover raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover @@ -3449,8 +3458,9 @@ def _persist_metadata( async def _create_rooted_hierarchy( - store_path: StorePath, *, + store: Store, + path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], overwrite: bool = False, ) -> AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: @@ -3476,17 +3486,18 @@ async def _create_rooted_hierarchy( nodes_created = { x.path: x async for x in create_hierarchy( - store_path=store_path, nodes=nodes, semaphore=semaphore, overwrite=overwrite + store=store, path=path, nodes=nodes, semaphore=semaphore, overwrite=overwrite ) } # the names of the created nodes will be relative to the store_path instance - root_relative_to_store_path = _join_paths([store_path.path, root]) + root_relative_to_store_path = _join_paths([path, root]) return nodes_created[root_relative_to_store_path] def _create_rooted_hierarchy_sync( - store_path: StorePath, *, + store: Store, + path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], overwrite: bool = False, ) -> Group | Array: @@ -3495,7 +3506,7 @@ def _create_rooted_hierarchy_sync( ``_create_rooted_hierarchy`` and waits for the result. """ async_node = sync( - _create_rooted_hierarchy(store_path=store_path, nodes=nodes, overwrite=overwrite) + _create_rooted_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite) ) if isinstance(async_node, AsyncGroup): return Group(async_node) diff --git a/tests/test_group.py b/tests/test_group.py index 7c8eaf9e8f..b088b4299d 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1473,10 +1473,9 @@ async def test_create_nodes(store: Store, zarr_format: ZarrFormat) -> None: "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } - spath = await make_store_path(store, path="foo") observed_nodes = { str(PurePosixPath(a.name).relative_to("/" + path)): a - async for a in create_nodes(store_path=spath, nodes=expected_meta) + async for a in create_nodes(store=store, path=path, nodes=expected_meta) } assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @@ -1500,21 +1499,26 @@ async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: Zarr } # we expect create_hierarchy to insert a group that was missing from the hierarchy spec expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} - spath = await make_store_path(store, path=path) # initialize the group with some nodes - sync(_collect_aiterator(create_nodes(store_path=spath, nodes=pre_existing_nodes))) + sync(_collect_aiterator(create_nodes(store=store, path=path, nodes=pre_existing_nodes))) observed_nodes = { str(PurePosixPath(a.name).relative_to("/" + path)): a - async for a in create_hierarchy(store_path=spath, nodes=expected_meta, overwrite=overwrite) + async for a in create_hierarchy( + store=store, path=path, nodes=expected_meta, overwrite=overwrite + ) } if not overwrite: - extra_group = await _read_node(spath / "group/extra", zarr_format=zarr_format) + extra_group = await _read_node( + store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format + ) assert extra_group.metadata.attributes == {"name": "extra"} else: with pytest.raises(KeyError): - await _read_node(spath / "group/extra", zarr_format=zarr_format) + await _read_node( + store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format + ) assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @@ -1527,9 +1531,8 @@ async def test_create_hierarchy_existing_nodes( Test that create_hierarchy with overwrite = False will not overwrite an existing array or group, and raises an exception instead. """ - spath = await make_store_path(store, path="path") extant_node_path = "node" - + path = "path" if extant_node == "array": extant_metadata = meta_from_array( np.zeros(4), zarr_format=zarr_format, attributes={"extant": True} @@ -1544,7 +1547,7 @@ async def test_create_hierarchy_existing_nodes( # write the extant metadata sync( _collect_aiterator( - create_nodes(store_path=spath, nodes={extant_node_path: extant_metadata}) + create_nodes(store=store, path=path, nodes={extant_node_path: extant_metadata}) ) ) @@ -1553,12 +1556,16 @@ async def test_create_hierarchy_existing_nodes( with pytest.raises(err_cls, match=re.escape(msg)): sync( _collect_aiterator( - create_hierarchy(store_path=spath, nodes={"node": new_metadata}, overwrite=False) + create_hierarchy( + store=store, path=path, nodes={"node": new_metadata}, overwrite=False + ) ) ) # ensure that the extant metadata was not overwritten assert ( - await _read_node(spath / extant_node_path, zarr_format=zarr_format) + await _read_node( + store=store, path=_join_paths([path, extant_node_path]), zarr_format=zarr_format + ) ).metadata.attributes == {"extant": True} @@ -1642,8 +1649,8 @@ async def test_create_hierarchy_invalid_nested( msg = "Only Zarr groups can contain other nodes." with pytest.raises(ValueError, match=msg): - spath = await make_store_path(store, path="foo") - await _collect_aiterator(create_hierarchy(store_path=spath, nodes=hierarchy_spec)) + path = "foo" + await _collect_aiterator(create_hierarchy(store=store, path=path, nodes=hierarchy_spec)) @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1652,7 +1659,7 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store) -> None: Test that create_hierarchy will not create a Zarr group that contains a both Zarr v2 and Zarr v3 nodes. """ - spath = await make_store_path(store, path="foo") + path = "foo" msg = ( "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. " "The following keys map to Zarr v2 nodes: ['v2']. " @@ -1662,7 +1669,8 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store) -> None: with pytest.raises(ValueError, match=re.escape(msg)): await _collect_aiterator( create_hierarchy( - store_path=spath, + store=store, + path=path, nodes={ "v2": GroupMetadata(zarr_format=2), "v3": GroupMetadata(zarr_format=3), @@ -1681,7 +1689,6 @@ async def test_create_rooted_hierarchy_group( """ Test that the _create_rooted_hierarchy can create a group. """ - spath = await make_store_path(store, path=path) root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} group_names = ["a", "a/b"] array_names = ["a/b/c", "a/b/d"] @@ -1703,7 +1710,7 @@ async def test_create_rooted_hierarchy_group( nodes_create = root_meta | groups_expected_meta | arrays_expected_meta - g = await _create_rooted_hierarchy(spath, nodes=nodes_create) + g = await _create_rooted_hierarchy(store=store, path=path, nodes=nodes_create) assert g.metadata.attributes == {"path": root_key} members = await _collect_aiterator(g.members(max_depth=None)) @@ -1717,14 +1724,14 @@ async def test_create_rooted_hierarchy_group( @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_create_hierarchy_implicit_groups(store: Store) -> None: - spath = sync(make_store_path(store, path="")) + path = "" nodes = { "": GroupMetadata(zarr_format=3, attributes={"implicit": False}), "a/b/c": GroupMetadata(zarr_format=3, attributes={"implicit": False}), } hierarchy_parsed = _parse_hierarchy_dict(nodes) - g = _create_rooted_hierarchy_sync(spath, nodes=nodes) + g = _create_rooted_hierarchy_sync(store=store, path=path, nodes=nodes) for key, value in hierarchy_parsed.items(): assert g[key].metadata.attributes == value.attributes @@ -1740,7 +1747,6 @@ def test_create_rooted_hierarchy_sync_group( """ Test that the _create_rooted_hierarchy_sync can create a group. """ - spath = sync(make_store_path(store, path=path)) root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} group_names = ["a", "a/b"] array_names = ["a/b/c", "a/b/d"] @@ -1761,7 +1767,7 @@ def test_create_rooted_hierarchy_sync_group( nodes_create = root_meta | groups_expected_meta | arrays_expected_meta - g = _create_rooted_hierarchy_sync(spath, nodes=nodes_create) + g = _create_rooted_hierarchy_sync(store=store, path=path, nodes=nodes_create) assert g.metadata.attributes == {"path": root_key} members = g.members(max_depth=None) @@ -1783,7 +1789,6 @@ async def test_create_rooted_hierarchy_array( """ Test that the _create_rooted_hierarchy can create an array. """ - spath = await make_store_path(store, path=path) root_meta = { root_key: meta_from_array( np.arange(3), zarr_format=zarr_format, attributes={"path": root_key} @@ -1792,7 +1797,7 @@ async def test_create_rooted_hierarchy_array( nodes_create = root_meta - a = await _create_rooted_hierarchy(spath, nodes=nodes_create, overwrite=True) + a = await _create_rooted_hierarchy(store=store, path=path, nodes=nodes_create, overwrite=True) assert isinstance(a, AsyncArray) assert a.metadata.attributes == {"path": root_key} @@ -1807,7 +1812,6 @@ async def test_create_rooted_hierarchy_sync_array( """ Test that _create_rooted_hierarchy_sync can create an array. """ - spath = await make_store_path(store, path=path) root_meta = { root_key: meta_from_array( np.arange(3), zarr_format=zarr_format, attributes={"path": root_key} @@ -1816,7 +1820,7 @@ async def test_create_rooted_hierarchy_sync_array( nodes_create = root_meta - a = _create_rooted_hierarchy_sync(spath, nodes=nodes_create, overwrite=True) + a = _create_rooted_hierarchy_sync(store=store, path=path, nodes=nodes_create, overwrite=True) assert isinstance(a, Array) assert a.metadata.attributes == {"path": root_key} @@ -1833,7 +1837,7 @@ async def test_create_rooted_hierarchy_invalid() -> None: } msg = "The input does not specify a root node. " with pytest.raises(ValueError, match=msg): - await _create_rooted_hierarchy(store_path=StorePath("store"), nodes=nodes) + await _create_rooted_hierarchy(store=store, path="", nodes=nodes) @pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")]) From 7a718d542451a19a80d536cd6bfc5b19d1031a2c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 28 Jan 2025 16:19:00 +0100 Subject: [PATCH 30/59] docstrings --- src/zarr/core/group.py | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 08634d1a1b..b530717282 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2856,8 +2856,11 @@ async def create_hierarchy( Parameters ---------- - store_path : StorePath - The StorePath object pointing to the root of the hierarchy. + store : Store + The storage backend to use. + path : str + The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with + ``path`` prior to creating nodes. nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, and the values are the metadata of the nodes. The @@ -2865,7 +2868,7 @@ async def create_hierarchy( or ArrayV2Metadata. semaphore : asyncio.Semaphore | None An optional semaphore to limit the number of concurrent tasks. If not - provided, the number of concurrent tasks is not limited. + provided, the number of concurrent tasks is unlimited. allow_root : bool Whether to allow a root node to be created. If ``False``, attempting to create a root node will result in an error. Use this option when calling this function as part of a method @@ -2966,9 +2969,32 @@ async def create_nodes( nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], semaphore: asyncio.Semaphore | None = None, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: - """ - Create a collection of zarr arrays and groups concurrently and atomically. To ensure atomicity, - no attempt is made to ensure that intermediate groups are created. + """Create a collection of arrays and / or groups concurrently. + + Note: no attempt is made to validate that these arrays and / or groups collectively form a + valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that + the ``nodes`` parameter satisfies any correctness constraints. + + Parameters + ---------- + store : Store + The storage backend to use. + path : str + The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with + ``path`` prior to creating nodes. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + semaphore : asyncio.Semaphore | None + An optional semaphore to limit the number of concurrent tasks. If not + provided, the number of concurrent tasks is unlimited. + + Yields + ------ + AsyncGroup | AsyncArray + The created nodes in the order they are created. """ ctx: asyncio.Semaphore | contextlib.nullcontext[None] @@ -2979,7 +3005,7 @@ async def create_nodes( create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): - # transform the key, which is relative to a store_path.path, to a key in the store + # make the key absolute write_prefix = _join_paths([path, key]) create_tasks.extend(_persist_metadata(store, write_prefix, value)) From 23bfef5bbbef7a7378168216491a95d5900e3b96 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 28 Jan 2025 17:18:17 +0100 Subject: [PATCH 31/59] docstrings --- src/zarr/core/group.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index b530717282..c002431fc7 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3096,15 +3096,22 @@ def _parse_hierarchy_dict( data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], ) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: """ + Take an input Mapping of str: node pairs, and parse it into + a dict of str: node pairs that models valid, complete Zarr hierarchy. + If the input represents a complete Zarr hierarchy, i.e. one with no implicit groups, - then return an identical copy of that dict. Otherwise, return a version of the input dict - with groups added where they are needed to make the hierarchy explicit. + then return a dict with the exact same data as the input. + + Otherwise, return a dict derived from the input with groups as needed to make + the hierarchy complete. - For example, an input of {'a/b/c': ArrayMetadata} will result in a return value of - {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}. + For example, an input of {'a/b/c': ArrayMetadata} is incomplete, because it references two + groups ('a' and 'a/b') but these keys are not present in the input. Applying this function + to that input will result in a return value of + {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}, i.e. the implied groups + were added. - The input is also checked for the following conditions, and an error is raised if any - of them are violated: + The input is also checked for the following conditions; an error is raised if any are violated: - No arrays can contain group or arrays (i.e., all arrays must be leaf nodes). - All arrays and groups must have the same ``zarr_format`` value. @@ -3179,18 +3186,21 @@ def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: def _normalize_path_keys(data: Mapping[str, T]) -> dict[str, T]: """ Normalize the keys of the input dict according to the normalization scheme used for zarr node - paths. If any two keys in the input normalize to the value, raise a ValueError. Return the - values of data with the normalized keys. + paths. If any two keys in the input normalize to the same value, raise a ValueError. + Returns a dict where the keys are the elements of the input and the values are the + normalized form of each key. """ parsed_keys = _normalize_paths(data.keys()) - return dict(zip(parsed_keys, data.values(), strict=False)) + return dict(zip(parsed_keys, data.values(), strict=True)) async def _getitem_semaphore( node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None ) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: """ - Combine node.getitem with an optional semaphore. If the semaphore parameter is an + Wrap Group.getitem with an optional semaphore. + + If the semaphore parameter is an asyncio.Semaphore instance, then the getitem operation is performed inside an async context manager provided by that semaphore. If the semaphore parameter is None, then getitem is invoked without a context manager. From 528253482cc246b726fd479dd7ed9ea2ae3f995c Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 28 Jan 2025 17:22:43 +0100 Subject: [PATCH 32/59] release notes --- changes/2665.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/2665.feature.rst diff --git a/changes/2665.feature.rst b/changes/2665.feature.rst new file mode 100644 index 0000000000..40bec542ce --- /dev/null +++ b/changes/2665.feature.rst @@ -0,0 +1 @@ +Adds functions for concurrently creating multiple arrays and groups. \ No newline at end of file From 6507e434bcebd137b05dccbfb3f50e7d32c317ff Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 28 Jan 2025 22:13:23 +0100 Subject: [PATCH 33/59] refactor sync / async functions, and make tests more compact accordingly --- src/zarr/core/group.py | 181 ++++++++++++++++++------- tests/test_group.py | 296 +++++++++++++++++++++++------------------ 2 files changed, 297 insertions(+), 180 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index c002431fc7..41f9677159 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -57,7 +57,7 @@ from zarr.core.config import config from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v3 import V3JsonEncoder -from zarr.core.sync import SyncMixin, sync +from zarr.core.sync import SyncMixin, _collect_aiterator, sync from zarr.errors import ( ContainsArrayError, ContainsGroupError, @@ -1422,7 +1422,7 @@ async def create_hierarchy( An asynchronous iterator over the created arrays and / or groups. """ semaphore = asyncio.Semaphore(config.get("async.concurrency")) - async for node in create_hierarchy( + async for node in create_hierarchy_a( store=self.store, path=self.path, nodes=nodes, @@ -2837,7 +2837,7 @@ def array( ) -async def create_hierarchy( +async def create_hierarchy_a( *, store: Store, path: str, @@ -2848,11 +2848,10 @@ async def create_hierarchy( ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input - ``nodes`` will be created as needed. + will be created as needed. This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy - concurrently. The groups and arrays in the hierarchy are created in a single pass, and the - function yields the created nodes in the order they are created. + concurrently. AsyncArrays and AsyncGroups are yielded in the order they are created. Parameters ---------- @@ -2958,11 +2957,58 @@ async def create_hierarchy( k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups } - async for node in create_nodes(store=store, path=path, nodes=nodes_parsed, semaphore=semaphore): + async for node in create_nodes_a( + store=store, path=path, nodes=nodes_parsed, semaphore=semaphore + ): yield node -async def create_nodes( +def create_hierarchy( + store: Store, + path: str, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, + allow_root: bool = True, +) -> Iterator[Group | Array]: + """ + Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input + will be created as needed. + + This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy + concurrently. Arrays and Groups are yielded in the order they are created. + + Parameters + ---------- + store : Store + The storage backend to use. + path : str + The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with + ``path`` prior to creating nodes. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + allow_root : bool + Whether to allow a root node to be created. If ``False``, attempting to create a root node + will result in an error. Use this option when calling this function as part of a method + defined on ``AsyncGroup`` instances, because in this case the root node has already been + created. + + Yields + ------ + Group | Array + The created nodes in the order they are created. + """ + coro = create_hierarchy_a( + store=store, path=path, nodes=nodes, overwrite=overwrite, allow_root=allow_root + ) + + for result in sync(_collect_aiterator(coro)): + yield _parse_async_node(result) + + +async def create_nodes_a( *, store: Store, path: str, @@ -3056,14 +3102,53 @@ async def create_nodes( meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys if meta_done and attrs_done: - node_store_path = StorePath(store=store, path=path) / node_name - if isinstance(meta_out, GroupMetadata): - yield AsyncGroup(metadata=meta_out, store_path=node_store_path) - else: - yield AsyncArray(metadata=meta_out, store_path=node_store_path) + yield _build_node( + store=store, path=_join_paths([path, node_name]), metadata=meta_out + ) + continue +def create_nodes( + *, + store: Store, + path: str, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + semaphore: asyncio.Semaphore | None = None, +) -> Iterator[Group | Array]: + """Create a collection of arrays and / or groups concurrently. + + Note: no attempt is made to validate that these arrays and / or groups collectively form a + valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that + the ``nodes`` parameter satisfies any correctness constraints. + + Parameters + ---------- + store : Store + The storage backend to use. + path : str + The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with + ``path`` prior to creating nodes. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + semaphore : asyncio.Semaphore | None + An optional semaphore to limit the number of concurrent tasks. If not + provided, the number of concurrent tasks is unlimited. + + Yields + ------ + Group | Array + The created nodes in the order they are created. + """ + coro = create_nodes_a(store=store, path=path, nodes=nodes, semaphore=semaphore) + + for result in sync(_collect_aiterator(coro)): + yield _parse_async_node(result) + + T = TypeVar("T") @@ -3393,34 +3478,31 @@ def _build_metadata_v2( return GroupMetadata.from_dict(zarr_json | {"attributes": attrs_json}) -def _build_node_v3( - *, - store: Store, - path: str, - metadata: ArrayV3Metadata | GroupMetadata, -) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: - """ - Take a metadata object and return a node (AsyncArray or AsyncGroup). - """ - store_path = StorePath(store=store, path=path) - match metadata: - case ArrayV3Metadata(): - return AsyncArray(metadata, store_path=store_path) - case GroupMetadata(): - return AsyncGroup(metadata, store_path=store_path) - case _: # pragma: no cover - raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover +@overload +def _build_node( + *, store: Store, path: str, metadata: ArrayV2Metadata +) -> AsyncArray[ArrayV2Metadata]: ... -def _build_node_v2( - *, store: Store, path: str, metadata: ArrayV2Metadata | GroupMetadata -) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: +@overload +def _build_node( + *, store: Store, path: str, metadata: ArrayV3Metadata +) -> AsyncArray[ArrayV3Metadata]: ... + + +@overload +def _build_node(*, store: Store, path: str, metadata: GroupMetadata) -> AsyncGroup: ... + + +def _build_node( + *, store: Store, path: str, metadata: ArrayV3Metadata | ArrayV2Metadata | GroupMetadata +) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: """ Take a metadata object and return a node (AsyncArray or AsyncGroup). """ store_path = StorePath(store=store, path=path) match metadata: - case ArrayV2Metadata(): + case ArrayV2Metadata() | ArrayV3Metadata(): return AsyncArray(metadata, store_path=store_path) case GroupMetadata(): return AsyncGroup(metadata, store_path=store_path) @@ -3433,7 +3515,7 @@ async def _read_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] Read a Zarr v2 AsyncArray or AsyncGroup from a location defined by a StorePath. """ metadata = await _read_metadata_v2(store=store, path=path) - return _build_node_v2(store=store, path=path, metadata=metadata) + return _build_node(store=store, path=path, metadata=metadata) async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: @@ -3441,14 +3523,14 @@ async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] Read a Zarr v3 AsyncArray or AsyncGroup from a location defined by a StorePath. """ metadata = await _read_metadata_v3(store=store, path=path) - return _build_node_v3(store=store, path=path, metadata=metadata) + return _build_node(store=store, path=path, metadata=metadata) -async def _read_node( +async def _read_node_a( store: Store, path: str, zarr_format: ZarrFormat ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: """ - Read an AsyncArray or AsyncGroup from a location defined by a StorePath. + Read an AsyncArray or AsyncGroup from a path in a Store. """ match zarr_format: case 2: @@ -3459,6 +3541,14 @@ async def _read_node( raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover +def read_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group: + """ + Read an Array or Group from a path in a Store. + """ + + return _parse_async_node(sync(_read_node_a(store=store, path=path, zarr_format=zarr_format))) + + async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str: """ Write a value to storage at the given key. The key is returned. @@ -3493,7 +3583,7 @@ def _persist_metadata( ) -async def _create_rooted_hierarchy( +async def _create_rooted_hierarchy_a( *, store: Store, path: str, @@ -3521,7 +3611,7 @@ async def _create_rooted_hierarchy( nodes_created = { x.path: x - async for x in create_hierarchy( + async for x in create_hierarchy_a( store=store, path=path, nodes=nodes, semaphore=semaphore, overwrite=overwrite ) } @@ -3530,7 +3620,7 @@ async def _create_rooted_hierarchy( return nodes_created[root_relative_to_store_path] -def _create_rooted_hierarchy_sync( +def _create_rooted_hierarchy( *, store: Store, path: str, @@ -3542,11 +3632,6 @@ def _create_rooted_hierarchy_sync( ``_create_rooted_hierarchy`` and waits for the result. """ async_node = sync( - _create_rooted_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite) + _create_rooted_hierarchy_a(store=store, path=path, nodes=nodes, overwrite=overwrite) ) - if isinstance(async_node, AsyncGroup): - return Group(async_node) - elif isinstance(async_node, AsyncArray): - return Array(async_node) - else: - raise TypeError(f"Unexpected node type: {type(async_node)}") # pragma: no cover + return _parse_async_node(async_node) diff --git a/tests/test_group.py b/tests/test_group.py index b088b4299d..52ba9e4827 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -6,7 +6,6 @@ import re import time import warnings -from pathlib import PurePosixPath from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -26,14 +25,16 @@ GroupMetadata, _build_metadata_v3, _create_rooted_hierarchy, - _create_rooted_hierarchy_sync, + _create_rooted_hierarchy_a, _join_paths, _normalize_path_keys, _normalize_paths, _parse_hierarchy_dict, - _read_node, create_hierarchy, + create_hierarchy_a, create_nodes, + create_nodes_a, + read_node, ) from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import _collect_aiterator, sync @@ -1460,7 +1461,10 @@ def test_delitem_removes_children(store: Store, zarr_format: ZarrFormat) -> None @pytest.mark.parametrize("store", ["memory"], indirect=True) -async def test_create_nodes(store: Store, zarr_format: ZarrFormat) -> None: +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_nodes( + impl: Literal["async", "sync"], store: Store, zarr_format: ZarrFormat +) -> None: """ Ensure that ``create_nodes`` can create a zarr hierarchy from a model of that hierarchy in dict form. Note that this creates an incomplete Zarr hierarchy. @@ -1473,59 +1477,87 @@ async def test_create_nodes(store: Store, zarr_format: ZarrFormat) -> None: "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } - observed_nodes = { - str(PurePosixPath(a.name).relative_to("/" + path)): a - async for a in create_nodes(store=store, path=path, nodes=expected_meta) - } + if impl == "async": + created = tuple( + [a async for a in create_nodes_a(store=store, path=path, nodes=expected_meta)] + ) + elif impl == "sync": + created = tuple(create_nodes(store=store, path=path, nodes=expected_meta)) + else: + raise ValueError(f"Invalid impl: {impl}") + + observed_nodes = {a.path.removeprefix(path + "/"): a for a in created} assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("overwrite", [True, False]) -async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: ZarrFormat) -> None: +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_hierarchy( + impl: Literal["async", "sync"], store: Store, overwrite: bool, zarr_format: ZarrFormat +) -> None: """ Test that ``create_hierarchy`` can create a complete Zarr hierarchy, even if the input describes an incomplete one. """ path = "foo" + hierarchy_spec = { - "group": GroupMetadata(attributes={"foo": 10}, zarr_format=zarr_format), - "group/array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), - "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), + "group": GroupMetadata(attributes={"path": "group"}, zarr_format=zarr_format), + "group/array_0": meta_from_array( + np.arange(3), attributes={"path": "group/array_0"}, zarr_format=zarr_format + ), + "group/subgroup/array_0": meta_from_array( + np.arange(4), attributes={"path": "group/subgroup/array_0"}, zarr_format=zarr_format + ), } pre_existing_nodes = { - "group/extra": GroupMetadata(zarr_format=zarr_format, attributes={"name": "extra"}), + "group/extra": GroupMetadata(zarr_format=zarr_format, attributes={"path": "group/extra"}), "": GroupMetadata(zarr_format=zarr_format, attributes={"name": "root"}), } # we expect create_hierarchy to insert a group that was missing from the hierarchy spec expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} # initialize the group with some nodes - sync(_collect_aiterator(create_nodes(store=store, path=path, nodes=pre_existing_nodes))) + _ = tuple(create_nodes(store=store, path=path, nodes=pre_existing_nodes)) - observed_nodes = { - str(PurePosixPath(a.name).relative_to("/" + path)): a - async for a in create_hierarchy( - store=store, path=path, nodes=expected_meta, overwrite=overwrite + if impl == "sync": + created = tuple( + create_hierarchy(store=store, path=path, nodes=hierarchy_spec, overwrite=overwrite) ) - } + elif impl == "async": + created = tuple( + [ + a + async for a in create_hierarchy_a( + store=store, path=path, nodes=hierarchy_spec, overwrite=overwrite + ) + ] + ) + else: + raise ValueError(f"Invalid impl: {impl}") + + observed_nodes = {a.path.removeprefix(path + "/"): a for a in created} + if not overwrite: - extra_group = await _read_node( + extra_group = read_node( store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format ) - assert extra_group.metadata.attributes == {"name": "extra"} + assert extra_group.metadata.attributes == {"path": "group/extra"} else: with pytest.raises(KeyError): - await _read_node( - store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format - ) + read_node(store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format) assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("extant_node", ["array", "group"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) async def test_create_hierarchy_existing_nodes( - store: Store, extant_node: Literal["array", "group"], zarr_format: ZarrFormat + impl: Literal["async", "sync"], + store: Store, + extant_node: Literal["array", "group"], + zarr_format: ZarrFormat, ) -> None: """ Test that create_hierarchy with overwrite = False will not overwrite an existing array or group, @@ -1533,6 +1565,7 @@ async def test_create_hierarchy_existing_nodes( """ extant_node_path = "node" path = "path" + if extant_node == "array": extant_metadata = meta_from_array( np.zeros(4), zarr_format=zarr_format, attributes={"extant": True} @@ -1545,27 +1578,33 @@ async def test_create_hierarchy_existing_nodes( err_cls = ContainsGroupError # write the extant metadata - sync( - _collect_aiterator( - create_nodes(store=store, path=path, nodes={extant_node_path: extant_metadata}) - ) - ) + tuple(create_nodes(store=store, path=path, nodes={extant_node_path: extant_metadata})) msg = f"{extant_node} exists in store {store!r} at path {extant_node_path!r}." # ensure that we cannot invoke create_hierarchy with overwrite=False here - with pytest.raises(err_cls, match=re.escape(msg)): - sync( - _collect_aiterator( + if impl == "sync": + with pytest.raises(err_cls, match=re.escape(msg)): + tuple( create_hierarchy( store=store, path=path, nodes={"node": new_metadata}, overwrite=False ) ) - ) + elif impl == "async": + with pytest.raises(err_cls, match=re.escape(msg)): + tuple( + [ + x + async for x in create_hierarchy_a( + store=store, path=path, nodes={"node": new_metadata}, overwrite=False + ) + ] + ) + else: + raise ValueError(f"Invalid impl: {impl}") + # ensure that the extant metadata was not overwritten assert ( - await _read_node( - store=store, path=_join_paths([path, extant_node_path]), zarr_format=zarr_format - ) + read_node(store=store, path=_join_paths([path, extant_node_path]), zarr_format=zarr_format) ).metadata.attributes == {"extant": True} @@ -1628,14 +1667,15 @@ def test_group_create_hierarchy_invalid_mixed_zarr_format( @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("defect", ["array/array", "array/group"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) async def test_create_hierarchy_invalid_nested( - store: Store, defect: tuple[str, str], zarr_format + impl: Literal["async", "sync"], store: Store, defect: tuple[str, str], zarr_format ) -> None: """ Test that create_hierarchy will not create a Zarr array that contains a Zarr group or Zarr array. """ - + path = "foo" if defect == "array/array": hierarchy_spec = { "array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), @@ -1648,13 +1688,21 @@ async def test_create_hierarchy_invalid_nested( } msg = "Only Zarr groups can contain other nodes." - with pytest.raises(ValueError, match=msg): - path = "foo" - await _collect_aiterator(create_hierarchy(store=store, path=path, nodes=hierarchy_spec)) + if impl == "sync": + with pytest.raises(ValueError, match=msg): + tuple(create_hierarchy(store=store, path=path, nodes=hierarchy_spec)) + elif impl == "async": + with pytest.raises(ValueError, match=msg): + await _collect_aiterator( + create_hierarchy_a(store=store, path=path, nodes=hierarchy_spec) + ) @pytest.mark.parametrize("store", ["memory"], indirect=True) -async def test_create_hierarchy_invalid_mixed_format(store: Store) -> None: +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_hierarchy_invalid_mixed_format( + impl: Literal["async", "sync"], store: Store +) -> None: """ Test that create_hierarchy will not create a Zarr group that contains a both Zarr v2 and Zarr v3 nodes. @@ -1666,25 +1714,39 @@ async def test_create_hierarchy_invalid_mixed_format(store: Store) -> None: "The following keys map to Zarr v3 nodes: ['v3']." "Ensure that all nodes have the same Zarr format." ) - with pytest.raises(ValueError, match=re.escape(msg)): - await _collect_aiterator( - create_hierarchy( - store=store, - path=path, - nodes={ - "v2": GroupMetadata(zarr_format=2), - "v3": GroupMetadata(zarr_format=3), - }, + nodes = { + "v2": GroupMetadata(zarr_format=2), + "v3": GroupMetadata(zarr_format=3), + } + if impl == "sync": + with pytest.raises(ValueError, match=re.escape(msg)): + tuple( + create_hierarchy( + store=store, + path=path, + nodes=nodes, + ) ) - ) + elif impl == "async": + with pytest.raises(ValueError, match=re.escape(msg)): + await _collect_aiterator( + create_hierarchy_a( + store=store, + path=path, + nodes=nodes, + ) + ) + else: + raise ValueError(f"Invalid impl: {impl}") @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) @pytest.mark.parametrize("path", ["", "foo"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) async def test_create_rooted_hierarchy_group( - store: Store, zarr_format, path: str, root_key: str + impl: Literal["async", "sync"], store: Store, zarr_format, path: str, root_key: str ) -> None: """ Test that the _create_rooted_hierarchy can create a group. @@ -1709,11 +1771,19 @@ async def test_create_rooted_hierarchy_group( } nodes_create = root_meta | groups_expected_meta | arrays_expected_meta + if impl == "async": + g = await _create_rooted_hierarchy_a(store=store, path=path, nodes=nodes_create) + assert isinstance(g, AsyncGroup) + members = await _collect_aiterator(g.members(max_depth=None)) + elif impl == "sync": + g = _create_rooted_hierarchy(store=store, path=path, nodes=nodes_create) + assert isinstance(g, Group) + members = g.members(max_depth=None) + else: + raise ValueError(f"Unknown implementation: {impl}") - g = await _create_rooted_hierarchy(store=store, path=path, nodes=nodes_create) assert g.metadata.attributes == {"path": root_key} - members = await _collect_aiterator(g.members(max_depth=None)) members_observed_meta = {k: v.metadata for k, v in members} members_expected_meta_relative = { k.removeprefix(root_key).lstrip("/"): v @@ -1723,7 +1793,10 @@ async def test_create_rooted_hierarchy_group( @pytest.mark.parametrize("store", ["memory"], indirect=True) -def test_create_hierarchy_implicit_groups(store: Store) -> None: +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_hierarchy_implicit_groups( + impl: Literal["async", "sync"], store: Store +) -> None: path = "" nodes = { "": GroupMetadata(zarr_format=3, attributes={"implicit": False}), @@ -1731,87 +1804,30 @@ def test_create_hierarchy_implicit_groups(store: Store) -> None: } hierarchy_parsed = _parse_hierarchy_dict(nodes) - g = _create_rooted_hierarchy_sync(store=store, path=path, nodes=nodes) - for key, value in hierarchy_parsed.items(): - assert g[key].metadata.attributes == value.attributes - - -# TODO: simplify testing sync versions of async functions. -@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) -@pytest.mark.parametrize("zarr_format", [2, 3]) -@pytest.mark.parametrize("root_key", ["", "root"]) -@pytest.mark.parametrize("path", ["", "foo"]) -def test_create_rooted_hierarchy_sync_group( - store: Store, zarr_format, path: str, root_key: str -) -> None: - """ - Test that the _create_rooted_hierarchy_sync can create a group. - """ - root_meta = {root_key: GroupMetadata(zarr_format=zarr_format, attributes={"path": root_key})} - group_names = ["a", "a/b"] - array_names = ["a/b/c", "a/b/d"] - - # just to ensure that we don't use the same name twice in tests - assert set(group_names) & set(array_names) == set() - - groups_expected_meta = { - _join_paths([root_key, node_name]): GroupMetadata( - zarr_format=zarr_format, attributes={"path": node_name} - ) - for node_name in group_names - } - arrays_expected_meta = { - _join_paths([root_key, node_name]): meta_from_array(np.zeros(4), zarr_format=zarr_format) - for node_name in array_names - } - - nodes_create = root_meta | groups_expected_meta | arrays_expected_meta - - g = _create_rooted_hierarchy_sync(store=store, path=path, nodes=nodes_create) - assert g.metadata.attributes == {"path": root_key} - - members = g.members(max_depth=None) - members_observed_meta = {k: v.metadata for k, v in members} - members_expected_meta_relative = { - k.removeprefix(root_key).lstrip("/"): v - for k, v in (groups_expected_meta | arrays_expected_meta).items() - } - assert members_observed_meta == members_expected_meta_relative + if impl == "sync": + g = _create_rooted_hierarchy(store=store, path=path, nodes=nodes) + for key, value in hierarchy_parsed.items(): + assert g[key].metadata.attributes == value.attributes + elif impl == "async": + g = await _create_rooted_hierarchy_a(store=store, path=path, nodes=nodes) + for key, value in hierarchy_parsed.items(): + assert (await g.getitem(key)).metadata.attributes == value.attributes + else: + raise ValueError(f"Unknown implementation: {impl}") @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) @pytest.mark.parametrize("path", ["", "foo"]) +@pytest.mark.parametrize("impl", ["async", "sync"]) async def test_create_rooted_hierarchy_array( - store: Store, zarr_format, path: str, root_key: str + impl: Literal["async", "sync"], store: Store, zarr_format, path: str, root_key: str ) -> None: """ - Test that the _create_rooted_hierarchy can create an array. + Test that _create_rooted_hierarchy can create an array. """ - root_meta = { - root_key: meta_from_array( - np.arange(3), zarr_format=zarr_format, attributes={"path": root_key} - ) - } - - nodes_create = root_meta - - a = await _create_rooted_hierarchy(store=store, path=path, nodes=nodes_create, overwrite=True) - assert isinstance(a, AsyncArray) - assert a.metadata.attributes == {"path": root_key} - -@pytest.mark.parametrize("store", ["memory", "local"], indirect=True) -@pytest.mark.parametrize("zarr_format", [2, 3]) -@pytest.mark.parametrize("root_key", ["", "root"]) -@pytest.mark.parametrize("path", ["", "foo"]) -async def test_create_rooted_hierarchy_sync_array( - store: Store, zarr_format, path: str, root_key: str -) -> None: - """ - Test that _create_rooted_hierarchy_sync can create an array. - """ root_meta = { root_key: meta_from_array( np.arange(3), zarr_format=zarr_format, attributes={"path": root_key} @@ -1820,12 +1836,21 @@ async def test_create_rooted_hierarchy_sync_array( nodes_create = root_meta - a = _create_rooted_hierarchy_sync(store=store, path=path, nodes=nodes_create, overwrite=True) - assert isinstance(a, Array) + if impl == "sync": + a = _create_rooted_hierarchy(store=store, path=path, nodes=nodes_create, overwrite=True) + assert isinstance(a, Array) + elif impl == "async": + a = await _create_rooted_hierarchy_a( + store=store, path=path, nodes=nodes_create, overwrite=True + ) + assert isinstance(a, AsyncArray) + else: + raise ValueError(f"Invalid impl: {impl}") assert a.metadata.attributes == {"path": root_key} -async def test_create_rooted_hierarchy_invalid() -> None: +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_create_rooted_hierarchy_invalid(impl: Literal["async", "sync"]) -> None: """ Ensure _create_rooted_hierarchy will raise a ValueError if the input does not contain a root node. @@ -1836,8 +1861,15 @@ async def test_create_rooted_hierarchy_invalid() -> None: "b": GroupMetadata(zarr_format=zarr_format), } msg = "The input does not specify a root node. " - with pytest.raises(ValueError, match=msg): - await _create_rooted_hierarchy(store=store, path="", nodes=nodes) + + if impl == "async": + with pytest.raises(ValueError, match=msg): + await _create_rooted_hierarchy_a(store=store, path="", nodes=nodes) + elif impl == "sync": + with pytest.raises(ValueError, match=msg): + _create_rooted_hierarchy(store=store, path="", nodes=nodes) + else: + raise ValueError(f"Invalid impl: {impl}") @pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")]) From 6b56342760b89f186a0b611a34f74a724ed345b6 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 28 Jan 2025 22:42:09 +0100 Subject: [PATCH 34/59] keyerror -> filenotfounderror --- src/zarr/core/group.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 41f9677159..c623db8c2c 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2926,8 +2926,8 @@ async def create_hierarchy_a( for key, value in extant_node_query.items(): if isinstance(value, BaseException): - if isinstance(value, KeyError): - # ignore KeyErrors, because they represent nodes we can safely create + if isinstance(value, FileNotFoundError): + # ignore FileNotFoundError, because they represent nodes we can safely create pass else: # Any other exception is a real error @@ -3402,13 +3402,14 @@ async def _iter_members_deep( async def _read_metadata_v3(store: Store, path: str) -> ArrayV3Metadata | GroupMetadata: """ Given a store_path, return ArrayV3Metadata or GroupMetadata defined by the metadata - document stored at store_path.path / zarr.json. If no such document is found, raise a KeyError. + document stored at store_path.path / zarr.json. If no such document is found, raise a + FileNotFoundError. """ zarr_json_bytes = await store.get( _join_paths([path, ZARR_JSON]), prototype=default_buffer_prototype() ) if zarr_json_bytes is None: - raise KeyError(path) + raise FileNotFoundError(path) else: zarr_json = json.loads(zarr_json_bytes.to_bytes()) return _build_metadata_v3(zarr_json) @@ -3441,7 +3442,7 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM else: if zgroup_bytes is None: # neither .zarray or .zgroup were found results in KeyError - raise KeyError(path) + raise FileNotFoundError(path) else: zmeta = json.loads(zgroup_bytes.to_bytes()) From 3be878d4486900dd1c848dc773bdc180eaf36980 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 28 Jan 2025 22:54:05 +0100 Subject: [PATCH 35/59] keyerror -> filenotfounderror, fixup --- src/zarr/core/group.py | 14 +++++++------- tests/test_group.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index c623db8c2c..583dade38a 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -693,12 +693,12 @@ async def getitem( # Consolidated metadata lets us avoid some I/O operations so try that first. if self.metadata.consolidated_metadata is not None: return self._getitem_consolidated(store_path, key, prefix=self.name) - elif self.metadata.zarr_format == 3: - return await _read_node_v3(store=self.store, path=store_path.path) - elif self.metadata.zarr_format == 2: - return await _read_node_v2(store=self.store, path=store_path.path) - else: - raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") + try: + return await _read_node_a( + store=store_path.store, path=store_path.path, zarr_format=self.metadata.zarr_format + ) + except FileNotFoundError as e: + raise KeyError(key) from e def _getitem_consolidated( self, store_path: StorePath, key: str, prefix: str @@ -3419,7 +3419,7 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM """ Given a store_path, return ArrayV2Metadata or GroupMetadata defined by the metadata document stored at store_path.path / (.zgroup | .zarray). If no such document is found, - raise a KeyError. + raise a FileNotFoundError. """ # TODO: consider first fetching array metadata, and only fetching group metadata when we don't # find an array diff --git a/tests/test_group.py b/tests/test_group.py index 52ba9e4827..d66ad9f140 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1545,7 +1545,7 @@ async def test_create_hierarchy( ) assert extra_group.metadata.attributes == {"path": "group/extra"} else: - with pytest.raises(KeyError): + with pytest.raises(FileNotFoundError): read_node(store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format) assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} From f3c506f53fd951ae1dd25b304a2e92be140755bb Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 28 Jan 2025 23:20:56 +0100 Subject: [PATCH 36/59] add top-level exports --- src/zarr/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/zarr/__init__.py b/src/zarr/__init__.py index bcbdaf7c19..eb4202d58d 100644 --- a/src/zarr/__init__.py +++ b/src/zarr/__init__.py @@ -30,7 +30,7 @@ ) from zarr.core.array import Array, AsyncArray from zarr.core.config import config -from zarr.core.group import AsyncGroup, Group +from zarr.core.group import AsyncGroup, Group, create_hierarchy, create_nodes # in case setuptools scm screw up and find version to be 0.0.0 assert not __version__.startswith("0.0.0") @@ -50,6 +50,8 @@ "create", "create_array", "create_group", + "create_hierarchy", + "create_nodes", "empty", "empty_like", "full", From 32e06fa5a211b03cab4edb314fcdeb9b121d40de Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 29 Jan 2025 12:34:21 +0100 Subject: [PATCH 37/59] mildly refactor node input validation --- src/zarr/core/group.py | 115 +++++++++++++++++++++++++---------------- tests/test_group.py | 4 +- 2 files changed, 73 insertions(+), 46 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 583dade38a..0cbf1b527c 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1400,6 +1400,7 @@ async def _members( async def create_hierarchy( self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], + *, overwrite: bool, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ @@ -1421,16 +1422,37 @@ async def create_hierarchy( ------- An asynchronous iterator over the created arrays and / or groups. """ + # check that all the nodes have the same zarr_format as Self + for key, value in nodes.items(): + if value.zarr_format != self.metadata.zarr_format: + msg = ( + "The zarr_format of the nodes must be the same as the parent group. " + f"The node at {key} has zarr_format {value.zarr_format}, but the parent group" + f" has zarr_format {self.metadata.zarr_format}." + ) + raise ValueError(msg) + semaphore = asyncio.Semaphore(config.get("async.concurrency")) - async for node in create_hierarchy_a( - store=self.store, - path=self.path, - nodes=nodes, - semaphore=semaphore, - overwrite=overwrite, - allow_root=False, - ): - yield node + + try: + async for node in create_hierarchy_a( + store=self.store, + path=self.path, + nodes=nodes, + semaphore=semaphore, + overwrite=overwrite, + allow_root=False, + ): + yield node + + except RootedHierarchyError as e: + msg = ( + "The input defines a root node, but a root node already exists, namely this Group instance." + "It is an error to use this method to create a root node. " + "Remove the root node from the input dict, or use a function like " + "create_rooted_hierarchy to create a rooted hierarchy." + ) + raise ValueError(msg) from e async def keys(self) -> AsyncGenerator[str, None]: """Iterate over member names.""" @@ -2057,17 +2079,21 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], def create_hierarchy( self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], + *, overwrite: bool = False, - ) -> Iterator[ - tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] - ]: + ) -> Iterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ Create a hierarchy of arrays or groups rooted at this group. This method takes a dictionary where the keys are the names of the arrays or groups to create and the values are the metadata objects for the arrays or groups. - The method returns a dict containing the created nodes. + This method returns an iterator of created Group or Array objects. + + Note: this method will create additional groups as needed to ensure that a hierarchy is + complete. Usage like ``create_hierarchy({'a/b': GroupMetadata()})`` defines an implicit + group at ``a``. This function will ensure that the group at ``a`` exists, first by checking + if one already exists, and if not, creating one. Parameters ---------- @@ -2076,30 +2102,21 @@ def create_hierarchy( Returns ------- - A dict containing the created nodes, with the same keys as the input + An iterator of Array or Group objects. + + Examples + -------- + >>> import zarr + >>> from zarr.core.group import GroupMetadata + >>> root = zarr.create_group(store={}) + >>> for key, val in root.create_hierarchy({'a/b/c': GroupMetadata()}): + ... print(key, val) + ... + + + """ - # check that all the nodes have the same zarr_format as Self - for key, value in nodes.items(): - if value.zarr_format != self.metadata.zarr_format: - msg = ( - "The zarr_format of the nodes must be the same as the parent group. " - f"The node at {key} has zarr_format {value.zarr_format}, but the parent group" - f" has zarr_format {self.metadata.zarr_format}." - ) - raise ValueError(msg) - try: - nodes_created = self._sync_iter( - self._async_group.create_hierarchy(nodes, overwrite=overwrite) - ) - for n in nodes_created: - yield (_join_paths([self.path, n.name]), n) - except RootedHierarchyError as e: - msg = ( - "The input defines a root node, but a root node already exists, namely this Group instance." - "It is an error to use this method to create a root node. " - "Remove the root node from the input dict, or use a function like _from_flat to create a rooted hierarchy." - ) - raise ValueError(msg) from e + yield from self._sync_iter(self._async_group.create_hierarchy(nodes, overwrite=overwrite)) def keys(self) -> Generator[str, None]: """Return an iterator over group member names. @@ -2879,14 +2896,7 @@ async def create_hierarchy_a( AsyncGroup | AsyncArray The created nodes in the order they are created. """ - nodes_parsed = _parse_hierarchy_dict(nodes) - - if not allow_root and "" in nodes_parsed: - msg = ( - "Found the key '' in nodes (after key name normalization). That key denotes the root of a hierarchy, but ``allow_root`` is False, and so creating this node " - "is not allowed. Either remove this key from ``nodes``, or set ``allow_root`` to True." - ) - raise RootedHierarchyError(msg) + nodes_parsed = _parse_hierarchy_dict(data=nodes, allow_root=allow_root) # we allow creating empty hierarchies -- it's a no-op if len(nodes_parsed) > 0: @@ -3178,7 +3188,9 @@ def _join_paths(paths: Iterable[str]) -> str: def _parse_hierarchy_dict( + *, data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + allow_root: bool = True, ) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: """ Take an input Mapping of str: node pairs, and parse it into @@ -3201,6 +3213,11 @@ def _parse_hierarchy_dict( - No arrays can contain group or arrays (i.e., all arrays must be leaf nodes). - All arrays and groups must have the same ``zarr_format`` value. + if ``allow_root`` is set to False, then the input is also checked to ensure that it does not + contain a key that normalizes to the empty string (''), as this is reserved for the root node, + and in some situations creating a root node is not permitted, for example, when creating a + hierarchy relative to an existing group. + This function ensures that the input is transformed into a specification of a complete and valid Zarr hierarchy. """ @@ -3227,6 +3244,15 @@ def _parse_hierarchy_dict( _normalize_path_keys(data) ) + if not allow_root and "" in data_normed: + msg = ( + "Found the empty string '' in data after key name normalization. " + "That key denotes the root of a hierarchy, but ``allow_root`` is False, " + "and so creating this node is not allowed. Remove the problematic key from the input, " + "or set ``allow_root`` to True." + ) + raise RootedHierarchyError(msg) + out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data_normed} for k, v in data.items(): @@ -3246,6 +3272,7 @@ def _parse_hierarchy_dict( "This is invalid. Only Zarr groups can contain other nodes." ) raise ValueError(msg) + return out diff --git a/tests/test_group.py b/tests/test_group.py index d66ad9f140..6e02ccf718 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -371,7 +371,7 @@ def test_group_getitem(store: Store, zarr_format: ZarrFormat, consolidated: bool ) with pytest.raises(KeyError): - # We've chosen to trust the consolidted metadata, which doesn't + # We've chosen to trust the consolidated metadata, which doesn't # contain this array group["subgroup/subarray"] @@ -1803,7 +1803,7 @@ async def test_create_hierarchy_implicit_groups( "a/b/c": GroupMetadata(zarr_format=3, attributes={"implicit": False}), } - hierarchy_parsed = _parse_hierarchy_dict(nodes) + hierarchy_parsed = _parse_hierarchy_dict(data=nodes) if impl == "sync": g = _create_rooted_hierarchy(store=store, path=path, nodes=nodes) for key, value in hierarchy_parsed.items(): From 8bd0b57a6f1c88ed218b42f5109035bb6e0e39c2 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Wed, 29 Jan 2025 12:51:46 +0100 Subject: [PATCH 38/59] simplify path normalization --- src/zarr/core/group.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 0cbf1b527c..0ccb44a431 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3633,7 +3633,7 @@ async def _create_rooted_hierarchy_a( ) raise ValueError(msg) else: - root = roots[0] + root_key = roots[0] semaphore = asyncio.Semaphore(config.get("async.concurrency")) @@ -3643,9 +3643,7 @@ async def _create_rooted_hierarchy_a( store=store, path=path, nodes=nodes, semaphore=semaphore, overwrite=overwrite ) } - # the names of the created nodes will be relative to the store_path instance - root_relative_to_store_path = _join_paths([path, root]) - return nodes_created[root_relative_to_store_path] + return nodes_created[_join_paths([path, root_key])] def _create_rooted_hierarchy( From d05a43ccf1866e914d65428ab27818aa2c6c2c3f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 2 Feb 2025 18:06:51 +0100 Subject: [PATCH 39/59] refactor to separate sync and async routines --- src/zarr/__init__.py | 6 +- src/zarr/api/asynchronous.py | 12 +++- src/zarr/api/synchronous.py | 106 +++++++++++++++++++++++++++++- src/zarr/core/group.py | 123 +++-------------------------------- tests/test_api.py | 8 +++ tests/test_group.py | 92 ++++++++++---------------- 6 files changed, 170 insertions(+), 177 deletions(-) diff --git a/src/zarr/__init__.py b/src/zarr/__init__.py index eb4202d58d..31d0797af6 100644 --- a/src/zarr/__init__.py +++ b/src/zarr/__init__.py @@ -8,6 +8,9 @@ create, create_array, create_group, + create_hierarchy, + create_nodes, + create_rooted_hierarchy, empty, empty_like, full, @@ -30,7 +33,7 @@ ) from zarr.core.array import Array, AsyncArray from zarr.core.config import config -from zarr.core.group import AsyncGroup, Group, create_hierarchy, create_nodes +from zarr.core.group import AsyncGroup, Group # in case setuptools scm screw up and find version to be 0.0.0 assert not __version__.startswith("0.0.0") @@ -52,6 +55,7 @@ "create_group", "create_hierarchy", "create_nodes", + "create_rooted_hierarchy", "empty", "empty_like", "full", diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 8eba4fc152..f0bb6f0546 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -23,7 +23,14 @@ _warn_write_empty_chunks_kwarg, parse_dtype, ) -from zarr.core.group import AsyncGroup, ConsolidatedMetadata, GroupMetadata +from zarr.core.group import ( + AsyncGroup, + ConsolidatedMetadata, + GroupMetadata, + create_hierarchy, + create_nodes, + create_rooted_hierarchy, +) from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v2 import _default_compressor, _default_filters from zarr.errors import NodeTypeValidationError @@ -48,6 +55,9 @@ "copy_store", "create", "create_array", + "create_hierarchy", + "create_nodes", + "create_rooted_hierarchy", "empty", "empty_like", "full", diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 305446ec97..600f36079b 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -8,16 +8,17 @@ import zarr.core.array from zarr._compat import _deprecate_positional_args from zarr.core.array import Array, AsyncArray -from zarr.core.group import Group -from zarr.core.sync import sync +from zarr.core.group import Group, GroupMetadata, _parse_async_node +from zarr.core.sync import _collect_aiterator, sync if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Iterator import numpy as np import numpy.typing as npt from zarr.abc.codec import Codec + from zarr.abc.store import Store from zarr.api.asynchronous import ArrayLike, PathLike from zarr.core.array import ( CompressorsLike, @@ -36,6 +37,7 @@ ShapeLike, ZarrFormat, ) + from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.storage import StoreLike __all__ = [ @@ -46,6 +48,9 @@ "copy_store", "create", "create_array", + "create_hierarchy", + "create_nodes", + "create_rooted_hierarchy", "empty", "empty_like", "full", @@ -1124,3 +1129,98 @@ def zeros_like(a: ArrayLike, **kwargs: Any) -> Array: The new array. """ return Array(sync(async_api.zeros_like(a, **kwargs))) + + +def create_hierarchy( + store: Store, + path: str, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, + allow_root: bool = True, +) -> Iterator[Group | Array]: + """ + Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input + will be created as needed. + + This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy + concurrently. Arrays and Groups are yielded in the order they are created. + + Parameters + ---------- + store : Store + The storage backend to use. + path : str + The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with + ``path`` prior to creating nodes. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + allow_root : bool + Whether to allow a root node to be created. If ``False``, attempting to create a root node + will result in an error. Use this option when calling this function as part of a method + defined on ``AsyncGroup`` instances, because in this case the root node has already been + created. + + Yields + ------ + Group | Array + The created nodes in the order they are created. + """ + coro = async_api.create_hierarchy( + store=store, path=path, nodes=nodes, overwrite=overwrite, allow_root=allow_root + ) + + for result in sync(_collect_aiterator(coro)): + yield _parse_async_node(result) + + +def create_nodes( + *, store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] +) -> Iterator[Group | Array]: + """Create a collection of arrays and / or groups concurrently. + + Note: no attempt is made to validate that these arrays and / or groups collectively form a + valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that + the ``nodes`` parameter satisfies any correctness constraints. + + Parameters + ---------- + store : Store + The storage backend to use. + path : str + The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with + ``path`` prior to creating nodes. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + + Yields + ------ + Group | Array + The created nodes in the order they are created. + """ + coro = async_api.create_nodes(store=store, path=path, nodes=nodes) + + for result in sync(_collect_aiterator(coro)): + yield _parse_async_node(result) + + +def create_rooted_hierarchy( + *, + store: Store, + path: str, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> Group | Array: + """ + Create a ``Group`` from a store and a dict of metadata documents. Calls the async method + ``_create_rooted_hierarchy`` and waits for the result. + """ + async_node = sync( + async_api.create_rooted_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite) + ) + return _parse_async_node(async_node) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 0ccb44a431..6ff316999f 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -57,7 +57,7 @@ from zarr.core.config import config from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v3 import V3JsonEncoder -from zarr.core.sync import SyncMixin, _collect_aiterator, sync +from zarr.core.sync import SyncMixin, sync from zarr.errors import ( ContainsArrayError, ContainsGroupError, @@ -1435,7 +1435,7 @@ async def create_hierarchy( semaphore = asyncio.Semaphore(config.get("async.concurrency")) try: - async for node in create_hierarchy_a( + async for node in create_hierarchy( store=self.store, path=self.path, nodes=nodes, @@ -2081,7 +2081,7 @@ def create_hierarchy( nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], *, overwrite: bool = False, - ) -> Iterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: + ) -> Iterator[Group | Array]: """ Create a hierarchy of arrays or groups rooted at this group. @@ -2116,7 +2116,8 @@ def create_hierarchy( """ - yield from self._sync_iter(self._async_group.create_hierarchy(nodes, overwrite=overwrite)) + for node in self._sync_iter(self._async_group.create_hierarchy(nodes, overwrite=overwrite)): + yield _parse_async_node(node) def keys(self) -> Generator[str, None]: """Return an iterator over group member names. @@ -2854,7 +2855,7 @@ def array( ) -async def create_hierarchy_a( +async def create_hierarchy( *, store: Store, path: str, @@ -2967,58 +2968,11 @@ async def create_hierarchy_a( k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups } - async for node in create_nodes_a( - store=store, path=path, nodes=nodes_parsed, semaphore=semaphore - ): + async for node in create_nodes(store=store, path=path, nodes=nodes_parsed, semaphore=semaphore): yield node -def create_hierarchy( - store: Store, - path: str, - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - overwrite: bool = False, - allow_root: bool = True, -) -> Iterator[Group | Array]: - """ - Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input - will be created as needed. - - This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy - concurrently. Arrays and Groups are yielded in the order they are created. - - Parameters - ---------- - store : Store - The storage backend to use. - path : str - The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with - ``path`` prior to creating nodes. - nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] - A dictionary defining the hierarchy. The keys are the paths of the nodes - in the hierarchy, and the values are the metadata of the nodes. The - metadata must be either an instance of GroupMetadata, ArrayV3Metadata - or ArrayV2Metadata. - allow_root : bool - Whether to allow a root node to be created. If ``False``, attempting to create a root node - will result in an error. Use this option when calling this function as part of a method - defined on ``AsyncGroup`` instances, because in this case the root node has already been - created. - - Yields - ------ - Group | Array - The created nodes in the order they are created. - """ - coro = create_hierarchy_a( - store=store, path=path, nodes=nodes, overwrite=overwrite, allow_root=allow_root - ) - - for result in sync(_collect_aiterator(coro)): - yield _parse_async_node(result) - - -async def create_nodes_a( +async def create_nodes( *, store: Store, path: str, @@ -3119,46 +3073,6 @@ async def create_nodes_a( continue -def create_nodes( - *, - store: Store, - path: str, - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - semaphore: asyncio.Semaphore | None = None, -) -> Iterator[Group | Array]: - """Create a collection of arrays and / or groups concurrently. - - Note: no attempt is made to validate that these arrays and / or groups collectively form a - valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that - the ``nodes`` parameter satisfies any correctness constraints. - - Parameters - ---------- - store : Store - The storage backend to use. - path : str - The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with - ``path`` prior to creating nodes. - nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] - A dictionary defining the hierarchy. The keys are the paths of the nodes - in the hierarchy, and the values are the metadata of the nodes. The - metadata must be either an instance of GroupMetadata, ArrayV3Metadata - or ArrayV2Metadata. - semaphore : asyncio.Semaphore | None - An optional semaphore to limit the number of concurrent tasks. If not - provided, the number of concurrent tasks is unlimited. - - Yields - ------ - Group | Array - The created nodes in the order they are created. - """ - coro = create_nodes_a(store=store, path=path, nodes=nodes, semaphore=semaphore) - - for result in sync(_collect_aiterator(coro)): - yield _parse_async_node(result) - - T = TypeVar("T") @@ -3611,7 +3525,7 @@ def _persist_metadata( ) -async def _create_rooted_hierarchy_a( +async def create_rooted_hierarchy( *, store: Store, path: str, @@ -3639,25 +3553,8 @@ async def _create_rooted_hierarchy_a( nodes_created = { x.path: x - async for x in create_hierarchy_a( + async for x in create_hierarchy( store=store, path=path, nodes=nodes, semaphore=semaphore, overwrite=overwrite ) } return nodes_created[_join_paths([path, root_key])] - - -def _create_rooted_hierarchy( - *, - store: Store, - path: str, - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - overwrite: bool = False, -) -> Group | Array: - """ - Create a ``Group`` from a store and a dict of metadata documents. Calls the async method - ``_create_rooted_hierarchy`` and waits for the result. - """ - async_node = sync( - _create_rooted_hierarchy_a(store=store, path=path, nodes=nodes, overwrite=overwrite) - ) - return _parse_async_node(async_node) diff --git a/tests/test_api.py b/tests/test_api.py index aacd558f2a..bb67769c92 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -8,6 +8,7 @@ import zarr import zarr.api.asynchronous +import zarr.api.synchronous import zarr.core.group from zarr import Array, Group from zarr.abc.store import Store @@ -1121,3 +1122,10 @@ def test_open_array_with_mode_r_plus(store: Store) -> None: assert isinstance(z2, Array) assert (z2[:] == 1).all() z2[:] = 3 + + +def test_api_exports() -> None: + """ + Test that the sync API and the async API export the same objects + """ + assert zarr.api.asynchronous.__all__ == zarr.api.synchronous.__all__ diff --git a/tests/test_group.py b/tests/test_group.py index 6e02ccf718..5bbfc76a62 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -15,6 +15,7 @@ import zarr import zarr.api.asynchronous import zarr.api.synchronous +import zarr.api.synchronous as sync_api import zarr.storage from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store @@ -24,16 +25,12 @@ ConsolidatedMetadata, GroupMetadata, _build_metadata_v3, - _create_rooted_hierarchy, - _create_rooted_hierarchy_a, _join_paths, _normalize_path_keys, _normalize_paths, - _parse_hierarchy_dict, create_hierarchy, - create_hierarchy_a, create_nodes, - create_nodes_a, + create_rooted_hierarchy, read_node, ) from zarr.core.metadata.v3 import ArrayV3Metadata @@ -1477,12 +1474,12 @@ async def test_create_nodes( "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } - if impl == "async": + if impl == "sync": + created = tuple(sync_api.create_nodes(store=store, path=path, nodes=expected_meta)) + elif impl == "async": created = tuple( - [a async for a in create_nodes_a(store=store, path=path, nodes=expected_meta)] + [a async for a in create_nodes(store=store, path=path, nodes=expected_meta)] ) - elif impl == "sync": - created = tuple(create_nodes(store=store, path=path, nodes=expected_meta)) else: raise ValueError(f"Invalid impl: {impl}") @@ -1519,17 +1516,19 @@ async def test_create_hierarchy( expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} # initialize the group with some nodes - _ = tuple(create_nodes(store=store, path=path, nodes=pre_existing_nodes)) + _ = tuple(sync_api.create_nodes(store=store, path=path, nodes=pre_existing_nodes)) if impl == "sync": created = tuple( - create_hierarchy(store=store, path=path, nodes=hierarchy_spec, overwrite=overwrite) + sync_api.create_hierarchy( + store=store, path=path, nodes=hierarchy_spec, overwrite=overwrite + ) ) elif impl == "async": created = tuple( [ a - async for a in create_hierarchy_a( + async for a in create_hierarchy( store=store, path=path, nodes=hierarchy_spec, overwrite=overwrite ) ] @@ -1578,14 +1577,14 @@ async def test_create_hierarchy_existing_nodes( err_cls = ContainsGroupError # write the extant metadata - tuple(create_nodes(store=store, path=path, nodes={extant_node_path: extant_metadata})) + tuple(sync_api.create_nodes(store=store, path=path, nodes={extant_node_path: extant_metadata})) msg = f"{extant_node} exists in store {store!r} at path {extant_node_path!r}." # ensure that we cannot invoke create_hierarchy with overwrite=False here if impl == "sync": with pytest.raises(err_cls, match=re.escape(msg)): tuple( - create_hierarchy( + sync_api.create_hierarchy( store=store, path=path, nodes={"node": new_metadata}, overwrite=False ) ) @@ -1594,7 +1593,7 @@ async def test_create_hierarchy_existing_nodes( tuple( [ x - async for x in create_hierarchy_a( + async for x in create_hierarchy( store=store, path=path, nodes={"node": new_metadata}, overwrite=False ) ] @@ -1690,12 +1689,10 @@ async def test_create_hierarchy_invalid_nested( msg = "Only Zarr groups can contain other nodes." if impl == "sync": with pytest.raises(ValueError, match=msg): - tuple(create_hierarchy(store=store, path=path, nodes=hierarchy_spec)) + tuple(sync_api.create_hierarchy(store=store, path=path, nodes=hierarchy_spec)) elif impl == "async": with pytest.raises(ValueError, match=msg): - await _collect_aiterator( - create_hierarchy_a(store=store, path=path, nodes=hierarchy_spec) - ) + await _collect_aiterator(create_hierarchy(store=store, path=path, nodes=hierarchy_spec)) @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1721,7 +1718,7 @@ async def test_create_hierarchy_invalid_mixed_format( if impl == "sync": with pytest.raises(ValueError, match=re.escape(msg)): tuple( - create_hierarchy( + sync_api.create_hierarchy( store=store, path=path, nodes=nodes, @@ -1730,7 +1727,7 @@ async def test_create_hierarchy_invalid_mixed_format( elif impl == "async": with pytest.raises(ValueError, match=re.escape(msg)): await _collect_aiterator( - create_hierarchy_a( + create_hierarchy( store=store, path=path, nodes=nodes, @@ -1771,14 +1768,14 @@ async def test_create_rooted_hierarchy_group( } nodes_create = root_meta | groups_expected_meta | arrays_expected_meta - if impl == "async": - g = await _create_rooted_hierarchy_a(store=store, path=path, nodes=nodes_create) - assert isinstance(g, AsyncGroup) - members = await _collect_aiterator(g.members(max_depth=None)) - elif impl == "sync": - g = _create_rooted_hierarchy(store=store, path=path, nodes=nodes_create) + if impl == "sync": + g = sync_api.create_rooted_hierarchy(store=store, path=path, nodes=nodes_create) assert isinstance(g, Group) members = g.members(max_depth=None) + elif impl == "async": + g = await create_rooted_hierarchy(store=store, path=path, nodes=nodes_create) + assert isinstance(g, AsyncGroup) + members = await _collect_aiterator(g.members(max_depth=None)) else: raise ValueError(f"Unknown implementation: {impl}") @@ -1792,30 +1789,6 @@ async def test_create_rooted_hierarchy_group( assert members_observed_meta == members_expected_meta_relative -@pytest.mark.parametrize("store", ["memory"], indirect=True) -@pytest.mark.parametrize("impl", ["async", "sync"]) -async def test_create_hierarchy_implicit_groups( - impl: Literal["async", "sync"], store: Store -) -> None: - path = "" - nodes = { - "": GroupMetadata(zarr_format=3, attributes={"implicit": False}), - "a/b/c": GroupMetadata(zarr_format=3, attributes={"implicit": False}), - } - - hierarchy_parsed = _parse_hierarchy_dict(data=nodes) - if impl == "sync": - g = _create_rooted_hierarchy(store=store, path=path, nodes=nodes) - for key, value in hierarchy_parsed.items(): - assert g[key].metadata.attributes == value.attributes - elif impl == "async": - g = await _create_rooted_hierarchy_a(store=store, path=path, nodes=nodes) - for key, value in hierarchy_parsed.items(): - assert (await g.getitem(key)).metadata.attributes == value.attributes - else: - raise ValueError(f"Unknown implementation: {impl}") - - @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) @@ -1833,14 +1806,15 @@ async def test_create_rooted_hierarchy_array( np.arange(3), zarr_format=zarr_format, attributes={"path": root_key} ) } - nodes_create = root_meta if impl == "sync": - a = _create_rooted_hierarchy(store=store, path=path, nodes=nodes_create, overwrite=True) + a = sync_api.create_rooted_hierarchy( + store=store, path=path, nodes=nodes_create, overwrite=True + ) assert isinstance(a, Array) elif impl == "async": - a = await _create_rooted_hierarchy_a( + a = await create_rooted_hierarchy( store=store, path=path, nodes=nodes_create, overwrite=True ) assert isinstance(a, AsyncArray) @@ -1860,14 +1834,14 @@ async def test_create_rooted_hierarchy_invalid(impl: Literal["async", "sync"]) - "a": GroupMetadata(zarr_format=zarr_format), "b": GroupMetadata(zarr_format=zarr_format), } + path = "" msg = "The input does not specify a root node. " - - if impl == "async": + if impl == "sync": with pytest.raises(ValueError, match=msg): - await _create_rooted_hierarchy_a(store=store, path="", nodes=nodes) - elif impl == "sync": + sync_api.create_rooted_hierarchy(store=store, path=path, nodes=nodes) + elif impl == "async": with pytest.raises(ValueError, match=msg): - _create_rooted_hierarchy(store=store, path="", nodes=nodes) + await create_rooted_hierarchy(store=store, path=path, nodes=nodes) else: raise ValueError(f"Invalid impl: {impl}") From 29bab74b882e3c241a9abcbfbc7f3efc25a42e8d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 2 Feb 2025 21:40:42 +0100 Subject: [PATCH 40/59] remove semaphore kwarg, and add test for concurrency limit sensitivity --- src/zarr/core/group.py | 50 +++++++++++------------------------------- tests/test_group.py | 26 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 6ff316999f..548793b4ea 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import contextlib import itertools import json import logging @@ -1432,14 +1431,11 @@ async def create_hierarchy( ) raise ValueError(msg) - semaphore = asyncio.Semaphore(config.get("async.concurrency")) - try: async for node in create_hierarchy( store=self.store, path=self.path, nodes=nodes, - semaphore=semaphore, overwrite=overwrite, allow_root=False, ): @@ -2860,7 +2856,6 @@ async def create_hierarchy( store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], - semaphore: asyncio.Semaphore | None = None, overwrite: bool = False, allow_root: bool = True, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: @@ -2883,9 +2878,6 @@ async def create_hierarchy( in the hierarchy, and the values are the metadata of the nodes. The metadata must be either an instance of GroupMetadata, ArrayV3Metadata or ArrayV2Metadata. - semaphore : asyncio.Semaphore | None - An optional semaphore to limit the number of concurrent tasks. If not - provided, the number of concurrent tasks is unlimited. allow_root : bool Whether to allow a root node to be created. If ``False``, attempting to create a root node will result in an error. Use this option when calling this function as part of a method @@ -2968,7 +2960,7 @@ async def create_hierarchy( k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups } - async for node in create_nodes(store=store, path=path, nodes=nodes_parsed, semaphore=semaphore): + async for node in create_nodes(store=store, path=path, nodes=nodes_parsed): yield node @@ -2977,7 +2969,6 @@ async def create_nodes( store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - semaphore: asyncio.Semaphore | None = None, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """Create a collection of arrays and / or groups concurrently. @@ -2997,21 +2988,16 @@ async def create_nodes( in the hierarchy, and the values are the metadata of the nodes. The metadata must be either an instance of GroupMetadata, ArrayV3Metadata or ArrayV2Metadata. - semaphore : asyncio.Semaphore | None - An optional semaphore to limit the number of concurrent tasks. If not - provided, the number of concurrent tasks is unlimited. Yields ------ AsyncGroup | AsyncArray The created nodes in the order they are created. """ - ctx: asyncio.Semaphore | contextlib.nullcontext[None] - if semaphore is None: - ctx = contextlib.nullcontext() - else: - ctx = semaphore + # Note: the only way to alter this value is via the config. If that's undesirable for some reason, + # then we should consider adding a keyword argument this this function + ctx = asyncio.Semaphore(config.get("async.concurrency")) create_tasks: list[Coroutine[None, None, str]] = [] for key, value in nodes.items(): @@ -3023,15 +3009,12 @@ async def create_nodes( async with ctx: for coro in asyncio.as_completed(create_tasks): created_key = await coro + # the created key will be in the store key space, i.e. it is an absolute path. The key + # will also end with the name of a metadata document. We have to remove the store_path.path + # component of the key to bring it back to the relative key space of store_path - # the created key will be in the store key space, and it will end with the name of - # a metadata document. - # we have to remove the store_path.path - # component of that path to bring it back to the relative key space of store_path - - # the relative path of the object we just created -- we need this to track which metadata documents - # were written so that we can yield a complete v2 Array / Group class after both .zattrs - # and the metadata JSON was created. + # we need this to track which metadata documents were written so that we can yield a + # complete v2 Array / Group class after both .zattrs and the metadata JSON was created. object_path_relative = created_key.removeprefix(path).lstrip("/") created_object_keys.append(object_path_relative) @@ -3046,12 +3029,9 @@ async def create_nodes( meta_out = nodes[node_name] if meta_out.zarr_format == 3: - # yes, it is silly that we relativize, then de-relativize this same path - node_store_path = StorePath(store=store, path=path) / node_name - if isinstance(meta_out, GroupMetadata): - yield AsyncGroup(metadata=meta_out, store_path=node_store_path) - else: - yield AsyncArray(metadata=meta_out, store_path=node_store_path) + yield _build_node( + store=store, path=_join_paths([path, node_name]), metadata=meta_out + ) else: # For zarr v2 # we only want to yield when both the metadata and attributes are created @@ -3549,12 +3529,8 @@ async def create_rooted_hierarchy( else: root_key = roots[0] - semaphore = asyncio.Semaphore(config.get("async.concurrency")) - nodes_created = { x.path: x - async for x in create_hierarchy( - store=store, path=path, nodes=nodes, semaphore=semaphore, overwrite=overwrite - ) + async for x in create_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite) } return nodes_created[_join_paths([path, root_key])] diff --git a/tests/test_group.py b/tests/test_group.py index 5bbfc76a62..f9bdb76821 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -21,6 +21,7 @@ from zarr.abc.store import Store from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype +from zarr.core.config import config as zarr_config from zarr.core.group import ( ConsolidatedMetadata, GroupMetadata, @@ -1487,6 +1488,31 @@ async def test_create_nodes( assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} +@pytest.mark.parametrize("store", ["memory"], indirect=True) +def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: + """ + Test that the execution time of create_nodes can be constrained by the async concurrency + configuration setting. + """ + set_latency = 0.02 + num_groups = 10 + groups = {str(idx): GroupMetadata() for idx in range(num_groups)} + + latency_store = LatencyStore(store, set_latency=set_latency) + + # check how long it takes to iterate over the groups + # if create_nodes is sensitive to IO latency, + # this should take (num_groups * get_latency) seconds + # otherwise, it should take only marginally more than get_latency seconds + + with zarr_config.set({"async.concurrency": 1}): + start = time.time() + _ = tuple(sync_api.create_nodes(store=latency_store, path="", nodes=groups)) + elapsed = time.time() - start + + assert elapsed > num_groups * set_latency + + @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("overwrite", [True, False]) @pytest.mark.parametrize("impl", ["async", "sync"]) From 2f02c26c5495462eb9a571c9cea5ffd7c5a9c3b7 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 2 Feb 2025 22:04:23 +0100 Subject: [PATCH 41/59] wire up semaphore correctly, thanks to a test --- src/zarr/api/asynchronous.py | 2 + src/zarr/api/synchronous.py | 49 ++++++++++++++- src/zarr/core/group.py | 114 ++++++++++++++++++----------------- tests/test_group.py | 3 +- 4 files changed, 107 insertions(+), 61 deletions(-) diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index f0bb6f0546..4b90f41608 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -30,6 +30,7 @@ create_hierarchy, create_nodes, create_rooted_hierarchy, + read_node, ) from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v2 import _default_compressor, _default_filters @@ -71,6 +72,7 @@ "open_consolidated", "open_group", "open_like", + "read_node", "save", "save_array", "save_group", diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 600f36079b..03eb232041 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -7,6 +7,7 @@ import zarr.api.asynchronous as async_api import zarr.core.array from zarr._compat import _deprecate_positional_args +from zarr.abc.store import Store from zarr.core.array import Array, AsyncArray from zarr.core.group import Group, GroupMetadata, _parse_async_node from zarr.core.sync import _collect_aiterator, sync @@ -1201,7 +1202,7 @@ def create_nodes( Yields ------ Group | Array - The created nodes in the order they are created. + The created nodes. """ coro = async_api.create_nodes(store=store, path=path, nodes=nodes) @@ -1217,10 +1218,52 @@ def create_rooted_hierarchy( overwrite: bool = False, ) -> Group | Array: """ - Create a ``Group`` from a store and a dict of metadata documents. Calls the async method - ``_create_rooted_hierarchy`` and waits for the result. + Create a Zarr hierarchy with a root, and return the root node, which could be a ``Group`` + or ``Array`` instance. + + Parameters + ---------- + store : Store + The storage backend to use. + path : str + The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with + ``path`` prior to creating nodes. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + overwrite : bool + Whether to overwrite existing nodes. Default is ``False``. + + Returns + ------- + Group | Array """ async_node = sync( async_api.create_rooted_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite) ) return _parse_async_node(async_node) + + +def read_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group: + """ + Read an Array or Group from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + zarr_format : {2, 3} + The zarr format of the node to read. + + Returns + ------- + Array | Group + """ + + return _parse_async_node( + sync(async_api.read_node(store=store, path=path, zarr_format=zarr_format)) + ) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 548793b4ea..b912f9f249 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -693,7 +693,7 @@ async def getitem( if self.metadata.consolidated_metadata is not None: return self._getitem_consolidated(store_path, key, prefix=self.name) try: - return await _read_node_a( + return await read_node( store=store_path.store, path=store_path.path, zarr_format=self.metadata.zarr_format ) except FileNotFoundError as e: @@ -2997,60 +2997,58 @@ async def create_nodes( # Note: the only way to alter this value is via the config. If that's undesirable for some reason, # then we should consider adding a keyword argument this this function - ctx = asyncio.Semaphore(config.get("async.concurrency")) - + semaphore = asyncio.Semaphore(config.get("async.concurrency")) create_tasks: list[Coroutine[None, None, str]] = [] + for key, value in nodes.items(): # make the key absolute write_prefix = _join_paths([path, key]) - create_tasks.extend(_persist_metadata(store, write_prefix, value)) + create_tasks.extend(_persist_metadata(store, write_prefix, value, semaphore=semaphore)) created_object_keys = [] - async with ctx: - for coro in asyncio.as_completed(create_tasks): - created_key = await coro - # the created key will be in the store key space, i.e. it is an absolute path. The key - # will also end with the name of a metadata document. We have to remove the store_path.path - # component of the key to bring it back to the relative key space of store_path - - # we need this to track which metadata documents were written so that we can yield a - # complete v2 Array / Group class after both .zattrs and the metadata JSON was created. - object_path_relative = created_key.removeprefix(path).lstrip("/") - created_object_keys.append(object_path_relative) - - # get the node name from the object key - if len(object_path_relative.split("/")) == 1: - # this is the root node - meta_out = nodes[""] - node_name = "" + + for coro in asyncio.as_completed(create_tasks): + created_key = await coro + # the created key will be in the store key space, i.e. it is an absolute path. The key + # will also end with the name of a metadata document. We have to remove the store_path.path + # component of the key to bring it back to the relative key space of store_path + + # we need this to track which metadata documents were written so that we can yield a + # complete v2 Array / Group class after both .zattrs and the metadata JSON was created. + object_path_relative = created_key.removeprefix(path).lstrip("/") + created_object_keys.append(object_path_relative) + + # get the node name from the object key + if len(object_path_relative.split("/")) == 1: + # this is the root node + meta_out = nodes[""] + node_name = "" + else: + # turn "foo/" into "foo" + node_name = object_path_relative[: object_path_relative.rfind("/")] + meta_out = nodes[node_name] + + if meta_out.zarr_format == 3: + yield _build_node(store=store, path=_join_paths([path, node_name]), metadata=meta_out) + else: + # For zarr v2 + # we only want to yield when both the metadata and attributes are created + # so we track which keys have been created, and wait for both the meta key and + # the attrs key to be created before yielding back the AsyncArray / AsyncGroup + + attrs_done = _join_paths([node_name, ZATTRS_JSON]) in created_object_keys + + if isinstance(meta_out, GroupMetadata): + meta_done = _join_paths([node_name, ZGROUP_JSON]) in created_object_keys else: - # turn "foo/" into "foo" - node_name = object_path_relative[: object_path_relative.rfind("/")] - meta_out = nodes[node_name] + meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys - if meta_out.zarr_format == 3: + if meta_done and attrs_done: yield _build_node( store=store, path=_join_paths([path, node_name]), metadata=meta_out ) - else: - # For zarr v2 - # we only want to yield when both the metadata and attributes are created - # so we track which keys have been created, and wait for both the meta key and - # the attrs key to be created before yielding back the AsyncArray / AsyncGroup - - attrs_done = _join_paths([node_name, ZATTRS_JSON]) in created_object_keys - - if isinstance(meta_out, GroupMetadata): - meta_done = _join_paths([node_name, ZGROUP_JSON]) in created_object_keys - else: - meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys - if meta_done and attrs_done: - yield _build_node( - store=store, path=_join_paths([path, node_name]), metadata=meta_out - ) - - continue + continue T = TypeVar("T") @@ -3448,7 +3446,7 @@ async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] return _build_node(store=store, path=path, metadata=metadata) -async def _read_node_a( +async def read_node( store: Store, path: str, zarr_format: ZarrFormat ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: """ @@ -3463,15 +3461,9 @@ async def _read_node_a( raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover -def read_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group: - """ - Read an Array or Group from a path in a Store. - """ - - return _parse_async_node(sync(_read_node_a(store=store, path=path, zarr_format=zarr_format))) - - -async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str: +async def _set_return_key( + *, store: Store, key: str, value: Buffer, semaphore: asyncio.Semaphore | None = None +) -> str: """ Write a value to storage at the given key. The key is returned. Useful when saving values via routines that return results in execution order, @@ -3486,13 +3478,23 @@ async def _set_return_key(*, store: Store, key: str, value: Buffer) -> str: The key to save the value to. value : Buffer The value to save. + semaphore : asyncio.Semaphore | None + An optional semaphore to use to limit the number of concurrent writes. """ - await store.set(key, value) + + if semaphore is not None: + async with semaphore: + await store.set(key, value) + else: + await store.set(key, value) return key def _persist_metadata( - store: Store, path: str, metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata + store: Store, + path: str, + metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata, + semaphore: asyncio.Semaphore | None = None, ) -> tuple[Coroutine[None, None, str], ...]: """ Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. @@ -3500,7 +3502,7 @@ def _persist_metadata( to_save = metadata.to_buffer_dict(default_buffer_prototype()) return tuple( - _set_return_key(store=store, key=_join_paths([path, key]), value=value) + _set_return_key(store=store, key=_join_paths([path, key]), value=value, semaphore=semaphore) for key, value in to_save.items() ) diff --git a/tests/test_group.py b/tests/test_group.py index f9bdb76821..d286e69861 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -19,6 +19,7 @@ import zarr.storage from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store +from zarr.api.synchronous import read_node from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype from zarr.core.config import config as zarr_config @@ -32,7 +33,6 @@ create_hierarchy, create_nodes, create_rooted_hierarchy, - read_node, ) from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import _collect_aiterator, sync @@ -1509,7 +1509,6 @@ def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: start = time.time() _ = tuple(sync_api.create_nodes(store=latency_store, path="", nodes=groups)) elapsed = time.time() - start - assert elapsed > num_groups * set_latency From 6ab833959a203bcfe8d13601d083a51e39654ce8 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 2 Feb 2025 22:05:25 +0100 Subject: [PATCH 42/59] export read_node --- src/zarr/api/synchronous.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 03eb232041..bce0d68420 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -65,6 +65,7 @@ "open_consolidated", "open_group", "open_like", + "read_node", "save", "save_array", "save_group", From 9b97c95317fa7336eb7c546595a8bbcc585a63fd Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 2 Feb 2025 22:15:28 +0100 Subject: [PATCH 43/59] docstrings --- src/zarr/core/group.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index b912f9f249..22385680f1 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3451,7 +3451,21 @@ async def read_node( ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: """ Read an AsyncArray or AsyncGroup from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + zarr_format : {2, 3} + The zarr format of the node to read. + + Returns + ------- + AsyncArray | AsyncGroup """ + match zarr_format: case 2: return await _read_node_v2(store=store, path=path) From e546519d0cab3bda3285bfb6574700920442d54d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 2 Feb 2025 22:17:01 +0100 Subject: [PATCH 44/59] docstrings --- src/zarr/core/group.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 22385680f1..9071062e94 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3432,7 +3432,18 @@ def _build_node( async def _read_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: """ - Read a Zarr v2 AsyncArray or AsyncGroup from a location defined by a StorePath. + Read a Zarr v2 AsyncArray or AsyncGroup from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + + Returns + ------- + AsyncArray | AsyncGroup """ metadata = await _read_metadata_v2(store=store, path=path) return _build_node(store=store, path=path, metadata=metadata) @@ -3440,7 +3451,18 @@ async def _read_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: """ - Read a Zarr v3 AsyncArray or AsyncGroup from a location defined by a StorePath. + Read a Zarr v3 AsyncArray or AsyncGroup from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + + Returns + ------- + AsyncArray | AsyncGroup """ metadata = await _read_metadata_v3(store=store, path=path) return _build_node(store=store, path=path, metadata=metadata) From 24eab3a1590a3278336bc8a2ce4b7fc1ec3408d5 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 2 Feb 2025 22:55:00 +0100 Subject: [PATCH 45/59] read_node -> get_node --- src/zarr/api/asynchronous.py | 4 ++-- src/zarr/api/synchronous.py | 8 ++++---- src/zarr/core/group.py | 14 +++++++------- tests/test_group.py | 8 ++++---- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 4b90f41608..f4df74ac4d 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -30,7 +30,7 @@ create_hierarchy, create_nodes, create_rooted_hierarchy, - read_node, + get_node, ) from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v2 import _default_compressor, _default_filters @@ -63,6 +63,7 @@ "empty_like", "full", "full_like", + "get_node", "group", "load", "ones", @@ -72,7 +73,6 @@ "open_consolidated", "open_group", "open_like", - "read_node", "save", "save_array", "save_group", diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index bce0d68420..87012c18d8 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -56,6 +56,7 @@ "empty_like", "full", "full_like", + "get_node", "group", "load", "ones", @@ -65,7 +66,6 @@ "open_consolidated", "open_group", "open_like", - "read_node", "save", "save_array", "save_group", @@ -1247,9 +1247,9 @@ def create_rooted_hierarchy( return _parse_async_node(async_node) -def read_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group: +def get_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group: """ - Read an Array or Group from a path in a Store. + Get an Array or Group from a path in a Store. Parameters ---------- @@ -1266,5 +1266,5 @@ def read_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group """ return _parse_async_node( - sync(async_api.read_node(store=store, path=path, zarr_format=zarr_format)) + sync(async_api.get_node(store=store, path=path, zarr_format=zarr_format)) ) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 9071062e94..bb936f1d31 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -693,7 +693,7 @@ async def getitem( if self.metadata.consolidated_metadata is not None: return self._getitem_consolidated(store_path, key, prefix=self.name) try: - return await read_node( + return await get_node( store=store_path.store, path=store_path.path, zarr_format=self.metadata.zarr_format ) except FileNotFoundError as e: @@ -3430,7 +3430,7 @@ def _build_node( raise ValueError(f"Unexpected metadata type: {type(metadata)}") # pragma: no cover -async def _read_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: +async def _get_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: """ Read a Zarr v2 AsyncArray or AsyncGroup from a path in a Store. @@ -3449,7 +3449,7 @@ async def _read_node_v2(store: Store, path: str) -> AsyncArray[ArrayV2Metadata] return _build_node(store=store, path=path, metadata=metadata) -async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: +async def _get_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: """ Read a Zarr v3 AsyncArray or AsyncGroup from a path in a Store. @@ -3468,11 +3468,11 @@ async def _read_node_v3(store: Store, path: str) -> AsyncArray[ArrayV3Metadata] return _build_node(store=store, path=path, metadata=metadata) -async def read_node( +async def get_node( store: Store, path: str, zarr_format: ZarrFormat ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: """ - Read an AsyncArray or AsyncGroup from a path in a Store. + Get an AsyncArray or AsyncGroup from a path in a Store. Parameters ---------- @@ -3490,9 +3490,9 @@ async def read_node( match zarr_format: case 2: - return await _read_node_v2(store=store, path=path) + return await _get_node_v2(store=store, path=path) case 3: - return await _read_node_v3(store=store, path=path) + return await _get_node_v3(store=store, path=path) case _: # pragma: no cover raise ValueError(f"Unexpected zarr format: {zarr_format}") # pragma: no cover diff --git a/tests/test_group.py b/tests/test_group.py index d286e69861..0969f11649 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -19,7 +19,7 @@ import zarr.storage from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store -from zarr.api.synchronous import read_node +from zarr.api.synchronous import get_node from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype from zarr.core.config import config as zarr_config @@ -1564,13 +1564,13 @@ async def test_create_hierarchy( observed_nodes = {a.path.removeprefix(path + "/"): a for a in created} if not overwrite: - extra_group = read_node( + extra_group = get_node( store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format ) assert extra_group.metadata.attributes == {"path": "group/extra"} else: with pytest.raises(FileNotFoundError): - read_node(store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format) + get_node(store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format) assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @@ -1628,7 +1628,7 @@ async def test_create_hierarchy_existing_nodes( # ensure that the extant metadata was not overwritten assert ( - read_node(store=store, path=_join_paths([path, extant_node_path]), zarr_format=zarr_format) + get_node(store=store, path=_join_paths([path, extant_node_path]), zarr_format=zarr_format) ).metadata.attributes == {"extant": True} From 545cacb543a8e2c1e634530a1fdb530d6faa23f7 Mon Sep 17 00:00:00 2001 From: Davis Bennett Date: Thu, 13 Feb 2025 00:30:56 +0100 Subject: [PATCH 46/59] Update src/zarr/api/synchronous.py Co-authored-by: Joe Hamman --- src/zarr/api/synchronous.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index b111bc407f..6bf9a7e896 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -1149,8 +1149,9 @@ def create_hierarchy( allow_root: bool = True, ) -> Iterator[Group | Array]: """ - Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input - will be created as needed. + Create a complete zarr hierarchy from a collection of metadata objects. + + Groups that are implicitly defined by the input will be created as needed. This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy concurrently. Arrays and Groups are yielded in the order they are created. From 438780b89dd95ed7d28c6e52a93ec36973669c64 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 13 Feb 2025 21:35:29 +0100 Subject: [PATCH 47/59] update docstring --- src/zarr/api/synchronous.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 6bf9a7e896..7093974c39 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -1150,11 +1150,12 @@ def create_hierarchy( ) -> Iterator[Group | Array]: """ Create a complete zarr hierarchy from a collection of metadata objects. - + Groups that are implicitly defined by the input will be created as needed. This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy - concurrently. Arrays and Groups are yielded in the order they are created. + concurrently. Arrays and Groups are yielded in the order they are created. This order is not + deterministic. Parameters ---------- From afe47cd108c4809bdfc5f36be9c2366c6798cab0 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 13 Feb 2025 21:46:25 +0100 Subject: [PATCH 48/59] add function signature tests --- tests/conftest.py | 7 +++---- tests/test_api.py | 35 ++++++++++++++++++++++++++++++++--- tests/test_group.py | 16 ++++++++++++++++ 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 79952f6cb7..8a9bdc1b0b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pathlib from dataclasses import dataclass, field -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import numpy as np import numpy.typing as npt @@ -10,7 +10,6 @@ from hypothesis import HealthCheck, Verbosity, settings from zarr import AsyncGroup, config -from zarr.abc.codec import Codec from zarr.abc.store import Store from zarr.codecs.sharding import ShardingCodec, ShardingCodecIndexLocation from zarr.core.array import ( @@ -29,7 +28,7 @@ if TYPE_CHECKING: from collections.abc import Generator, Iterable from typing import Any, Literal - + from zarr.abc.codec import Codec from _pytest.compat import LEGACY_PATH from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike @@ -291,7 +290,7 @@ def create_array_metadata( dtype=dtype_parsed, ) - sub_codecs = cast(tuple[Codec, ...], (*array_array, array_bytes, *bytes_bytes)) + sub_codecs = (*array_array, array_bytes, *bytes_bytes) codecs_out: tuple[Codec, ...] if shard_shape_parsed is not None: index_location = None diff --git a/tests/test_api.py b/tests/test_api.py index bb67769c92..5d04b7c880 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,4 +1,14 @@ -import pathlib +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable + from zarr.abc.store import Store + from zarr.core.common import JSON, MemoryOrder, ZarrFormat + import pathlib + import warnings from typing import Literal @@ -11,7 +21,6 @@ import zarr.api.synchronous import zarr.core.group from zarr import Array, Group -from zarr.abc.store import Store from zarr.api.synchronous import ( create, create_array, @@ -24,7 +33,6 @@ save_array, save_group, ) -from zarr.core.common import JSON, MemoryOrder, ZarrFormat from zarr.errors import MetadataValidationError from zarr.storage import MemoryStore from zarr.storage._utils import normalize_path @@ -1124,6 +1132,27 @@ def test_open_array_with_mode_r_plus(store: Store) -> None: z2[:] = 3 +@pytest.mark.parametrize( + ("a_func", "b_func"), + [ + (zarr.api.asynchronous.create_hierarchy, zarr.api.synchronous.create_hierarchy), + ( + zarr.api.asynchronous.create_rooted_hierarchy, + zarr.api.synchronous.create_rooted_hierarchy, + ), + ], +) +def test_consistent_signatures( + a_func: Callable[[object], object], b_func: Callable[[object], object] +) -> None: + """ + Ensure that pairs of functions have the same signature + """ + base_sig = inspect.signature(a_func) + test_sig = inspect.signature(b_func) + assert test_sig.parameters == base_sig.parameters + + def test_api_exports() -> None: """ Test that the sync API and the async API export the same objects diff --git a/tests/test_group.py b/tests/test_group.py index 0969f11649..321520cb66 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import inspect import operator import pickle import re @@ -1512,6 +1513,21 @@ def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: assert elapsed > num_groups * set_latency +@pytest.mark.parametrize( + ("a_func", "b_func"), + [(zarr.core.group.AsyncGroup.create_hierarchy, zarr.core.group.Group.create_hierarchy)], +) +def test_consistent_signatures( + a_func: Callable[[object], object], b_func: Callable[[object], object] +) -> None: + """ + Ensure that pairs of functions have consistent signatures + """ + base_sig = inspect.signature(a_func) + test_sig = inspect.signature(b_func) + assert test_sig.parameters == base_sig.parameters + + @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("overwrite", [True, False]) @pytest.mark.parametrize("impl", ["async", "sync"]) From a2547b3e8c9981148d47ebf26410db09c947f795 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Thu, 13 Feb 2025 21:47:16 +0100 Subject: [PATCH 49/59] update exception name --- src/zarr/core/group.py | 2 +- src/zarr/errors.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 8d77585417..39417143a9 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -3152,7 +3152,7 @@ def _parse_hierarchy_dict( "and so creating this node is not allowed. Remove the problematic key from the input, " "or set ``allow_root`` to True." ) - raise RootedHierarchyError(msg) + raise NestedRootError(msg) out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data_normed} diff --git a/src/zarr/errors.py b/src/zarr/errors.py index 855ea51b9d..655fa31b0f 100644 --- a/src/zarr/errors.py +++ b/src/zarr/errors.py @@ -59,8 +59,7 @@ class NodeTypeValidationError(MetadataValidationError): """ -class RootedHierarchyError(BaseZarrError): +class NestedRootError(BaseZarrError): """ - Exception raised when attempting to create a rooted hierarchy in a context where that is not - permitted. + Exception raised when attempting to create a root node relative to a pre-existing root node. """ From 9f0ccfafd5afe4a8f692d61f118ce483d09f3bb9 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 14 Feb 2025 15:21:16 +0100 Subject: [PATCH 50/59] refactor: remove path kwarg, bring back ImplicitGroupMetadata --- src/zarr/api/synchronous.py | 44 ++--- src/zarr/core/group.py | 345 ++++++++++++++++++------------------ src/zarr/errors.py | 6 - tests/conftest.py | 9 +- tests/test_api.py | 3 +- tests/test_group.py | 123 +++++++------ 6 files changed, 262 insertions(+), 268 deletions(-) diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 7093974c39..c63cb9674a 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -1142,12 +1142,11 @@ def zeros_like(a: ArrayLike, **kwargs: Any) -> Array: def create_hierarchy( + *, store: Store, - path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], overwrite: bool = False, - allow_root: bool = True, -) -> Iterator[Group | Array]: +) -> Iterator[tuple[str, Group | Array]]: """ Create a complete zarr hierarchy from a collection of metadata objects. @@ -1161,36 +1160,29 @@ def create_hierarchy( ---------- store : Store The storage backend to use. - path : str - The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with - ``path`` prior to creating nodes. nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, and the values are the metadata of the nodes. The metadata must be either an instance of GroupMetadata, ArrayV3Metadata or ArrayV2Metadata. - allow_root : bool - Whether to allow a root node to be created. If ``False``, attempting to create a root node - will result in an error. Use this option when calling this function as part of a method - defined on ``AsyncGroup`` instances, because in this case the root node has already been - created. + overwrite : bool + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error will be + raised instead of overwriting an existing array or group. Yields ------ - Group | Array - The created nodes in the order they are created. + tuple[str, Group | Array] + (key, node) pairs the order they are created. """ - coro = async_api.create_hierarchy( - store=store, path=path, nodes=nodes, overwrite=overwrite, allow_root=allow_root - ) + coro = async_api.create_hierarchy(store=store, nodes=nodes, overwrite=overwrite) - for result in sync(_collect_aiterator(coro)): - yield _parse_async_node(result) + for key, value in sync(_collect_aiterator(coro)): + yield key, _parse_async_node(value) def create_nodes( - *, store: Store, path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] -) -> Iterator[Group | Array]: + *, store: Store, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] +) -> Iterator[tuple[str, Group | Array]]: """Create a collection of arrays and / or groups concurrently. Note: no attempt is made to validate that these arrays and / or groups collectively form a @@ -1201,9 +1193,6 @@ def create_nodes( ---------- store : Store The storage backend to use. - path : str - The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with - ``path`` prior to creating nodes. nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, and the values are the metadata of the nodes. The @@ -1215,16 +1204,15 @@ def create_nodes( Group | Array The created nodes. """ - coro = async_api.create_nodes(store=store, path=path, nodes=nodes) + coro = async_api.create_nodes(store=store, nodes=nodes) - for result in sync(_collect_aiterator(coro)): - yield _parse_async_node(result) + for key, value in sync(_collect_aiterator(coro)): + yield key, _parse_async_node(value) def create_rooted_hierarchy( *, store: Store, - path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], overwrite: bool = False, ) -> Group | Array: @@ -1252,7 +1240,7 @@ def create_rooted_hierarchy( Group | Array """ async_node = sync( - async_api.create_rooted_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite) + async_api.create_rooted_hierarchy(store=store, nodes=nodes, overwrite=overwrite) ) return _parse_async_node(async_node) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 39417143a9..2881935c1c 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -6,16 +6,10 @@ import logging import warnings from collections import defaultdict +from collections.abc import Iterator, Mapping from dataclasses import asdict, dataclass, field, fields, replace from itertools import accumulate -from typing import ( - TYPE_CHECKING, - Literal, - TypeVar, - assert_never, - cast, - overload, -) +from typing import TYPE_CHECKING, Literal, TypeVar, assert_never, cast, overload import numpy as np import numpy.typing as npt @@ -57,12 +51,7 @@ from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v3 import V3JsonEncoder from zarr.core.sync import SyncMixin, sync -from zarr.errors import ( - ContainsArrayError, - ContainsGroupError, - MetadataValidationError, - RootedHierarchyError, -) +from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError from zarr.storage import StoreLike, StorePath from zarr.storage._common import ensure_no_existing_node, make_store_path from zarr.storage._utils import normalize_path @@ -71,12 +60,9 @@ from collections.abc import ( AsyncGenerator, AsyncIterator, - Callable, Coroutine, Generator, Iterable, - Iterator, - Mapping, ) from typing import Any @@ -430,6 +416,15 @@ def to_dict(self) -> dict[str, Any]: return result +@dataclass(frozen=True) +class ImplicitGroupMarker(GroupMetadata): + """ + Marker for an implicit group. Instances of this class are only used in the context of group + creation as a placeholder to represent groups that should only be created if they do not + already exist in storage + """ + + @dataclass(frozen=True) class AsyncGroup: """ @@ -1394,14 +1389,14 @@ async def _members( ): yield member - # TODO: find a better name for this. create_tree could work. - # TODO: include an example in the docstring async def create_hierarchy( self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], *, - overwrite: bool, - ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: + overwrite: bool = False, + ) -> AsyncIterator[ + tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] + ]: """ Create a hierarchy of arrays or groups rooted at this group. @@ -1430,25 +1425,24 @@ async def create_hierarchy( f" has zarr_format {self.metadata.zarr_format}." ) raise ValueError(msg) + if normalize_path(key) == "": + msg = ( + "The input defines a root node, but a root node already exists, namely this Group instance." + "It is an error to use this method to create a root node. " + "Remove the root node from the input dict, or use a function like " + "create_rooted_hierarchy to create a rooted hierarchy." + ) + raise ValueError(msg) - try: - async for node in create_hierarchy( - store=self.store, - path=self.path, - nodes=nodes, - overwrite=overwrite, - allow_root=False, - ): - yield node + # insert ImplicitGroupMetadata to represent self + nodes_rooted = nodes | {"": ImplicitGroupMarker(zarr_format=self.metadata.zarr_format)} - except RootedHierarchyError as e: - msg = ( - "The input defines a root node, but a root node already exists, namely this Group instance." - "It is an error to use this method to create a root node. " - "Remove the root node from the input dict, or use a function like " - "create_rooted_hierarchy to create a rooted hierarchy." - ) - raise ValueError(msg) from e + async for key, node in create_hierarchy( + store=self.store, + nodes=nodes_rooted, + overwrite=overwrite, + ): + yield key, node async def keys(self) -> AsyncGenerator[str, None]: """Iterate over member names.""" @@ -2078,7 +2072,7 @@ def create_hierarchy( nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], *, overwrite: bool = False, - ) -> Iterator[Group | Array]: + ) -> Iterator[tuple[str, Group | Array]]: """ Create a hierarchy of arrays or groups rooted at this group. @@ -2099,7 +2093,7 @@ def create_hierarchy( Returns ------- - An iterator of Array or Group objects. + An iterator of (name, Array or Group) tuples. Examples -------- @@ -2113,8 +2107,10 @@ def create_hierarchy( """ - for node in self._sync_iter(self._async_group.create_hierarchy(nodes, overwrite=overwrite)): - yield _parse_async_node(node) + for key, node in self._sync_iter( + self._async_group.create_hierarchy(nodes, overwrite=overwrite) + ): + yield (key, _parse_async_node(node)) def keys(self) -> Generator[str, None]: """Return an iterator over group member names. @@ -2863,71 +2859,86 @@ def array( async def create_hierarchy( *, store: Store, - path: str, - nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], overwrite: bool = False, - allow_root: bool = True, -) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: +) -> AsyncIterator[ + tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] +]: """ - Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input - will be created as needed. + Create a complete zarr hierarchy from a collection of metadata objects. + + This function will parse its input to ensure that the hierarchy is valid. In this context, + "valid" means that the following requirements are met: + * All nodes have the same zarr format. + * There are no nodes descending from arrays. + * There are no implicit groups. Any implicit groups will be inserted as needed. For example, + an input like ```{'a': GroupMetadata, 'a/b/c': GroupMetadata}``` defines an implicit group at + the path ```a/b```, and also at the root of the hierarchy, which we denote with the empty string. + After parsing, that group will be added and the input will be: + ```{'': GroupMetadata, 'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': GroupMetadata}``` + + After input parsing, this function then creates all the nodes in the hierarchy concurrently. - This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy - concurrently. AsyncArrays and AsyncGroups are yielded in the order they are created. + Arrays and Groups are yielded in the order they are created. This order is not stable and + should not be relied on. Parameters ---------- store : Store The storage backend to use. - path : str - The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with - ``path`` prior to creating nodes. nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, and the values are the metadata of the nodes. The metadata must be either an instance of GroupMetadata, ArrayV3Metadata or ArrayV2Metadata. - allow_root : bool - Whether to allow a root node to be created. If ``False``, attempting to create a root node - will result in an error. Use this option when calling this function as part of a method - defined on ``AsyncGroup`` instances, because in this case the root node has already been - created. + overwrite : bool + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is + raised instead of overwriting an existing array or group. Yields ------ AsyncGroup | AsyncArray The created nodes in the order they are created. + + Examples + -------- + """ - nodes_parsed = _parse_hierarchy_dict(data=nodes, allow_root=allow_root) + # normalize the keys to be valid paths + nodes_normed_keys = _normalize_path_keys(nodes) - # we allow creating empty hierarchies -- it's a no-op + # ensure that all nodes have the same zarr_format, and add implicit groups as needed + nodes_parsed = _parse_hierarchy_dict(data=nodes_normed_keys) + redundant_implicit_groups = [] + + # empty hierarchies should be a no-op if len(nodes_parsed) > 0: if overwrite: - await store.delete_dir(path) + # only remove elements from the store if they would be overwritten by nodes + should_delete_keys = ( + k for k, v in nodes_parsed.items() if not isinstance(v, ImplicitGroupMarker) + ) + await asyncio.gather( + *(store.delete_dir(key) for key in should_delete_keys), return_exceptions=True + ) else: # attempt to fetch all of the metadata described in hierarchy # first figure out which zarr format we are dealing with sample, *_ = nodes_parsed.values() - redundant_implicit_groups = [] - # TODO: decide if this set difference is sufficient for detecting implicit groups. - # an alternative would be to use an explicit implicit group class. - - implicit_group_names = set(nodes_parsed.keys()) - set(nodes.keys()) zarr_format = sample.zarr_format - # TODO: this type hint is so long - func: ( - Callable[[Store, str], Coroutine[Any, Any, GroupMetadata | ArrayV3Metadata]] - | Callable[[Store, str], Coroutine[Any, Any, GroupMetadata | ArrayV2Metadata]] + # This type is long. + coros: ( + Generator[Coroutine[Any, Any, ArrayV2Metadata | GroupMetadata], None, None] + | Generator[Coroutine[Any, Any, ArrayV3Metadata | GroupMetadata], None, None] ) - if zarr_format == 3: - func = _read_metadata_v3 - elif zarr_format == 2: - func = _read_metadata_v2 + if zarr_format == 2: + coros = (_read_metadata_v2(store=store, path=key) for key in nodes_parsed) + elif zarr_format == 3: + coros = (_read_metadata_v3(store=store, path=key) for key in nodes_parsed) else: # pragma: no cover raise ValueError(f"Invalid zarr_format: {zarr_format}") # pragma: no cover - coros = (func(store=store, path=_join_paths([path, key])) for key in nodes_parsed) extant_node_query = dict( zip( nodes_parsed.keys(), @@ -2935,50 +2946,54 @@ async def create_hierarchy( strict=False, ) ) - - for key, value in extant_node_query.items(): - if isinstance(value, BaseException): - if isinstance(value, FileNotFoundError): + # iterate over the existing arrays / groups and figure out which of them conflict + # with the arrays / groups we want to create + for key, extant_node in extant_node_query.items(): + proposed_node = nodes_parsed[key] + if isinstance(extant_node, BaseException): + if isinstance(extant_node, FileNotFoundError): # ignore FileNotFoundError, because they represent nodes we can safely create pass else: # Any other exception is a real error - raise value + raise extant_node else: - # this is a node that already exists, but a node with this name was specified in - # nodes_parsed. - # Two cases produce exceptions: - # 1. The node is a group, and a node with this name was explicitly defined in - # nodes - # 2. The node is an array. - # The third case is when this extant node is a group, but its name was not - # explicitly defined in nodes. This means it was added as an implicit group by - # _parse_hierarchy_dict, and we can remove the reference to this node from - # nodes_parsed. We don't need to create this node. - - if isinstance(value, GroupMetadata): - if key not in implicit_group_names: - raise ContainsGroupError(store, key) - else: - # as there is already a group with this name, we should not create a new one + # this is a node that already exists, but a node with the same key was specified + # in nodes_parsed. + if isinstance(extant_node, GroupMetadata): + # a group already exists where we want to create a group + if isinstance(proposed_node, ImplicitGroupMarker): + # we have proposed an implicit group, which is OK -- we will just skip + # creating this particular metadata document redundant_implicit_groups.append(key) - elif isinstance(value, ArrayV2Metadata | ArrayV3Metadata): + else: + # we have proposed an explicit group, which is an error, given that a + # group already exists. + raise ContainsGroupError(store, key) + elif isinstance(extant_node, ArrayV2Metadata | ArrayV3Metadata): + # we are trying to overwrite an existing array. this is an error. raise ContainsArrayError(store, key) - nodes_parsed = { - k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups - } + nodes_explicit: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {} - async for node in create_nodes(store=store, path=path, nodes=nodes_parsed): - yield node + for k, v in nodes_parsed.items(): + if k not in redundant_implicit_groups: + if isinstance(v, ImplicitGroupMarker): + nodes_explicit[k] = GroupMetadata(zarr_format=v.zarr_format) + else: + nodes_explicit[k] = v + + async for key, node in create_nodes(store=store, nodes=nodes_explicit): + yield key, node async def create_nodes( *, store: Store, - path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], -) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: +) -> AsyncIterator[ + tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] +]: """Create a collection of arrays and / or groups concurrently. Note: no attempt is made to validate that these arrays and / or groups collectively form a @@ -2989,9 +3004,6 @@ async def create_nodes( ---------- store : Store The storage backend to use. - path : str - The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with - ``path`` prior to creating nodes. nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, and the values are the metadata of the nodes. The @@ -3011,34 +3023,27 @@ async def create_nodes( for key, value in nodes.items(): # make the key absolute - write_prefix = _join_paths([path, key]) - create_tasks.extend(_persist_metadata(store, write_prefix, value, semaphore=semaphore)) + create_tasks.extend(_persist_metadata(store, key, value, semaphore=semaphore)) created_object_keys = [] for coro in asyncio.as_completed(create_tasks): created_key = await coro - # the created key will be in the store key space, i.e. it is an absolute path. The key - # will also end with the name of a metadata document. We have to remove the store_path.path - # component of the key to bring it back to the relative key space of store_path - # we need this to track which metadata documents were written so that we can yield a # complete v2 Array / Group class after both .zattrs and the metadata JSON was created. - object_path_relative = created_key.removeprefix(path).lstrip("/") - created_object_keys.append(object_path_relative) + created_object_keys.append(created_key) # get the node name from the object key - if len(object_path_relative.split("/")) == 1: + if len(created_key.split("/")) == 1: # this is the root node meta_out = nodes[""] node_name = "" else: # turn "foo/" into "foo" - node_name = object_path_relative[: object_path_relative.rfind("/")] + node_name = created_key[: created_key.rfind("/")] meta_out = nodes[node_name] - if meta_out.zarr_format == 3: - yield _build_node(store=store, path=_join_paths([path, node_name]), metadata=meta_out) + yield node_name, _build_node(store=store, path=node_name, metadata=meta_out) else: # For zarr v2 # we only want to yield when both the metadata and attributes are created @@ -3053,9 +3058,7 @@ async def create_nodes( meta_done = _join_paths([node_name, ZARRAY_JSON]) in created_object_keys if meta_done and attrs_done: - yield _build_node( - store=store, path=_join_paths([path, node_name]), metadata=meta_out - ) + yield node_name, _build_node(store=store, path=node_name, metadata=meta_out) continue @@ -3064,7 +3067,7 @@ async def create_nodes( def _get_roots( - data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + data: Iterable[str], ) -> tuple[str, ...]: """ Return the keys of the root(s) of the hierarchy. A root is a key with the fewest number of @@ -3090,21 +3093,20 @@ def _join_paths(paths: Iterable[str]) -> str: def _parse_hierarchy_dict( *, - data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - allow_root: bool = True, -) -> dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: + data: Mapping[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], +) -> dict[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: """ Take an input Mapping of str: node pairs, and parse it into - a dict of str: node pairs that models valid, complete Zarr hierarchy. + a dict of str: node pairs that models a valid, complete Zarr hierarchy. If the input represents a complete Zarr hierarchy, i.e. one with no implicit groups, then return a dict with the exact same data as the input. - Otherwise, return a dict derived from the input with groups as needed to make + Otherwise, return a dict derived from the input with GroupMetadata inserted as needed to make the hierarchy complete. For example, an input of {'a/b/c': ArrayMetadata} is incomplete, because it references two - groups ('a' and 'a/b') but these keys are not present in the input. Applying this function + groups ('a' and 'a/b') that are not specified in the input. Applying this function to that input will result in a return value of {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}, i.e. the implied groups were added. @@ -3123,51 +3125,24 @@ def _parse_hierarchy_dict( Zarr hierarchy. """ - observed_zarr_formats: dict[ZarrFormat, list[str | None]] = {2: [], 3: []} - - # We will iterate over the dict again, but a full pass here ensures that the error message - # is comprehensive, and I think the performance cost will be negligible. - for k, v in data.items(): - observed_zarr_formats[v.zarr_format].append(k) - - if len(observed_zarr_formats[2]) > 0 and len(observed_zarr_formats[3]) > 0: - msg = ( - "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. " - f"The following keys map to Zarr v2 nodes: {observed_zarr_formats.get(2)}. " - f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}." - "Ensure that all nodes have the same Zarr format." - ) - raise ValueError(msg) - - # normalize the keys of the dict + data_purified = _ensure_consistent_zarr_format(data) - data_normed: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = ( - _normalize_path_keys(data) - ) + data_normed_keys = _normalize_path_keys(data_purified) - if not allow_root and "" in data_normed: - msg = ( - "Found the empty string '' in data after key name normalization. " - "That key denotes the root of a hierarchy, but ``allow_root`` is False, " - "and so creating this node is not allowed. Remove the problematic key from the input, " - "or set ``allow_root`` to True." - ) - raise NestedRootError(msg) - - out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data_normed} + out: dict[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = { + **data_normed_keys + } for k, v in data.items(): - # TODO: ensure that the key is a valid path key_split = k.split("/") - # we use /.join here because it checks the types of its inputs, unlike an f string - *subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b])) # noqa: FLY002 + *subpaths, _ = accumulate(key_split, lambda a, b: _join_paths([a, b])) for subpath in subpaths: - # If a component is not already in the output dict, add a group + # If a component is not already in the output dict, add ImplicitGroupMetadata if subpath not in out: - out[subpath] = GroupMetadata(zarr_format=v.zarr_format) + out[subpath] = ImplicitGroupMarker(zarr_format=v.zarr_format) else: - if not isinstance(out[subpath], GroupMetadata): + if not isinstance(out[subpath], GroupMetadata | ImplicitGroupMarker): msg = ( f"The node at {subpath} contains other nodes, but it is not a Zarr group. " "This is invalid. Only Zarr groups can contain other nodes." @@ -3177,6 +3152,34 @@ def _parse_hierarchy_dict( return out +def _ensure_consistent_zarr_format( + data: Mapping[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], +) -> Mapping[str, GroupMetadata | ArrayV2Metadata] | Mapping[str, GroupMetadata | ArrayV3Metadata]: + """ + Ensure that all values of the input dict have the same zarr format. If any do not, + then a value error is raised. + """ + observed_zarr_formats: dict[ZarrFormat, list[str]] = {2: [], 3: []} + + for k, v in data.items(): + observed_zarr_formats[v.zarr_format].append(k) + + if len(observed_zarr_formats[2]) > 0 and len(observed_zarr_formats[3]) > 0: + msg = ( + "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. " + f"The following keys map to Zarr v2 nodes: {observed_zarr_formats.get(2)}. " + f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}." + "Ensure that all nodes have the same Zarr format." + ) + raise ValueError(msg) + + return cast( + Mapping[str, GroupMetadata | ArrayV2Metadata] + | Mapping[str, GroupMetadata | ArrayV3Metadata], + data, + ) + + def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: """ Normalize the input paths according to the normalization scheme used for zarr node paths. @@ -3555,7 +3558,6 @@ def _persist_metadata( async def create_rooted_hierarchy( *, store: Store, - path: str, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], overwrite: bool = False, ) -> AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: @@ -3564,7 +3566,7 @@ async def create_rooted_hierarchy( This function ensures that its input contains a specification of a root node, calls ``create_hierarchy`` to create nodes, and returns the root node of the hierarchy. """ - roots = _get_roots(nodes) + roots = _get_roots(nodes.keys()) if len(roots) != 1: msg = ( "The input does not specify a root node. " @@ -3576,8 +3578,7 @@ async def create_rooted_hierarchy( else: root_key = roots[0] - nodes_created = { - x.path: x - async for x in create_hierarchy(store=store, path=path, nodes=nodes, overwrite=overwrite) - } - return nodes_created[_join_paths([path, root_key])] + nodes_created = [ + x async for x in create_hierarchy(store=store, nodes=nodes, overwrite=overwrite) + ] + return dict(nodes_created)[root_key] diff --git a/src/zarr/errors.py b/src/zarr/errors.py index 655fa31b0f..441cdab9a3 100644 --- a/src/zarr/errors.py +++ b/src/zarr/errors.py @@ -57,9 +57,3 @@ class NodeTypeValidationError(MetadataValidationError): This can be raised when the value is invalid or unexpected given the context, for example an 'array' node when we expected a 'group'. """ - - -class NestedRootError(BaseZarrError): - """ - Exception raised when attempting to create a root node relative to a pre-existing root node. - """ diff --git a/tests/conftest.py b/tests/conftest.py index 8a9bdc1b0b..04034cb5b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,9 +28,10 @@ if TYPE_CHECKING: from collections.abc import Generator, Iterable from typing import Any, Literal - from zarr.abc.codec import Codec + from _pytest.compat import LEGACY_PATH + from zarr.abc.codec import Codec from zarr.core.array import CompressorsLike, FiltersLike, SerializerLike, ShardsLike from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike from zarr.core.common import ChunkCoords, MemoryOrder, ShapeLike, ZarrFormat @@ -290,7 +291,7 @@ def create_array_metadata( dtype=dtype_parsed, ) - sub_codecs = (*array_array, array_bytes, *bytes_bytes) + sub_codecs: tuple[Codec, ...] = (*array_array, array_bytes, *bytes_bytes) codecs_out: tuple[Codec, ...] if shard_shape_parsed is not None: index_location = None @@ -299,7 +300,9 @@ def create_array_metadata( if index_location is None: index_location = ShardingCodecIndexLocation.end sharding_codec = ShardingCodec( - chunk_shape=chunk_shape_parsed, codecs=sub_codecs, index_location=index_location + chunk_shape=chunk_shape_parsed, + codecs=sub_codecs, + index_location=index_location, ) sharding_codec.validate( shape=chunk_shape_parsed, diff --git a/tests/test_api.py b/tests/test_api.py index 5d04b7c880..4d642ad9c3 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + import pathlib from collections.abc import Callable + from zarr.abc.store import Store from zarr.core.common import JSON, MemoryOrder, ZarrFormat - import pathlib import warnings from typing import Literal diff --git a/tests/test_group.py b/tests/test_group.py index 321520cb66..af5b11ab4c 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -46,6 +46,8 @@ from .conftest import meta_from_array, parse_store if TYPE_CHECKING: + from collections.abc import Callable + from _pytest.compat import LEGACY_PATH from zarr.core.common import JSON, ZarrFormat @@ -1468,8 +1470,7 @@ async def test_create_nodes( Ensure that ``create_nodes`` can create a zarr hierarchy from a model of that hierarchy in dict form. Note that this creates an incomplete Zarr hierarchy. """ - path = "foo" - expected_meta = { + node_spec = { "group": GroupMetadata(attributes={"foo": 10}), "group/array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), "group/array_1": meta_from_array(np.arange(4), zarr_format=zarr_format), @@ -1477,16 +1478,13 @@ async def test_create_nodes( "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } if impl == "sync": - created = tuple(sync_api.create_nodes(store=store, path=path, nodes=expected_meta)) + observed_nodes = dict(sync_api.create_nodes(store=store, nodes=node_spec)) elif impl == "async": - created = tuple( - [a async for a in create_nodes(store=store, path=path, nodes=expected_meta)] - ) + observed_nodes = dict(await _collect_aiterator(create_nodes(store=store, nodes=node_spec))) else: raise ValueError(f"Invalid impl: {impl}") - observed_nodes = {a.path.removeprefix(path + "/"): a for a in created} - assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} + assert node_spec == {k: v.metadata for k, v in observed_nodes.items()} @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1508,7 +1506,7 @@ def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: with zarr_config.set({"async.concurrency": 1}): start = time.time() - _ = tuple(sync_api.create_nodes(store=latency_store, path="", nodes=groups)) + _ = tuple(sync_api.create_nodes(store=latency_store, nodes=groups)) elapsed = time.time() - start assert elapsed > num_groups * set_latency @@ -1538,7 +1536,6 @@ async def test_create_hierarchy( Test that ``create_hierarchy`` can create a complete Zarr hierarchy, even if the input describes an incomplete one. """ - path = "foo" hierarchy_spec = { "group": GroupMetadata(attributes={"path": "group"}, zarr_format=zarr_format), @@ -1557,37 +1554,30 @@ async def test_create_hierarchy( expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} # initialize the group with some nodes - _ = tuple(sync_api.create_nodes(store=store, path=path, nodes=pre_existing_nodes)) + _ = dict(sync_api.create_nodes(store=store, nodes=pre_existing_nodes)) if impl == "sync": - created = tuple( - sync_api.create_hierarchy( - store=store, path=path, nodes=hierarchy_spec, overwrite=overwrite - ) + created = dict( + sync_api.create_hierarchy(store=store, nodes=hierarchy_spec, overwrite=overwrite) ) elif impl == "async": - created = tuple( + created = dict( [ a async for a in create_hierarchy( - store=store, path=path, nodes=hierarchy_spec, overwrite=overwrite + store=store, nodes=hierarchy_spec, overwrite=overwrite ) ] ) else: raise ValueError(f"Invalid impl: {impl}") - - observed_nodes = {a.path.removeprefix(path + "/"): a for a in created} - if not overwrite: - extra_group = get_node( - store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format - ) + extra_group = get_node(store=store, path="group/extra", zarr_format=zarr_format) assert extra_group.metadata.attributes == {"path": "group/extra"} else: with pytest.raises(FileNotFoundError): - get_node(store=store, path=_join_paths([path, "group/extra"]), zarr_format=zarr_format) - assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} + get_node(store=store, path="group/extra", zarr_format=zarr_format) + assert expected_meta == {k: v.metadata for k, v in created.items()} @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1604,7 +1594,6 @@ async def test_create_hierarchy_existing_nodes( and raises an exception instead. """ extant_node_path = "node" - path = "path" if extant_node == "array": extant_metadata = meta_from_array( @@ -1618,7 +1607,7 @@ async def test_create_hierarchy_existing_nodes( err_cls = ContainsGroupError # write the extant metadata - tuple(sync_api.create_nodes(store=store, path=path, nodes={extant_node_path: extant_metadata})) + tuple(sync_api.create_nodes(store=store, nodes={extant_node_path: extant_metadata})) msg = f"{extant_node} exists in store {store!r} at path {extant_node_path!r}." # ensure that we cannot invoke create_hierarchy with overwrite=False here @@ -1626,7 +1615,7 @@ async def test_create_hierarchy_existing_nodes( with pytest.raises(err_cls, match=re.escape(msg)): tuple( sync_api.create_hierarchy( - store=store, path=path, nodes={"node": new_metadata}, overwrite=False + store=store, nodes={"node": new_metadata}, overwrite=False ) ) elif impl == "async": @@ -1635,7 +1624,7 @@ async def test_create_hierarchy_existing_nodes( [ x async for x in create_hierarchy( - store=store, path=path, nodes={"node": new_metadata}, overwrite=False + store=store, nodes={"node": new_metadata}, overwrite=False ) ] ) @@ -1644,27 +1633,56 @@ async def test_create_hierarchy_existing_nodes( # ensure that the extant metadata was not overwritten assert ( - get_node(store=store, path=_join_paths([path, extant_node_path]), zarr_format=zarr_format) + get_node(store=store, path=extant_node_path, zarr_format=zarr_format) ).metadata.attributes == {"extant": True} @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("overwrite", [True, False]) -def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat, overwrite: bool) -> None: +@pytest.mark.parametrize("impl", ["async", "sync"]) +async def test_group_create_hierarchy( + store: Store, zarr_format: ZarrFormat, overwrite: bool, impl: Literal["async", "sync"] +) -> None: """ Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict. + Also test that off-target nodes are not deleted, and that the root group is not deleted """ g = Group.from_store(store, zarr_format=zarr_format) - tree = { + + node_spec = { "a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), "a/b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a/b"}), "a/b/c": meta_from_array( np.zeros(5), zarr_format=zarr_format, attributes={"name": "a/b/c"} ), } - nodes = g.create_hierarchy(tree, overwrite=overwrite) - for k, v in g.members(max_depth=None): - assert v.metadata == tree[k] == nodes[k].metadata + # This node should be kept if overwrite is True + extant_spec = {"b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "b"})} + if impl == "async": + extant_created = dict( + await _collect_aiterator(g._async_group.create_hierarchy(extant_spec, overwrite=False)) + ) + nodes_created = dict( + await _collect_aiterator( + g._async_group.create_hierarchy(node_spec, overwrite=overwrite) + ) + ) + elif impl == "sync": + extant_created = dict(g.create_hierarchy(extant_spec, overwrite=False)) + nodes_created = dict(g.create_hierarchy(node_spec, overwrite=overwrite)) + + all_members = dict(g.members(max_depth=None)) + for k, v in node_spec.items(): + assert all_members[k].metadata == v == nodes_created[k].metadata + + # if overwrite is True, the extant nodes should be erased + for k, v in extant_spec.items(): + if overwrite: + assert k in all_members + # check that we did not erase the root group + assert get_node(store=store, path="", zarr_format=zarr_format) == g + else: + assert all_members[k].metadata == v == extant_created[k].metadata @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1682,7 +1700,7 @@ def test_group_create_hierarchy_no_root( with pytest.raises( ValueError, match="It is an error to use this method to create a root node. " ): - _ = tuple(g.create_hierarchy(tree, overwrite=overwrite)) + _ = dict(g.create_hierarchy(tree, overwrite=overwrite)) @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1709,13 +1727,12 @@ def test_group_create_hierarchy_invalid_mixed_zarr_format( @pytest.mark.parametrize("defect", ["array/array", "array/group"]) @pytest.mark.parametrize("impl", ["async", "sync"]) async def test_create_hierarchy_invalid_nested( - impl: Literal["async", "sync"], store: Store, defect: tuple[str, str], zarr_format + impl: Literal["async", "sync"], store: Store, defect: tuple[str, str], zarr_format: ZarrFormat ) -> None: """ Test that create_hierarchy will not create a Zarr array that contains a Zarr group or Zarr array. """ - path = "foo" if defect == "array/array": hierarchy_spec = { "array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), @@ -1730,10 +1747,10 @@ async def test_create_hierarchy_invalid_nested( msg = "Only Zarr groups can contain other nodes." if impl == "sync": with pytest.raises(ValueError, match=msg): - tuple(sync_api.create_hierarchy(store=store, path=path, nodes=hierarchy_spec)) + tuple(sync_api.create_hierarchy(store=store, nodes=hierarchy_spec)) elif impl == "async": with pytest.raises(ValueError, match=msg): - await _collect_aiterator(create_hierarchy(store=store, path=path, nodes=hierarchy_spec)) + await _collect_aiterator(create_hierarchy(store=store, nodes=hierarchy_spec)) @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1745,7 +1762,6 @@ async def test_create_hierarchy_invalid_mixed_format( Test that create_hierarchy will not create a Zarr group that contains a both Zarr v2 and Zarr v3 nodes. """ - path = "foo" msg = ( "Got data with both Zarr v2 and Zarr v3 nodes, which is invalid. " "The following keys map to Zarr v2 nodes: ['v2']. " @@ -1761,7 +1777,6 @@ async def test_create_hierarchy_invalid_mixed_format( tuple( sync_api.create_hierarchy( store=store, - path=path, nodes=nodes, ) ) @@ -1770,7 +1785,6 @@ async def test_create_hierarchy_invalid_mixed_format( await _collect_aiterator( create_hierarchy( store=store, - path=path, nodes=nodes, ) ) @@ -1781,10 +1795,9 @@ async def test_create_hierarchy_invalid_mixed_format( @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) -@pytest.mark.parametrize("path", ["", "foo"]) @pytest.mark.parametrize("impl", ["async", "sync"]) async def test_create_rooted_hierarchy_group( - impl: Literal["async", "sync"], store: Store, zarr_format, path: str, root_key: str + impl: Literal["async", "sync"], store: Store, zarr_format, root_key: str ) -> None: """ Test that the _create_rooted_hierarchy can create a group. @@ -1810,11 +1823,11 @@ async def test_create_rooted_hierarchy_group( nodes_create = root_meta | groups_expected_meta | arrays_expected_meta if impl == "sync": - g = sync_api.create_rooted_hierarchy(store=store, path=path, nodes=nodes_create) + g = sync_api.create_rooted_hierarchy(store=store, nodes=nodes_create) assert isinstance(g, Group) members = g.members(max_depth=None) elif impl == "async": - g = await create_rooted_hierarchy(store=store, path=path, nodes=nodes_create) + g = await create_rooted_hierarchy(store=store, nodes=nodes_create) assert isinstance(g, AsyncGroup) members = await _collect_aiterator(g.members(max_depth=None)) else: @@ -1833,10 +1846,9 @@ async def test_create_rooted_hierarchy_group( @pytest.mark.parametrize("store", ["memory", "local"], indirect=True) @pytest.mark.parametrize("zarr_format", [2, 3]) @pytest.mark.parametrize("root_key", ["", "root"]) -@pytest.mark.parametrize("path", ["", "foo"]) @pytest.mark.parametrize("impl", ["async", "sync"]) async def test_create_rooted_hierarchy_array( - impl: Literal["async", "sync"], store: Store, zarr_format, path: str, root_key: str + impl: Literal["async", "sync"], store: Store, zarr_format, root_key: str ) -> None: """ Test that _create_rooted_hierarchy can create an array. @@ -1850,14 +1862,10 @@ async def test_create_rooted_hierarchy_array( nodes_create = root_meta if impl == "sync": - a = sync_api.create_rooted_hierarchy( - store=store, path=path, nodes=nodes_create, overwrite=True - ) + a = sync_api.create_rooted_hierarchy(store=store, nodes=nodes_create, overwrite=True) assert isinstance(a, Array) elif impl == "async": - a = await create_rooted_hierarchy( - store=store, path=path, nodes=nodes_create, overwrite=True - ) + a = await create_rooted_hierarchy(store=store, nodes=nodes_create, overwrite=True) assert isinstance(a, AsyncArray) else: raise ValueError(f"Invalid impl: {impl}") @@ -1875,14 +1883,13 @@ async def test_create_rooted_hierarchy_invalid(impl: Literal["async", "sync"]) - "a": GroupMetadata(zarr_format=zarr_format), "b": GroupMetadata(zarr_format=zarr_format), } - path = "" msg = "The input does not specify a root node. " if impl == "sync": with pytest.raises(ValueError, match=msg): - sync_api.create_rooted_hierarchy(store=store, path=path, nodes=nodes) + sync_api.create_rooted_hierarchy(store=store, nodes=nodes) elif impl == "async": with pytest.raises(ValueError, match=msg): - await create_rooted_hierarchy(store=store, path=path, nodes=nodes) + await create_rooted_hierarchy(store=store, nodes=nodes) else: raise ValueError(f"Invalid impl: {impl}") From 42b9804f8764119c051b85e6812f3dc6ff391071 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 14 Feb 2025 15:21:49 +0100 Subject: [PATCH 51/59] prune top-level synchronous API --- src/zarr/__init__.py | 1 - src/zarr/api/synchronous.py | 60 ------------------------------------- 2 files changed, 61 deletions(-) diff --git a/src/zarr/__init__.py b/src/zarr/__init__.py index 31d0797af6..ce77d5f163 100644 --- a/src/zarr/__init__.py +++ b/src/zarr/__init__.py @@ -10,7 +10,6 @@ create_group, create_hierarchy, create_nodes, - create_rooted_hierarchy, empty, empty_like, full, diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index c63cb9674a..5189b98019 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -51,12 +51,10 @@ "create_array", "create_hierarchy", "create_nodes", - "create_rooted_hierarchy", "empty", "empty_like", "full", "full_like", - "get_node", "group", "load", "ones", @@ -1208,61 +1206,3 @@ def create_nodes( for key, value in sync(_collect_aiterator(coro)): yield key, _parse_async_node(value) - - -def create_rooted_hierarchy( - *, - store: Store, - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - overwrite: bool = False, -) -> Group | Array: - """ - Create a Zarr hierarchy with a root, and return the root node, which could be a ``Group`` - or ``Array`` instance. - - Parameters - ---------- - store : Store - The storage backend to use. - path : str - The name of the root of the created hierarchy. Every key in ``nodes`` will be prefixed with - ``path`` prior to creating nodes. - nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] - A dictionary defining the hierarchy. The keys are the paths of the nodes - in the hierarchy, and the values are the metadata of the nodes. The - metadata must be either an instance of GroupMetadata, ArrayV3Metadata - or ArrayV2Metadata. - overwrite : bool - Whether to overwrite existing nodes. Default is ``False``. - - Returns - ------- - Group | Array - """ - async_node = sync( - async_api.create_rooted_hierarchy(store=store, nodes=nodes, overwrite=overwrite) - ) - return _parse_async_node(async_node) - - -def get_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group: - """ - Get an Array or Group from a path in a Store. - - Parameters - ---------- - store : Store - The store-like object to read from. - path : str - The path to the node to read. - zarr_format : {2, 3} - The zarr format of the node to read. - - Returns - ------- - Array | Group - """ - - return _parse_async_node( - sync(async_api.get_node(store=store, path=path, zarr_format=zarr_format)) - ) From d7d007099b99dff157317b51e5c3f3046dfbd40f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 14 Feb 2025 15:33:45 +0100 Subject: [PATCH 52/59] more api pruning --- src/zarr/__init__.py | 1 - src/zarr/api/synchronous.py | 31 ------------------------------- 2 files changed, 32 deletions(-) diff --git a/src/zarr/__init__.py b/src/zarr/__init__.py index ce77d5f163..e0538259d7 100644 --- a/src/zarr/__init__.py +++ b/src/zarr/__init__.py @@ -9,7 +9,6 @@ create_array, create_group, create_hierarchy, - create_nodes, empty, empty_like, full, diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 5189b98019..9fcedabe72 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -50,7 +50,6 @@ "create", "create_array", "create_hierarchy", - "create_nodes", "empty", "empty_like", "full", @@ -1176,33 +1175,3 @@ def create_hierarchy( for key, value in sync(_collect_aiterator(coro)): yield key, _parse_async_node(value) - - -def create_nodes( - *, store: Store, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] -) -> Iterator[tuple[str, Group | Array]]: - """Create a collection of arrays and / or groups concurrently. - - Note: no attempt is made to validate that these arrays and / or groups collectively form a - valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that - the ``nodes`` parameter satisfies any correctness constraints. - - Parameters - ---------- - store : Store - The storage backend to use. - nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] - A dictionary defining the hierarchy. The keys are the paths of the nodes - in the hierarchy, and the values are the metadata of the nodes. The - metadata must be either an instance of GroupMetadata, ArrayV3Metadata - or ArrayV2Metadata. - - Yields - ------ - Group | Array - The created nodes. - """ - coro = async_api.create_nodes(store=store, nodes=nodes) - - for key, value in sync(_collect_aiterator(coro)): - yield key, _parse_async_node(value) From afdc320c4468c193004840a74357e287efefe8da Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Fri, 14 Feb 2025 17:49:39 +0100 Subject: [PATCH 53/59] put sync wrappers in sync_group module, move utils to utils --- src/zarr/api/asynchronous.py | 6 -- src/zarr/api/synchronous.py | 49 +------------ src/zarr/core/group.py | 44 +---------- src/zarr/core/sync_group.py | 134 ++++++++++++++++++++++++++++++++++ src/zarr/storage/_utils.py | 46 +++++++++++- tests/test_api.py | 23 ------ tests/test_group.py | 93 ++++++++++------------- tests/test_store/test_core.py | 47 +++++++++++- 8 files changed, 267 insertions(+), 175 deletions(-) create mode 100644 src/zarr/core/sync_group.py diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index d1ee8568b5..a4e63ec09e 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -28,9 +28,6 @@ ConsolidatedMetadata, GroupMetadata, create_hierarchy, - create_nodes, - create_rooted_hierarchy, - get_node, ) from zarr.core.metadata import ArrayMetadataDict, ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v2 import _default_compressor, _default_filters @@ -57,13 +54,10 @@ "create", "create_array", "create_hierarchy", - "create_nodes", - "create_rooted_hierarchy", "empty", "empty_like", "full", "full_like", - "get_node", "group", "load", "ones", diff --git a/src/zarr/api/synchronous.py b/src/zarr/api/synchronous.py index 9fcedabe72..219eef80be 100644 --- a/src/zarr/api/synchronous.py +++ b/src/zarr/api/synchronous.py @@ -7,19 +7,18 @@ import zarr.api.asynchronous as async_api import zarr.core.array from zarr._compat import _deprecate_positional_args -from zarr.abc.store import Store from zarr.core.array import Array, AsyncArray -from zarr.core.group import Group, GroupMetadata, _parse_async_node -from zarr.core.sync import _collect_aiterator, sync +from zarr.core.group import Group +from zarr.core.sync import sync +from zarr.core.sync_group import create_hierarchy if TYPE_CHECKING: - from collections.abc import Iterable, Iterator + from collections.abc import Iterable import numpy as np import numpy.typing as npt from zarr.abc.codec import Codec - from zarr.abc.store import Store from zarr.api.asynchronous import ArrayLike, PathLike from zarr.core.array import ( CompressorsLike, @@ -38,7 +37,6 @@ ShapeLike, ZarrFormat, ) - from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.storage import StoreLike __all__ = [ @@ -1136,42 +1134,3 @@ def zeros_like(a: ArrayLike, **kwargs: Any) -> Array: The new array. """ return Array(sync(async_api.zeros_like(a, **kwargs))) - - -def create_hierarchy( - *, - store: Store, - nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], - overwrite: bool = False, -) -> Iterator[tuple[str, Group | Array]]: - """ - Create a complete zarr hierarchy from a collection of metadata objects. - - Groups that are implicitly defined by the input will be created as needed. - - This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy - concurrently. Arrays and Groups are yielded in the order they are created. This order is not - deterministic. - - Parameters - ---------- - store : Store - The storage backend to use. - nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] - A dictionary defining the hierarchy. The keys are the paths of the nodes - in the hierarchy, and the values are the metadata of the nodes. The - metadata must be either an instance of GroupMetadata, ArrayV3Metadata - or ArrayV2Metadata. - overwrite : bool - Whether to overwrite existing nodes. Defaults to ``False``, in which case an error will be - raised instead of overwriting an existing array or group. - - Yields - ------ - tuple[str, Group | Array] - (key, node) pairs the order they are created. - """ - coro = async_api.create_hierarchy(store=store, nodes=nodes, overwrite=overwrite) - - for key, value in sync(_collect_aiterator(coro)): - yield key, _parse_async_node(value) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 2881935c1c..5f5c2e0e29 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -54,7 +54,7 @@ from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError from zarr.storage import StoreLike, StorePath from zarr.storage._common import ensure_no_existing_node, make_store_path -from zarr.storage._utils import normalize_path +from zarr.storage._utils import _join_paths, _normalize_path_keys, normalize_path if TYPE_CHECKING: from collections.abc import ( @@ -3063,9 +3063,6 @@ async def create_nodes( continue -T = TypeVar("T") - - def _get_roots( data: Iterable[str], ) -> tuple[str, ...]: @@ -3082,15 +3079,6 @@ def _get_roots( return tuple(groups[min(groups.keys())]) -def _join_paths(paths: Iterable[str]) -> str: - """ - Filter out instances of '' and join the remaining strings with '/'. - - Because the root node of a zarr hierarchy is represented by an empty string, - """ - return "/".join(filter(lambda v: v != "", paths)) - - def _parse_hierarchy_dict( *, data: Mapping[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], @@ -3180,36 +3168,6 @@ def _ensure_consistent_zarr_format( ) -def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: - """ - Normalize the input paths according to the normalization scheme used for zarr node paths. - If any two paths normalize to the same value, raise a ValueError. - """ - path_map: dict[str, str] = {} - for path in paths: - parsed = normalize_path(path) - if parsed in path_map: - msg = ( - f"After normalization, the value '{path}' collides with '{path_map[parsed]}'. " - f"Both '{path}' and '{path_map[parsed]}' normalize to the same value: '{parsed}'. " - f"You should use either '{path}' or '{path_map[parsed]}', but not both." - ) - raise ValueError(msg) - path_map[parsed] = path - return tuple(path_map.keys()) - - -def _normalize_path_keys(data: Mapping[str, T]) -> dict[str, T]: - """ - Normalize the keys of the input dict according to the normalization scheme used for zarr node - paths. If any two keys in the input normalize to the same value, raise a ValueError. - Returns a dict where the keys are the elements of the input and the values are the - normalized form of each key. - """ - parsed_keys = _normalize_paths(data.keys()) - return dict(zip(parsed_keys, data.values(), strict=True)) - - async def _getitem_semaphore( node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None ) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: diff --git a/src/zarr/core/sync_group.py b/src/zarr/core/sync_group.py new file mode 100644 index 0000000000..829d98affc --- /dev/null +++ b/src/zarr/core/sync_group.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from zarr.core.group import Group, GroupMetadata, _parse_async_node +from zarr.core.group import create_hierarchy as create_hierarchy_async +from zarr.core.group import create_nodes as create_nodes_async +from zarr.core.group import create_rooted_hierarchy as create_rooted_hierarchy_async +from zarr.core.group import get_node as get_node_async +from zarr.core.sync import _collect_aiterator, sync + +if TYPE_CHECKING: + from collections.abc import Iterator + + from zarr.abc.store import Store + from zarr.core.array import Array + from zarr.core.common import ZarrFormat + from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata + + +def create_nodes( + *, store: Store, nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] +) -> Iterator[tuple[str, Group | Array]]: + """Create a collection of arrays and / or groups concurrently. + + Note: no attempt is made to validate that these arrays and / or groups collectively form a + valid Zarr hierarchy. It is the responsibility of the caller of this function to ensure that + the ``nodes`` parameter satisfies any correctness constraints. + + Parameters + ---------- + store : Store + The storage backend to use. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + + Yields + ------ + Group | Array + The created nodes. + """ + coro = create_nodes_async(store=store, nodes=nodes) + + for key, value in sync(_collect_aiterator(coro)): + yield key, _parse_async_node(value) + + +def create_hierarchy( + *, + store: Store, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> Iterator[tuple[str, Group | Array]]: + """ + Create a complete zarr hierarchy from a collection of metadata objects. + + Groups that are implicitly defined by the input will be created as needed. + + This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy + concurrently. Arrays and Groups are yielded in the order they are created. + + Parameters + ---------- + store : Store + The storage backend to use. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + + Yields + ------ + Group | Array + The created nodes in the order they are created. + """ + coro = create_hierarchy_async(store=store, nodes=nodes, overwrite=overwrite) + + for key, value in sync(_collect_aiterator(coro)): + yield key, _parse_async_node(value) + + +def create_rooted_hierarchy( + *, + store: Store, + nodes: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], + overwrite: bool = False, +) -> Group | Array: + """ + Create a Zarr hierarchy with a root, and return the root node, which could be a ``Group`` + or ``Array`` instance. + + Parameters + ---------- + store : Store + The storage backend to use. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes + in the hierarchy, and the values are the metadata of the nodes. The + metadata must be either an instance of GroupMetadata, ArrayV3Metadata + or ArrayV2Metadata. + overwrite : bool + Whether to overwrite existing nodes. Default is ``False``. + + Returns + ------- + Group | Array + """ + async_node = sync(create_rooted_hierarchy_async(store=store, nodes=nodes, overwrite=overwrite)) + return _parse_async_node(async_node) + + +def get_node(store: Store, path: str, zarr_format: ZarrFormat) -> Array | Group: + """ + Get an Array or Group from a path in a Store. + + Parameters + ---------- + store : Store + The store-like object to read from. + path : str + The path to the node to read. + zarr_format : {2, 3} + The zarr format of the node to read. + + Returns + ------- + Array | Group + """ + + return _parse_async_node(sync(get_node_async(store=store, path=path, zarr_format=zarr_format))) diff --git a/src/zarr/storage/_utils.py b/src/zarr/storage/_utils.py index 4fc3171eb8..eda4342f47 100644 --- a/src/zarr/storage/_utils.py +++ b/src/zarr/storage/_utils.py @@ -2,11 +2,13 @@ import re from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar from zarr.abc.store import OffsetByteRequest, RangeByteRequest, SuffixByteRequest if TYPE_CHECKING: + from collections.abc import Iterable, Mapping + from zarr.abc.store import ByteRequest from zarr.core.buffer import Buffer @@ -66,3 +68,45 @@ def _normalize_byte_range_index(data: Buffer, byte_range: ByteRequest | None) -> else: raise ValueError(f"Unexpected byte_range, got {byte_range}.") return (start, stop) + + +def _join_paths(paths: Iterable[str]) -> str: + """ + Filter out instances of '' and join the remaining strings with '/'. + + Because the root node of a zarr hierarchy is represented by an empty string, + """ + return "/".join(filter(lambda v: v != "", paths)) + + +def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: + """ + Normalize the input paths according to the normalization scheme used for zarr node paths. + If any two paths normalize to the same value, raise a ValueError. + """ + path_map: dict[str, str] = {} + for path in paths: + parsed = normalize_path(path) + if parsed in path_map: + msg = ( + f"After normalization, the value '{path}' collides with '{path_map[parsed]}'. " + f"Both '{path}' and '{path_map[parsed]}' normalize to the same value: '{parsed}'. " + f"You should use either '{path}' or '{path_map[parsed]}', but not both." + ) + raise ValueError(msg) + path_map[parsed] = path + return tuple(path_map.keys()) + + +T = TypeVar("T") + + +def _normalize_path_keys(data: Mapping[str, T]) -> dict[str, T]: + """ + Normalize the keys of the input dict according to the normalization scheme used for zarr node + paths. If any two keys in the input normalize to the same value, raise a ValueError. + Returns a dict where the keys are the elements of the input and the values are the + normalized form of each key. + """ + parsed_keys = _normalize_paths(data.keys()) + return dict(zip(parsed_keys, data.values(), strict=True)) diff --git a/tests/test_api.py b/tests/test_api.py index 4d642ad9c3..d3a6e3eca6 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,11 +1,9 @@ from __future__ import annotations -import inspect from typing import TYPE_CHECKING if TYPE_CHECKING: import pathlib - from collections.abc import Callable from zarr.abc.store import Store from zarr.core.common import JSON, MemoryOrder, ZarrFormat @@ -1133,27 +1131,6 @@ def test_open_array_with_mode_r_plus(store: Store) -> None: z2[:] = 3 -@pytest.mark.parametrize( - ("a_func", "b_func"), - [ - (zarr.api.asynchronous.create_hierarchy, zarr.api.synchronous.create_hierarchy), - ( - zarr.api.asynchronous.create_rooted_hierarchy, - zarr.api.synchronous.create_rooted_hierarchy, - ), - ], -) -def test_consistent_signatures( - a_func: Callable[[object], object], b_func: Callable[[object], object] -) -> None: - """ - Ensure that pairs of functions have the same signature - """ - base_sig = inspect.signature(a_func) - test_sig = inspect.signature(b_func) - assert test_sig.parameters == base_sig.parameters - - def test_api_exports() -> None: """ Test that the sync API and the async API export the same objects diff --git a/tests/test_group.py b/tests/test_group.py index af5b11ab4c..be4d09acef 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -16,11 +16,10 @@ import zarr import zarr.api.asynchronous import zarr.api.synchronous -import zarr.api.synchronous as sync_api import zarr.storage from zarr import Array, AsyncArray, AsyncGroup, Group from zarr.abc.store import Store -from zarr.api.synchronous import get_node +from zarr.core import sync_group from zarr.core._info import GroupInfo from zarr.core.buffer import default_buffer_prototype from zarr.core.config import config as zarr_config @@ -28,19 +27,18 @@ ConsolidatedMetadata, GroupMetadata, _build_metadata_v3, - _join_paths, - _normalize_path_keys, - _normalize_paths, + _get_roots, create_hierarchy, create_nodes, create_rooted_hierarchy, + get_node, ) from zarr.core.metadata.v3 import ArrayV3Metadata from zarr.core.sync import _collect_aiterator, sync from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage._common import make_store_path -from zarr.storage._utils import normalize_path +from zarr.storage._utils import _join_paths from zarr.testing.store import LatencyStore from .conftest import meta_from_array, parse_store @@ -1478,7 +1476,7 @@ async def test_create_nodes( "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } if impl == "sync": - observed_nodes = dict(sync_api.create_nodes(store=store, nodes=node_spec)) + observed_nodes = dict(sync_group.create_nodes(store=store, nodes=node_spec)) elif impl == "async": observed_nodes = dict(await _collect_aiterator(create_nodes(store=store, nodes=node_spec))) else: @@ -1506,14 +1504,20 @@ def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: with zarr_config.set({"async.concurrency": 1}): start = time.time() - _ = tuple(sync_api.create_nodes(store=latency_store, nodes=groups)) + _ = tuple(sync_group.create_nodes(store=latency_store, nodes=groups)) elapsed = time.time() - start assert elapsed > num_groups * set_latency @pytest.mark.parametrize( ("a_func", "b_func"), - [(zarr.core.group.AsyncGroup.create_hierarchy, zarr.core.group.Group.create_hierarchy)], + [ + (zarr.core.group.AsyncGroup.create_hierarchy, zarr.core.group.Group.create_hierarchy), + (zarr.core.group.create_hierarchy, zarr.core.sync_group.create_hierarchy), + (zarr.core.group.create_nodes, zarr.core.sync_group.create_nodes), + (zarr.core.group.create_rooted_hierarchy, zarr.core.sync_group.create_rooted_hierarchy), + (zarr.core.group.get_node, zarr.core.sync_group.get_node), + ], ) def test_consistent_signatures( a_func: Callable[[object], object], b_func: Callable[[object], object] @@ -1554,11 +1558,11 @@ async def test_create_hierarchy( expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} # initialize the group with some nodes - _ = dict(sync_api.create_nodes(store=store, nodes=pre_existing_nodes)) + _ = dict(sync_group.create_nodes(store=store, nodes=pre_existing_nodes)) if impl == "sync": created = dict( - sync_api.create_hierarchy(store=store, nodes=hierarchy_spec, overwrite=overwrite) + sync_group.create_hierarchy(store=store, nodes=hierarchy_spec, overwrite=overwrite) ) elif impl == "async": created = dict( @@ -1572,11 +1576,11 @@ async def test_create_hierarchy( else: raise ValueError(f"Invalid impl: {impl}") if not overwrite: - extra_group = get_node(store=store, path="group/extra", zarr_format=zarr_format) + extra_group = sync_group.get_node(store=store, path="group/extra", zarr_format=zarr_format) assert extra_group.metadata.attributes == {"path": "group/extra"} else: with pytest.raises(FileNotFoundError): - get_node(store=store, path="group/extra", zarr_format=zarr_format) + await get_node(store=store, path="group/extra", zarr_format=zarr_format) assert expected_meta == {k: v.metadata for k, v in created.items()} @@ -1607,14 +1611,14 @@ async def test_create_hierarchy_existing_nodes( err_cls = ContainsGroupError # write the extant metadata - tuple(sync_api.create_nodes(store=store, nodes={extant_node_path: extant_metadata})) + tuple(sync_group.create_nodes(store=store, nodes={extant_node_path: extant_metadata})) msg = f"{extant_node} exists in store {store!r} at path {extant_node_path!r}." # ensure that we cannot invoke create_hierarchy with overwrite=False here if impl == "sync": with pytest.raises(err_cls, match=re.escape(msg)): tuple( - sync_api.create_hierarchy( + sync_group.create_hierarchy( store=store, nodes={"node": new_metadata}, overwrite=False ) ) @@ -1633,7 +1637,7 @@ async def test_create_hierarchy_existing_nodes( # ensure that the extant metadata was not overwritten assert ( - get_node(store=store, path=extant_node_path, zarr_format=zarr_format) + await get_node(store=store, path=extant_node_path, zarr_format=zarr_format) ).metadata.attributes == {"extant": True} @@ -1680,7 +1684,7 @@ async def test_group_create_hierarchy( if overwrite: assert k in all_members # check that we did not erase the root group - assert get_node(store=store, path="", zarr_format=zarr_format) == g + assert sync_group.get_node(store=store, path="", zarr_format=zarr_format) == g else: assert all_members[k].metadata == v == extant_created[k].metadata @@ -1747,7 +1751,7 @@ async def test_create_hierarchy_invalid_nested( msg = "Only Zarr groups can contain other nodes." if impl == "sync": with pytest.raises(ValueError, match=msg): - tuple(sync_api.create_hierarchy(store=store, nodes=hierarchy_spec)) + tuple(sync_group.create_hierarchy(store=store, nodes=hierarchy_spec)) elif impl == "async": with pytest.raises(ValueError, match=msg): await _collect_aiterator(create_hierarchy(store=store, nodes=hierarchy_spec)) @@ -1775,7 +1779,7 @@ async def test_create_hierarchy_invalid_mixed_format( if impl == "sync": with pytest.raises(ValueError, match=re.escape(msg)): tuple( - sync_api.create_hierarchy( + sync_group.create_hierarchy( store=store, nodes=nodes, ) @@ -1823,7 +1827,7 @@ async def test_create_rooted_hierarchy_group( nodes_create = root_meta | groups_expected_meta | arrays_expected_meta if impl == "sync": - g = sync_api.create_rooted_hierarchy(store=store, nodes=nodes_create) + g = sync_group.create_rooted_hierarchy(store=store, nodes=nodes_create) assert isinstance(g, Group) members = g.members(max_depth=None) elif impl == "async": @@ -1862,7 +1866,7 @@ async def test_create_rooted_hierarchy_array( nodes_create = root_meta if impl == "sync": - a = sync_api.create_rooted_hierarchy(store=store, nodes=nodes_create, overwrite=True) + a = sync_group.create_rooted_hierarchy(store=store, nodes=nodes_create, overwrite=True) assert isinstance(a, Array) elif impl == "async": a = await create_rooted_hierarchy(store=store, nodes=nodes_create, overwrite=True) @@ -1886,7 +1890,7 @@ async def test_create_rooted_hierarchy_invalid(impl: Literal["async", "sync"]) - msg = "The input does not specify a root node. " if impl == "sync": with pytest.raises(ValueError, match=msg): - sync_api.create_rooted_hierarchy(store=store, nodes=nodes) + sync_group.create_rooted_hierarchy(store=store, nodes=nodes) elif impl == "async": with pytest.raises(ValueError, match=msg): await create_rooted_hierarchy(store=store, nodes=nodes) @@ -1894,40 +1898,6 @@ async def test_create_rooted_hierarchy_invalid(impl: Literal["async", "sync"]) - raise ValueError(f"Invalid impl: {impl}") -@pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")]) -def test_normalize_paths_invalid(paths: tuple[str, str]) -> None: - """ - Ensure that calling _normalize_paths on values that will normalize to the same value - will generate a ValueError. - """ - a, b = paths - msg = f"After normalization, the value '{b}' collides with '{a}'. " - with pytest.raises(ValueError, match=msg): - _normalize_paths(paths) - - -@pytest.mark.parametrize( - "paths", [("/a", "a/b"), ("a", "a/b"), ("a/", "a///b"), ("/a/", "//a/b///")] -) -def test_normalize_paths_valid(paths: tuple[str, str]) -> None: - """ - Ensure that calling _normalize_paths on values that normalize to distinct values - returns a tuple of those normalized values. - """ - expected = tuple(map(normalize_path, paths)) - assert _normalize_paths(paths) == expected - - -def test_normalize_path_keys() -> None: - """ - Test that normalize_path_keys returns a dict where each key has been normalized. - """ - data = {"": 10, "a": "hello", "a/b": None, "/a/b/c/d": None} - observed = _normalize_path_keys(data) - expected = {normalize_path(k): v for k, v in data.items()} - assert observed == expected - - @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_performance(store: Store) -> None: """ @@ -2012,3 +1982,14 @@ def test_build_metadata_v3(option: Literal["array", "group", "invalid"]) -> None msg = "Invalid value for 'node_type'. Expected 'array or group'. Got 'nothing (the key is missing)'." with pytest.raises(MetadataValidationError, match=re.escape(msg)): _build_metadata_v3(metadata_dict) + + +@pytest.mark.parametrize("roots", [("",), ("a", "b")]) +def test_get_roots(roots: tuple[str, ...]): + root_nodes = {k: GroupMetadata(attributes={"name": k}) for k in roots} + child_nodes = { + _join_paths([k, "foo"]): GroupMetadata(attributes={"name": _join_paths([k, "foo"])}) + for k in roots + } + data = root_nodes | child_nodes + assert set(_get_roots(data)) == set(roots) diff --git a/tests/test_store/test_core.py b/tests/test_store/test_core.py index 726da06a52..bce582a746 100644 --- a/tests/test_store/test_core.py +++ b/tests/test_store/test_core.py @@ -8,7 +8,7 @@ from zarr.core.common import AccessModeLiteral, ZarrFormat from zarr.storage import FsspecStore, LocalStore, MemoryStore, StoreLike, StorePath from zarr.storage._common import contains_array, contains_group, make_store_path -from zarr.storage._utils import normalize_path +from zarr.storage._utils import _join_paths, _normalize_path_keys, _normalize_paths, normalize_path @pytest.mark.parametrize("path", ["foo", "foo/bar"]) @@ -174,3 +174,48 @@ def test_normalize_path_none(): def test_normalize_path_invalid(path: str): with pytest.raises(ValueError): normalize_path(path) + + +@pytest.mark.parametrize("paths", [("", "foo"), ("foo", "bar")]) +def test_join_paths(paths: tuple[str, str]) -> None: + """ + Test that _join_paths joins paths in a way that is robust to an empty string + """ + observed = _join_paths(paths) + if paths[0] == "": + assert observed == paths[1] + else: + assert observed == "/".join(paths) + + +class TestNormalizePaths: + @staticmethod + def test_valid() -> None: + """ + Test that path normalization works as expected + """ + paths = ["a", "b", "c", "d", "", "//a///b//"] + assert _normalize_paths(paths) == tuple([normalize_path(p) for p in paths]) + + @staticmethod + @pytest.mark.parametrize("paths", [("", "/"), ("///a", "a")]) + def test_invalid(paths: tuple[str, str]) -> None: + """ + Test that name collisions after normalization raise a ``ValueError`` + """ + msg = ( + f"After normalization, the value '{paths[1]}' collides with '{paths[0]}'. " + f"Both '{paths[1]}' and '{paths[0]}' normalize to the same value: '{normalize_path(paths[0])}'. " + f"You should use either '{paths[1]}' or '{paths[0]}', but not both." + ) + with pytest.raises(ValueError, match=msg): + _normalize_paths(paths) + + +def test_normalize_path_keys(): + """ + Test that ``_normalize_path_keys`` just applies the normalize_path function to each key of its + input + """ + data = {"a": 10, "//b": 10} + assert _normalize_path_keys(data) == {normalize_path(k): v for k, v in data.items()} From 50b02b412d953c2730ed5f22f86c35f835a9002f Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Mon, 17 Feb 2025 18:41:09 +0100 Subject: [PATCH 54/59] ensure we always have a root group --- src/zarr/core/group.py | 106 ++++++++++++++++++++++++++++++----------- tests/test_group.py | 55 +++++++++++++++++++-- 2 files changed, 130 insertions(+), 31 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 5f5c2e0e29..35dbf4d9f6 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1434,8 +1434,7 @@ async def create_hierarchy( ) raise ValueError(msg) - # insert ImplicitGroupMetadata to represent self - nodes_rooted = nodes | {"": ImplicitGroupMarker(zarr_format=self.metadata.zarr_format)} + nodes_rooted = nodes async for key, node in create_hierarchy( store=self.store, @@ -2913,20 +2912,43 @@ async def create_hierarchy( # empty hierarchies should be a no-op if len(nodes_parsed) > 0: + # figure out which zarr format we are using + zarr_format = next(iter(nodes_parsed.values())).zarr_format + + # check which implicit groups will require materialization + implicit_group_keys = tuple( + filter(lambda k: isinstance(nodes_parsed[k], ImplicitGroupMarker), nodes_parsed) + ) + # read potential group metadata for each implicit group + maybe_extant_group_coros = ( + _read_group_metadata(store, k, zarr_format=zarr_format) for k in implicit_group_keys + ) + maybe_extant_groups = await asyncio.gather( + *maybe_extant_group_coros, return_exceptions=True + ) + + for key, value in zip(implicit_group_keys, maybe_extant_groups, strict=True): + if isinstance(value, BaseException): + if isinstance(value, FileNotFoundError): + # this is fine -- there was no group there, so we will create one + pass + else: + raise value + else: + # a loop exists already at ``key``, so we can avoid creating anything there + redundant_implicit_groups.append(key) + if overwrite: - # only remove elements from the store if they would be overwritten by nodes - should_delete_keys = ( - k for k, v in nodes_parsed.items() if not isinstance(v, ImplicitGroupMarker) - ) - await asyncio.gather( - *(store.delete_dir(key) for key in should_delete_keys), return_exceptions=True + # we will remove any nodes that collide with arrays and non-implicit groups defined in + # nodes + + # track the keys of nodes we need to delete + to_delete_keys = [] + to_delete_keys.extend( + [k for k, v in nodes_parsed.items() if k not in implicit_group_keys] ) + await asyncio.gather(*(store.delete_dir(key) for key in to_delete_keys)) else: - # attempt to fetch all of the metadata described in hierarchy - # first figure out which zarr format we are dealing with - sample, *_ = nodes_parsed.values() - - zarr_format = sample.zarr_format # This type is long. coros: ( Generator[Coroutine[Any, Any, ArrayV2Metadata | GroupMetadata], None, None] @@ -3084,7 +3106,7 @@ def _parse_hierarchy_dict( data: Mapping[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], ) -> dict[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata]: """ - Take an input Mapping of str: node pairs, and parse it into + Take an input with type Mapping[str, ArrayMetadata | GroupMetadata] and parse it into a dict of str: node pairs that models a valid, complete Zarr hierarchy. If the input represents a complete Zarr hierarchy, i.e. one with no implicit groups, @@ -3093,10 +3115,10 @@ def _parse_hierarchy_dict( Otherwise, return a dict derived from the input with GroupMetadata inserted as needed to make the hierarchy complete. - For example, an input of {'a/b/c': ArrayMetadata} is incomplete, because it references two - groups ('a' and 'a/b') that are not specified in the input. Applying this function + For example, an input of {'a/b': ArrayMetadata} is incomplete, because it references two + groups (the root group '' and a group at 'a') that are not specified in the input. Applying this function to that input will result in a return value of - {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}, i.e. the implied groups + {'': GroupMetadata, 'a': GroupMetadata, 'a/b': ArrayMetadata}, i.e. the implied groups were added. The input is also checked for the following conditions; an error is raised if any are violated: @@ -3104,25 +3126,28 @@ def _parse_hierarchy_dict( - No arrays can contain group or arrays (i.e., all arrays must be leaf nodes). - All arrays and groups must have the same ``zarr_format`` value. - if ``allow_root`` is set to False, then the input is also checked to ensure that it does not - contain a key that normalizes to the empty string (''), as this is reserved for the root node, - and in some situations creating a root node is not permitted, for example, when creating a - hierarchy relative to an existing group. - This function ensures that the input is transformed into a specification of a complete and valid Zarr hierarchy. """ + # ensure that all nodes have the same zarr format data_purified = _ensure_consistent_zarr_format(data) + # ensure that keys are normalized to zarr paths data_normed_keys = _normalize_path_keys(data_purified) - out: dict[str, ImplicitGroupMarker | GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = { - **data_normed_keys - } + # insert an implicit root group if a root was not specified + # but not if an empty dict was provided, because any empty hierarchy has no nodes + if len(data_normed_keys) > 0 and "" not in data_normed_keys: + z_format = next(iter(data_normed_keys.values())).zarr_format + data_normed_keys = data_normed_keys | {"": ImplicitGroupMarker(zarr_format=z_format)} - for k, v in data.items(): + out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data_normed_keys} + + for k, v in data_normed_keys.items(): key_split = k.split("/") + + # get every parent path *subpaths, _ = accumulate(key_split, lambda a, b: _join_paths([a, b])) for subpath in subpaths: @@ -3136,7 +3161,6 @@ def _parse_hierarchy_dict( "This is invalid. Only Zarr groups can contain other nodes." ) raise ValueError(msg) - return out @@ -3338,6 +3362,34 @@ async def _read_metadata_v2(store: Store, path: str) -> ArrayV2Metadata | GroupM return _build_metadata_v2(zmeta, zattrs) +async def _read_group_metadata_v2(store: Store, path: str) -> GroupMetadata: + """ + Read group metadata or error + """ + meta = await _read_metadata_v2(store=store, path=path) + if not isinstance(meta, GroupMetadata): + raise FileNotFoundError(f"Group metadata was not found in {store} at {path}") + return meta + + +async def _read_group_metadata_v3(store: Store, path: str) -> GroupMetadata: + """ + Read group metadata or error + """ + meta = await _read_metadata_v3(store=store, path=path) + if not isinstance(meta, GroupMetadata): + raise FileNotFoundError(f"Group metadata was not found in {store} at {path}") + return meta + + +async def _read_group_metadata( + store: Store, path: str, *, zarr_format: ZarrFormat +) -> GroupMetadata: + if zarr_format == 2: + return await _read_group_metadata_v2(store=store, path=path) + return await _read_group_metadata_v3(store=store, path=path) + + def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMetadata: """ Convert a dict representation of Zarr V3 metadata into the corresponding metadata class. diff --git a/tests/test_group.py b/tests/test_group.py index be4d09acef..1fbfd9c2f3 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -26,8 +26,10 @@ from zarr.core.group import ( ConsolidatedMetadata, GroupMetadata, + ImplicitGroupMarker, _build_metadata_v3, _get_roots, + _parse_hierarchy_dict, create_hierarchy, create_nodes, create_rooted_hierarchy, @@ -38,7 +40,7 @@ from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage._common import make_store_path -from zarr.storage._utils import _join_paths +from zarr.storage._utils import _join_paths, normalize_path from zarr.testing.store import LatencyStore from .conftest import meta_from_array, parse_store @@ -1651,7 +1653,8 @@ async def test_group_create_hierarchy( Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict. Also test that off-target nodes are not deleted, and that the root group is not deleted """ - g = Group.from_store(store, zarr_format=zarr_format) + root_attrs = {"root": True} + g = Group.from_store(store, zarr_format=zarr_format, attributes=root_attrs) node_spec = { "a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), @@ -1683,10 +1686,10 @@ async def test_group_create_hierarchy( for k, v in extant_spec.items(): if overwrite: assert k in all_members - # check that we did not erase the root group - assert sync_group.get_node(store=store, path="", zarr_format=zarr_format) == g else: assert all_members[k].metadata == v == extant_created[k].metadata + # ensure that we left the root group as-is + assert sync_group.get_node(store=store, path="", zarr_format=zarr_format).attrs == root_attrs @pytest.mark.parametrize("store", ["memory"], indirect=True) @@ -1707,6 +1710,50 @@ def test_group_create_hierarchy_no_root( _ = dict(g.create_hierarchy(tree, overwrite=overwrite)) +class TestParseHierarchyDict: + """ + Tests for the function that parses dicts of str : Metadata pairs, ensuring that the output models a + valid Zarr hierarchy + """ + + @staticmethod + def test_normed_keys() -> None: + """ + Test that keys get normalized properly + """ + + nodes = { + "a": GroupMetadata(), + "/b": GroupMetadata(), + "": GroupMetadata(), + "/a//c////": GroupMetadata(), + } + observed = _parse_hierarchy_dict(data=nodes) + expected = {normalize_path(k): v for k, v in nodes.items()} + assert observed == expected + + @staticmethod + def test_empty() -> None: + """ + Test that an empty dict passes through + """ + assert _parse_hierarchy_dict(data={}) == {} + + @staticmethod + def test_implicit_groups() -> None: + """ + Test that implicit groups were added as needed. + """ + requested = {"a/b/c": GroupMetadata()} + expected = requested | { + "": ImplicitGroupMarker(), + "a": ImplicitGroupMarker(), + "a/b": ImplicitGroupMarker(), + } + observed = _parse_hierarchy_dict(data=requested) + assert observed == expected + + @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_create_hierarchy_invalid_mixed_zarr_format( store: Store, zarr_format: ZarrFormat From 7c56b87ce6c44b4d8c23a4ff4e1ed21f115ecb62 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 18 Feb 2025 14:49:39 +0100 Subject: [PATCH 55/59] docs --- docs/quickstart.rst | 22 +++++++++++++++++++ docs/user-guide/groups.rst | 25 +++++++++++++++++++++ src/zarr/__init__.py | 2 -- src/zarr/core/group.py | 43 +++++++++++++++++++++++-------------- src/zarr/core/sync_group.py | 43 +++++++++++++++++++++++++++++-------- 5 files changed, 108 insertions(+), 27 deletions(-) diff --git a/docs/quickstart.rst b/docs/quickstart.rst index d520554593..66bdae2a2e 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -119,6 +119,28 @@ Zarr allows you to create hierarchical groups, similar to directories:: This creates a group with two datasets: ``foo`` and ``bar``. +Batch Hierarchy Creation +~~~~~~~~~~~~~~~~~~~~~~~~ + +Zarr provides tools for creating a collection of arrays and groups with a single function call. +Suppose we want to copy existing groups and arrays into a new storage backend: + + >>> # Create nested groups and add arrays + >>> root = zarr.group("data/example-3.zarr", attributes={'name': 'root'}) + >>> foo = root.create_group(name="foo") + >>> bar = root.create_array( + ... name="bar", shape=(100, 10), chunks=(10, 10), dtype="f4" + ... ) + >>> nodes = {'': root.metadata} | {k: v.metadata for k,v in root.members()} + >>> print(nodes) + >>> from zarr.storage import MemoryStore + >>> new_nodes = dict(zarr.create_hierarchy(store=MemoryStore(), nodes=nodes)) + >>> new_root = new_nodes[''] + >>> assert new_root.attrs == root.attrs + +Note that :func:`zarr.create_hierarchy` will only initialize arrays and groups -- copying array data must +be done in a separate step. + Persistent Storage ------------------ diff --git a/docs/user-guide/groups.rst b/docs/user-guide/groups.rst index 1e72df3478..4268004f70 100644 --- a/docs/user-guide/groups.rst +++ b/docs/user-guide/groups.rst @@ -75,6 +75,31 @@ For more information on groups see the :class:`zarr.Group` API docs. .. _user-guide-diagnostics: +Batch Group Creation +-------------------- + +You can also create multiple groups concurrently with a single function call. :func:`zarr.create_hierarchy` takes +a :class:`zarr.storage.Store` instance and a dict of ``key : metadata`` pairs, parses that dict, and +writes metadata documents to storage: + + >>> from zarr import create_hierarchy + >>> from zarr.core.group import GroupMetadata + >>> from zarr.storage import LocalStore + >>> node_spec = {'a/b/c': GroupMetadata()} + >>> nodes_created = dict(create_hierarchy(store=LocalStore(root='data'), nodes=node_spec)) + >>> print(sorted(nodes_created.items(), key=lambda kv: len(kv[0]))) + [('', ), ('a', ), ('a/b', ), ('a/b/c', )] + +Note that we only specified a single group named ``a/b/c``, but 4 groups were created. These additional groups +were created to ensure that the desired node ``a/b/c`` is connected to the root group ``''`` by a sequence +of intermediate groups. :func:`zarr.create_hierarchy` normalizes the ``nodes`` keyword argument to +ensure that the resulting hierarchy is complete, i.e. all groups or arrays are connected to the root +of the hierarchy via intermediate groups. + +Because :func:`zarr.create_hierarchy` concurrently creates metadata documents, it's more efficient +than repeated calls to :func:`create_group` or :func:`create_array`, provided you can statically define +the metadata for the groups and arrays you want to create. + Array and group diagnostics --------------------------- diff --git a/src/zarr/__init__.py b/src/zarr/__init__.py index e0538259d7..4ffa4c9bbc 100644 --- a/src/zarr/__init__.py +++ b/src/zarr/__init__.py @@ -52,8 +52,6 @@ "create_array", "create_group", "create_hierarchy", - "create_nodes", - "create_rooted_hierarchy", "empty", "empty_like", "full", diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 35dbf4d9f6..ea1bd0dd63 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2866,15 +2866,10 @@ async def create_hierarchy( """ Create a complete zarr hierarchy from a collection of metadata objects. - This function will parse its input to ensure that the hierarchy is valid. In this context, - "valid" means that the following requirements are met: - * All nodes have the same zarr format. - * There are no nodes descending from arrays. - * There are no implicit groups. Any implicit groups will be inserted as needed. For example, - an input like ```{'a': GroupMetadata, 'a/b/c': GroupMetadata}``` defines an implicit group at - the path ```a/b```, and also at the root of the hierarchy, which we denote with the empty string. - After parsing, that group will be added and the input will be: - ```{'': GroupMetadata, 'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': GroupMetadata}``` + This function will parse its input to ensure that the hierarchy is complete. Any implicit groups + will be inserted as needed. For example, an input like + ```{'a/b': GroupMetadata}``` will be parsed to + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}``` After input parsing, this function then creates all the nodes in the hierarchy concurrently. @@ -2886,22 +2881,38 @@ async def create_hierarchy( store : Store The storage backend to use. nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] - A dictionary defining the hierarchy. The keys are the paths of the nodes - in the hierarchy, and the values are the metadata of the nodes. The - metadata must be either an instance of GroupMetadata, ArrayV3Metadata - or ArrayV2Metadata. + A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, + relative to the root of the ``Store``. The root of the store can be specified with the empty + string ``''``. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that + all values must have the same ``zarr_format`` -- it is an error to mix zarr versions in the + same hierarchy. overwrite : bool Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is raised instead of overwriting an existing array or group. + This function will not erase an existing group unless that group is explicitly named in + ``nodes``. If ``nodes`` defines implicit groups, e.g. ``{`'a/b/c': GroupMetadata}``, and a + group already exists at path ``a``, then this function will leave the group at ``a`` as-is. + Yields ------ - AsyncGroup | AsyncArray - The created nodes in the order they are created. + tuple[str, AsyncGroup | AsyncArray] + This function yields (path, node) pairs, in the order the nodes were created. Examples -------- - + from zarr.api.asynchronous import create_hierarchy + from zarr.storage import MemoryStore + from zarr.core.group import GroupMetadata + import asyncio + store = MemoryStore() + nodes = {'a': GroupMetadata(attributes={'name': 'leaf'})} + + async def run(): + print(dict([x async for x in create_hierarchy(store=store, nodes=nodes)])) + + asyncio.run(run()) + # {'a': , '': } """ # normalize the keys to be valid paths nodes_normed_keys = _normalize_path_keys(nodes) diff --git a/src/zarr/core/sync_group.py b/src/zarr/core/sync_group.py index 829d98affc..52f1326c16 100644 --- a/src/zarr/core/sync_group.py +++ b/src/zarr/core/sync_group.py @@ -57,25 +57,50 @@ def create_hierarchy( """ Create a complete zarr hierarchy from a collection of metadata objects. - Groups that are implicitly defined by the input will be created as needed. + This function will parse its input to ensure that the hierarchy is complete. Any implicit groups + will be inserted as needed. For example, an input like + ```{'a/b': GroupMetadata}``` will be parsed to + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}``` - This function takes a parsed hierarchy dictionary and creates all the nodes in the hierarchy - concurrently. Arrays and Groups are yielded in the order they are created. + After input parsing, this function then creates all the nodes in the hierarchy concurrently. + + Arrays and Groups are yielded in the order they are created. This order is not stable and + should not be relied on. Parameters ---------- store : Store The storage backend to use. nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] - A dictionary defining the hierarchy. The keys are the paths of the nodes - in the hierarchy, and the values are the metadata of the nodes. The - metadata must be either an instance of GroupMetadata, ArrayV3Metadata - or ArrayV2Metadata. + A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, + relative to the root of the ``Store``. The root of the store can be specified with the empty + string ``''``. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that + all values must have the same ``zarr_format`` -- it is an error to mix zarr versions in the + same hierarchy. + overwrite : bool + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is + raised instead of overwriting an existing array or group. + + This function will not erase an existing group unless that group is explicitly named in + ``nodes``. If ``nodes`` defines implicit groups, e.g. ``{`'a/b/c': GroupMetadata}``, and a + group already exists at path ``a``, then this function will leave the group at ``a`` as-is. Yields ------ - Group | Array - The created nodes in the order they are created. + tuple[str, Group | Array] + This function yields (path, node) pairs, in the order the nodes were created. + + Examples + -------- + from zarr import create_hierarchy + from zarr.storage import MemoryStore + from zarr.core.group import GroupMetadata + + store = MemoryStore() + nodes = {'a': GroupMetadata(attributes={'name': 'leaf'})} + nodes_created = dict(create_hierarchy(store=store, nodes=nodes)) + print(nodes) + # {'a': GroupMetadata(attributes={'name': 'leaf'}, zarr_format=3, consolidated_metadata=None, node_type='group')} """ coro = create_hierarchy_async(store=store, nodes=nodes, overwrite=overwrite) From 8245e80a14428c2262abeb7ad2e8de319d92250d Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 18 Feb 2025 15:09:45 +0100 Subject: [PATCH 56/59] fix group.create_hierarchy to properly prefix keys with the name of the group --- src/zarr/core/group.py | 16 +++++++++++----- tests/test_group.py | 18 ++++++++++++++---- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index ea1bd0dd63..ebbc95d4e6 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1414,9 +1414,11 @@ async def create_hierarchy( Returns ------- - An asynchronous iterator over the created arrays and / or groups. + An asynchronous iterator of (str, AsyncArray | AsyncGroup) pairs. """ # check that all the nodes have the same zarr_format as Self + prefix = self.path + nodes_parsed = {} for key, value in nodes.items(): if value.zarr_format != self.metadata.zarr_format: msg = ( @@ -1433,15 +1435,19 @@ async def create_hierarchy( "create_rooted_hierarchy to create a rooted hierarchy." ) raise ValueError(msg) - - nodes_rooted = nodes + else: + nodes_parsed[_join_paths([prefix, key])] = value async for key, node in create_hierarchy( store=self.store, - nodes=nodes_rooted, + nodes=nodes_parsed, overwrite=overwrite, ): - yield key, node + if prefix == "": + out_key = key + else: + out_key = key.removeprefix(prefix + "/") + yield out_key, node async def keys(self) -> AsyncGenerator[str, None]: """Iterate over member names.""" diff --git a/tests/test_group.py b/tests/test_group.py index 1fbfd9c2f3..521819ea0e 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1645,17 +1645,24 @@ async def test_create_hierarchy_existing_nodes( @pytest.mark.parametrize("store", ["memory"], indirect=True) @pytest.mark.parametrize("overwrite", [True, False]) +@pytest.mark.parametrize("group_path", ["", "foo"]) @pytest.mark.parametrize("impl", ["async", "sync"]) async def test_group_create_hierarchy( - store: Store, zarr_format: ZarrFormat, overwrite: bool, impl: Literal["async", "sync"] + store: Store, + zarr_format: ZarrFormat, + overwrite: bool, + group_path: str, + impl: Literal["async", "sync"], ) -> None: """ Test that the Group.create_hierarchy method creates specified nodes and returns them in a dict. Also test that off-target nodes are not deleted, and that the root group is not deleted """ root_attrs = {"root": True} - g = Group.from_store(store, zarr_format=zarr_format, attributes=root_attrs) - + g = sync_group.create_rooted_hierarchy( + store=store, + nodes={group_path: GroupMetadata(zarr_format=zarr_format, attributes=root_attrs)}, + ) node_spec = { "a": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), "a/b": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a/b"}), @@ -1689,7 +1696,10 @@ async def test_group_create_hierarchy( else: assert all_members[k].metadata == v == extant_created[k].metadata # ensure that we left the root group as-is - assert sync_group.get_node(store=store, path="", zarr_format=zarr_format).attrs == root_attrs + assert ( + sync_group.get_node(store=store, path=group_path, zarr_format=zarr_format).attrs.asdict() + == root_attrs + ) @pytest.mark.parametrize("store", ["memory"], indirect=True) From df2bdc69757e02d4a46c61b5cb33d2d10ac7e2ad Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 18 Feb 2025 15:24:07 +0100 Subject: [PATCH 57/59] docstrings --- src/zarr/core/group.py | 60 +++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index ebbc95d4e6..9249827ce6 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1400,21 +1400,34 @@ async def create_hierarchy( """ Create a hierarchy of arrays or groups rooted at this group. - This method takes a dictionary where the keys are the names of the arrays or groups - to create and the values are the metadata or objects representing the arrays or groups. + This function will parse its input to ensure that the hierarchy is complete. Any implicit groups + will be inserted as needed. For example, an input like + ```{'a/b': GroupMetadata}``` will be parsed to + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}```. - The method returns an asynchronous iterator over the created nodes. + Explicitly specifying a root group, e.g. with ``nodes = {'': GroupMetadata()}`` is an error + because this group instance is the root group. + + After input parsing, this function then creates all the nodes in the hierarchy concurrently. + + Arrays and Groups are yielded in the order they are created. This order is not stable and + should not be relied on. Parameters ---------- - nodes : A dictionary representing the hierarchy to create - + nodes : A dictionary representing the hierarchy to create. The keys should be paths relative to this group + and the values should be the metadata for the arrays or groups to create. overwrite : bool - Whether or not existing arrays / groups should be replaced. + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is + raised instead of overwriting an existing array or group. - Returns + This function will not erase an existing group unless that group is explicitly named in + ``nodes``. If ``nodes`` defines implicit groups, e.g. ``{`'a/b/c': GroupMetadata}``, and a + group already exists at path ``a``, then this function will leave the group at ``a`` as-is. + + Yields ------- - An asynchronous iterator of (str, AsyncArray | AsyncGroup) pairs. + tuple[str, AsyncArray | AsyncGroup]. """ # check that all the nodes have the same zarr_format as Self prefix = self.path @@ -2081,24 +2094,35 @@ def create_hierarchy( """ Create a hierarchy of arrays or groups rooted at this group. - This method takes a dictionary where the keys are the names of the arrays or groups - to create and the values are the metadata objects for the arrays or groups. + This function will parse its input to ensure that the hierarchy is complete. Any implicit groups + will be inserted as needed. For example, an input like + ```{'a/b': GroupMetadata}``` will be parsed to + ```{'': GroupMetadata, 'a': GroupMetadata, 'b': Groupmetadata}```. - This method returns an iterator of created Group or Array objects. + Explicitly specifying a root group, e.g. with ``nodes = {'': GroupMetadata()}`` is an error + because this group instance is the root group. + + After input parsing, this function then creates all the nodes in the hierarchy concurrently. + + Arrays and Groups are yielded in the order they are created. This order is not stable and + should not be relied on. - Note: this method will create additional groups as needed to ensure that a hierarchy is - complete. Usage like ``create_hierarchy({'a/b': GroupMetadata()})`` defines an implicit - group at ``a``. This function will ensure that the group at ``a`` exists, first by checking - if one already exists, and if not, creating one. Parameters ---------- - nodes : A dictionary representing the hierarchy to create. The keys should be relative paths + nodes : A dictionary representing the hierarchy to create. The keys should be paths relative to this group and the values should be the metadata for the arrays or groups to create. + overwrite : bool + Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is + raised instead of overwriting an existing array or group. - Returns + This function will not erase an existing group unless that group is explicitly named in + ``nodes``. If ``nodes`` defines implicit groups, e.g. ``{`'a/b/c': GroupMetadata}``, and a + group already exists at path ``a``, then this function will leave the group at ``a`` as-is. + + Yields ------- - An iterator of (name, Array or Group) tuples. + tuple[str, Array | Group]. Examples -------- From 35afe7f1c8b1cd06e62720ac5a823c35a421dd17 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 18 Feb 2025 15:45:46 +0100 Subject: [PATCH 58/59] docstrings --- src/zarr/core/group.py | 23 +++++++++++++++++------ src/zarr/core/sync_group.py | 2 ++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 9249827ce6..c34bce4189 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1415,8 +1415,13 @@ async def create_hierarchy( Parameters ---------- - nodes : A dictionary representing the hierarchy to create. The keys should be paths relative to this group - and the values should be the metadata for the arrays or groups to create. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, + relative to the path of the group. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that + all values must have the same ``zarr_format`` as the parent group -- it is an error to mix zarr versions in the + same hierarchy. + + Leading "/" characters from keys will be removed. overwrite : bool Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is raised instead of overwriting an existing array or group. @@ -2107,11 +2112,15 @@ def create_hierarchy( Arrays and Groups are yielded in the order they are created. This order is not stable and should not be relied on. - Parameters ---------- - nodes : A dictionary representing the hierarchy to create. The keys should be paths relative to this group - and the values should be the metadata for the arrays or groups to create. + nodes : dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata] + A dictionary defining the hierarchy. The keys are the paths of the nodes in the hierarchy, + relative to the path of the group. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that + all values must have the same ``zarr_format`` as the parent group -- it is an error to mix zarr versions in the + same hierarchy. + + Leading "/" characters from keys will be removed. overwrite : bool Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is raised instead of overwriting an existing array or group. @@ -2906,7 +2915,7 @@ async def create_hierarchy( Arrays and Groups are yielded in the order they are created. This order is not stable and should not be relied on. - Parameters + Parameters ---------- store : Store The storage backend to use. @@ -2916,6 +2925,8 @@ async def create_hierarchy( string ``''``. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that all values must have the same ``zarr_format`` -- it is an error to mix zarr versions in the same hierarchy. + + Leading "/" characters from keys will be removed. overwrite : bool Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is raised instead of overwriting an existing array or group. diff --git a/src/zarr/core/sync_group.py b/src/zarr/core/sync_group.py index 52f1326c16..d1238933e2 100644 --- a/src/zarr/core/sync_group.py +++ b/src/zarr/core/sync_group.py @@ -77,6 +77,8 @@ def create_hierarchy( string ``''``. The values are instances of ``GroupMetadata`` or ``ArrayMetadata``. Note that all values must have the same ``zarr_format`` -- it is an error to mix zarr versions in the same hierarchy. + + Leading "/" characters from keys will be removed. overwrite : bool Whether to overwrite existing nodes. Defaults to ``False``, in which case an error is raised instead of overwriting an existing array or group. From 77264e4b02b678b27768d1153ef2087465703111 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Tue, 18 Feb 2025 16:15:17 +0100 Subject: [PATCH 59/59] docstring examples --- src/zarr/core/group.py | 20 +++++++++----------- src/zarr/core/sync_group.py | 16 ++++++++-------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index c34bce4189..a7f8a6c022 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -2942,17 +2942,15 @@ async def create_hierarchy( Examples -------- - from zarr.api.asynchronous import create_hierarchy - from zarr.storage import MemoryStore - from zarr.core.group import GroupMetadata - import asyncio - store = MemoryStore() - nodes = {'a': GroupMetadata(attributes={'name': 'leaf'})} - - async def run(): - print(dict([x async for x in create_hierarchy(store=store, nodes=nodes)])) - - asyncio.run(run()) + >>> from zarr.api.asynchronous import create_hierarchy + >>> from zarr.storage import MemoryStore + >>> from zarr.core.group import GroupMetadata + >>> import asyncio + >>> store = MemoryStore() + >>> nodes = {'a': GroupMetadata(attributes={'name': 'leaf'})} + >>> async def run(): + ... print(dict([x async for x in create_hierarchy(store=store, nodes=nodes)])) + >>> asyncio.run(run()) # {'a': , '': } """ # normalize the keys to be valid paths diff --git a/src/zarr/core/sync_group.py b/src/zarr/core/sync_group.py index d1238933e2..39d8a17992 100644 --- a/src/zarr/core/sync_group.py +++ b/src/zarr/core/sync_group.py @@ -94,14 +94,14 @@ def create_hierarchy( Examples -------- - from zarr import create_hierarchy - from zarr.storage import MemoryStore - from zarr.core.group import GroupMetadata - - store = MemoryStore() - nodes = {'a': GroupMetadata(attributes={'name': 'leaf'})} - nodes_created = dict(create_hierarchy(store=store, nodes=nodes)) - print(nodes) + >>> from zarr import create_hierarchy + >>> from zarr.storage import MemoryStore + >>> from zarr.core.group import GroupMetadata + + >>> store = MemoryStore() + >>> nodes = {'a': GroupMetadata(attributes={'name': 'leaf'})} + >>> nodes_created = dict(create_hierarchy(store=store, nodes=nodes)) + >>> print(nodes) # {'a': GroupMetadata(attributes={'name': 'leaf'}, zarr_format=3, consolidated_metadata=None, node_type='group')} """ coro = create_hierarchy_async(store=store, nodes=nodes, overwrite=overwrite)