Skip to content

Commit

Permalink
[torch][cuda] fix race condition in cuda initialization (pytorch#143238)
Browse files Browse the repository at this point in the history
The access to lazy init callbacks (`_lazy_seed_tracker` and `_queued_calls`) is not synchronized with the initialization lock.

This exposes us to the following race:
1. start `_lazy_init`
2. take `_initialization_lock`
3. flush `_queued_calls` and run them all
4. another thread comes in and uses `_lazy_call` to put something on the queue (in our case, the `manual_seed`)
5. original thread finishes initializing, but never runs that call

Pull Request resolved: pytorch#143238
Approved by: https://github.com/ngimel
  • Loading branch information
suo authored and pytorchmergebot committed Dec 14, 2024
1 parent 28d8297 commit 9933e59
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions torch/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,20 +245,21 @@ def is_initialized():


def _lazy_call(callable, **kwargs):
if is_initialized():
callable()
else:
# TODO(torch_deploy): this accesses linecache, which attempts to read the
# file system to get traceback info. Patch linecache or do something
# else here if this ends up being important.
global _lazy_seed_tracker
if kwargs.get("seed_all", False):
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
elif kwargs.get("seed", False):
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
with _initialization_lock:
if is_initialized():
callable()
else:
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))
# TODO(torch_deploy): this accesses linecache, which attempts to read the
# file system to get traceback info. Patch linecache or do something
# else here if this ends up being important.
global _lazy_seed_tracker
if kwargs.get("seed_all", False):
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
elif kwargs.get("seed", False):
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
else:
# Don't store the actual traceback to avoid memory cycle
_queued_calls.append((callable, traceback.format_stack()))


_lazy_call(_check_capability)
Expand Down

0 comments on commit 9933e59

Please sign in to comment.