diff --git a/src/zarr/core/sync.py b/src/zarr/core/sync.py index 6a2de855e8..2bb5f24802 100644 --- a/src/zarr/core/sync.py +++ b/src/zarr/core/sync.py @@ -3,6 +3,7 @@ import asyncio import atexit import logging +import os import threading from concurrent.futures import ThreadPoolExecutor, wait from typing import TYPE_CHECKING, TypeVar @@ -89,6 +90,26 @@ def cleanup_resources() -> None: atexit.register(cleanup_resources) +def reset_resources_after_fork() -> None: + """ + Ensure that global resources are reset after a fork. Without this function, + forked processes will retain invalid references to the parent process's resources. + """ + global loop, iothread, _executor + # These lines are excluded from coverage because this function only runs in a child process, + # which is not observed by the test coverage instrumentation. Despite the apparent lack of + # test coverage, this function should be adequately tested by any test that uses Zarr IO with + # multiprocessing. + loop[0] = None # pragma: no cover + iothread[0] = None # pragma: no cover + _executor = None # pragma: no cover + + +# this is only available on certain operating systems +if hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=reset_resources_after_fork) + + async def _runner(coro: Coroutine[Any, Any, T]) -> T | BaseException: """ Await a coroutine and return the result of running it. If awaiting the coroutine raises an diff --git a/tests/test_array.py b/tests/test_array.py index e458ba106e..1b84d1d061 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1,8 +1,10 @@ import dataclasses import json import math +import multiprocessing as mp import pickle import re +import sys from itertools import accumulate from typing import TYPE_CHECKING, Any, Literal from unittest import mock @@ -1382,3 +1384,39 @@ def test_roundtrip_numcodecs() -> None: metadata = root["test"].metadata.to_dict() expected = (*filters, BYTES_CODEC, *compressors) assert metadata["codecs"] == expected + + +def _index_array(arr: Array, index: Any) -> Any: + return arr[index] + + +@pytest.mark.parametrize( + "method", + [ + pytest.param( + "fork", + marks=pytest.mark.skipif( + sys.platform in ("win32", "darwin"), reason="fork not supported on Windows or OSX" + ), + ), + "spawn", + pytest.param( + "forkserver", + marks=pytest.mark.skipif( + sys.platform == "win32", reason="forkserver not supported on Windows" + ), + ), + ], +) +@pytest.mark.parametrize("store", ["local"], indirect=True) +def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkserver"]) -> None: + """ + Test that arrays can be pickled and indexed in child processes + """ + data = np.arange(100) + arr = zarr.create_array(store=store, data=data) + ctx = mp.get_context(method) + pool = ctx.Pool() + + results = pool.starmap(_index_array, [(arr, slice(len(data)))]) + assert all(np.array_equal(r, data) for r in results)