Skip to content

Commit

Permalink
move dask code into xarray/backends/locks.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kmuehlbauer committed Jan 3, 2024
1 parent fe5b4a9 commit 4f6e831
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 112 deletions.
107 changes: 0 additions & 107 deletions xarray/backends/dask_lock.py

This file was deleted.

79 changes: 76 additions & 3 deletions xarray/backends/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,84 @@

import multiprocessing
import threading
import uuid
import weakref
from collections.abc import MutableMapping
from typing import Any
from collections.abc import Hashable, MutableMapping
from typing import Any, ClassVar
from weakref import WeakValueDictionary


# SerializableLock is adapted from Dask:
# https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224
# Used under the terms of Dask's license, see licenses/DASK_LICENSE.
class SerializableLock:
"""A Serializable per-process Lock
This wraps a normal ``threading.Lock`` object and satisfies the same
interface. However, this lock can also be serialized and sent to different
processes. It will not block concurrent operations between processes (for
this you should look at ``dask.multiprocessing.Lock`` or ``locket.lock_file``
but will consistently deserialize into the same lock.
So if we make a lock in one process::
lock = SerializableLock()
And then send it over to another process multiple times::
bytes = pickle.dumps(lock)
a = pickle.loads(bytes)
b = pickle.loads(bytes)
Then the deserialized objects will operate as though they were the same
lock, and collide as appropriate.
This is useful for consistently protecting resources on a per-process
level.
The creation of locks is itself not threadsafe.
"""

_locks: ClassVar[
WeakValueDictionary[Hashable, threading.Lock]
] = WeakValueDictionary()
token: Hashable
lock: threading.Lock

def __init__(self, token: Hashable | None = None):
self.token = token or str(uuid.uuid4())
if self.token in SerializableLock._locks:
self.lock = SerializableLock._locks[self.token]
else:
self.lock = threading.Lock()
SerializableLock._locks[self.token] = self.lock

def acquire(self, *args, **kwargs):
return self.lock.acquire(*args, **kwargs)

def release(self, *args, **kwargs):
return self.lock.release(*args, **kwargs)

def __enter__(self):
self.lock.__enter__()

def __exit__(self, *args):
self.lock.__exit__(*args)

def locked(self):
return self.lock.locked()

def __getstate__(self):
return self.token

def __setstate__(self, token):
self.__init__(token)

def __str__(self):
return f"<{self.__class__.__name__}: {self.token}>"

__repr__ = __str__

from xarray.backends.dask_lock import SerializableLock

# Locks used by multiple backends.
# Neither HDF5 nor the netCDF-C library are thread-safe.
Expand Down
3 changes: 1 addition & 2 deletions xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
)

import xarray as xr
from xarray.backends.dask_lock import SerializableLock
from xarray.backends.locks import HDF5_LOCK, CombinedLock
from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock
from xarray.tests import (
assert_allclose,
assert_identical,
Expand Down

0 comments on commit 4f6e831

Please sign in to comment.