Skip to content

Commit

Permalink
Multiprocessing support (#2815)
Browse files Browse the repository at this point in the history
* add failing multiprocessing test

* add hook to reset global vars after fork

* parametrize multiprocessing test over different methods

* guard execution of register_at_fork with a hasattr check

* exempt runs-in-a-forked-process code from coverage

* update literal type
  • Loading branch information
d-v-b authored Feb 11, 2025
1 parent 8b77464 commit 2f8b88a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/zarr/core/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import atexit
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor, wait
from typing import TYPE_CHECKING, TypeVar
Expand Down Expand Up @@ -89,6 +90,26 @@ def cleanup_resources() -> None:
atexit.register(cleanup_resources)


def reset_resources_after_fork() -> None:
"""
Ensure that global resources are reset after a fork. Without this function,
forked processes will retain invalid references to the parent process's resources.
"""
global loop, iothread, _executor
# These lines are excluded from coverage because this function only runs in a child process,
# which is not observed by the test coverage instrumentation. Despite the apparent lack of
# test coverage, this function should be adequately tested by any test that uses Zarr IO with
# multiprocessing.
loop[0] = None # pragma: no cover
iothread[0] = None # pragma: no cover
_executor = None # pragma: no cover


# this is only available on certain operating systems
if hasattr(os, "register_at_fork"):
os.register_at_fork(after_in_child=reset_resources_after_fork)


async def _runner(coro: Coroutine[Any, Any, T]) -> T | BaseException:
"""
Await a coroutine and return the result of running it. If awaiting the coroutine raises an
Expand Down
38 changes: 38 additions & 0 deletions tests/test_array.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dataclasses
import json
import math
import multiprocessing as mp
import pickle
import re
import sys
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Literal
from unittest import mock
Expand Down Expand Up @@ -1382,3 +1384,39 @@ def test_roundtrip_numcodecs() -> None:
metadata = root["test"].metadata.to_dict()
expected = (*filters, BYTES_CODEC, *compressors)
assert metadata["codecs"] == expected


def _index_array(arr: Array, index: Any) -> Any:
return arr[index]


@pytest.mark.parametrize(
"method",
[
pytest.param(
"fork",
marks=pytest.mark.skipif(
sys.platform in ("win32", "darwin"), reason="fork not supported on Windows or OSX"
),
),
"spawn",
pytest.param(
"forkserver",
marks=pytest.mark.skipif(
sys.platform == "win32", reason="forkserver not supported on Windows"
),
),
],
)
@pytest.mark.parametrize("store", ["local"], indirect=True)
def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkserver"]) -> None:
"""
Test that arrays can be pickled and indexed in child processes
"""
data = np.arange(100)
arr = zarr.create_array(store=store, data=data)
ctx = mp.get_context(method)
pool = ctx.Pool()

results = pool.starmap(_index_array, [(arr, slice(len(data)))])
assert all(np.array_equal(r, data) for r in results)

0 comments on commit 2f8b88a

Please sign in to comment.