Skip to content

Commit

Permalink
Fix bug when calling einx ops from multiple threads
Browse files Browse the repository at this point in the history
  • Loading branch information
fferflo committed Apr 26, 2024
1 parent fd896ee commit 85c2f51
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 22 deletions.
17 changes: 11 additions & 6 deletions einx/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,17 @@ def inner(*args):


_thread_local = threading.local()
_thread_local.default_backend_stack = []


def _get_backend_stack():
if not hasattr(_thread_local, "backend_stack"):
_thread_local.backend_stack = []
return _thread_local.backend_stack


def get_default():
if len(_thread_local.default_backend_stack) > 0:
return _thread_local.default_backend_stack[-1]
if len(_get_backend_stack()) > 0:
return _get_backend_stack()[-1]
else:
return None

Expand All @@ -71,12 +76,12 @@ class Backend:
decorators = []

def __enter__(backend):
_thread_local.default_backend_stack.append(backend)
_get_backend_stack().append(backend)
return backend

def __exit__(backend, *args):
assert _thread_local.default_backend_stack[-1] is backend
_thread_local.default_backend_stack.pop()
assert _get_backend_stack()[-1] is backend
_get_backend_stack().pop()

@staticmethod
def _decorate_construct_graph(f):
Expand Down
23 changes: 19 additions & 4 deletions einx/traceback_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,22 @@ def include_frame(fname):


thread_local = threading.local()
thread_local.in_reraise = False


def _set_in_reraise():
if not hasattr(thread_local, "in_reraise"):
thread_local.in_reraise = False
assert not thread_local.in_reraise
thread_local.in_reraise = True


def _unset_in_reraise():
assert thread_local.in_reraise
thread_local.in_reraise = False


def _is_in_reraise():
return getattr(thread_local, "in_reraise", False)


def _filter_tb(tb):
Expand Down Expand Up @@ -46,8 +61,8 @@ def filter(func):

@functools.wraps(func)
def func_with_reraise(*args, **kwargs):
if not thread_local.in_reraise:
thread_local.in_reraise = True
if not _is_in_reraise():
_set_in_reraise()
tb = None
try:
return func(*args, **kwargs)
Expand All @@ -56,7 +71,7 @@ def func_with_reraise(*args, **kwargs):
raise e.with_traceback(tb) from None
finally:
del tb
thread_local.in_reraise = False
_unset_in_reraise()
else:
return func(*args, **kwargs)

Expand Down
19 changes: 12 additions & 7 deletions einx/tracer/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,23 +112,28 @@ def lru_cache(func):


_thread_local = threading.local()
_thread_local.stack = []


def _get_trace_stack():
if not hasattr(_thread_local, "stack"):
_thread_local.stack = []
return _thread_local.stack


class _trace_context:
def __init__(self, backend):
self.backend = backend

def __enter__(self):
_thread_local.stack.append(self)
_get_trace_stack().append(self)

def __exit__(self, *args):
assert id(_thread_local.stack[-1]) == id(self)
_thread_local.stack.pop()
assert id(_get_trace_stack()[-1]) == id(self)
_get_trace_stack().pop()


def _is_tracing():
return len(_thread_local.stack) > 0
return len(_get_trace_stack()) > 0


trace_all = lambda t, c: lambda *args, **kwargs: c(
Expand Down Expand Up @@ -187,8 +192,8 @@ def func_jit(*args, backend=None, graph=False, **kwargs):
if _is_tracing():
assert not graph
if backend is None:
backend = _thread_local.stack[-1].backend
elif backend != _thread_local.stack[-1].backend:
backend = _get_trace_stack()[-1].backend
elif backend != _get_trace_stack()[-1].backend:
raise ValueError("Cannot change backend during tracing")

return func(*args, backend=backend, **kwargs)
Expand Down
15 changes: 10 additions & 5 deletions einx/tracer/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,29 @@ def check(got_output, expected_output):
signature=signature,
inplace_updates=inplace_updates,
comment=comment,
depend_on=depend_on + _thread_local.depend_on,
depend_on=depend_on + _get_depend_on_stack(),
).output


_thread_local = threading.local()
_thread_local.depend_on = []


def _get_depend_on_stack():
if not hasattr(_thread_local, "depend_on"):
_thread_local.depend_on = []
return _thread_local.depend_on


class depend_on:
def __init__(self, tracers):
self.tracer = list(einx.tree_util.tree_flatten(tracers))

def __enter__(self):
_thread_local.depend_on.append(self.tracer)
_get_depend_on_stack().append(self.tracer)

def __exit__(self, *args):
assert _thread_local.depend_on[-1] is self.tracer
_thread_local.depend_on.pop()
assert _get_depend_on_stack()[-1] is self.tracer
_get_depend_on_stack().pop()


class Tracer:
Expand Down
53 changes: 53 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import numpy as np
import types
import einx
import threading
import multiprocessing

tests = []

Expand Down Expand Up @@ -53,6 +55,55 @@ def op3(*args, **kwargs):
return op3


def in_new_thread(op):
def inner(*args, **kwargs):
result = [None, None]

def run(result):
try:
result[0] = op(*args, **kwargs)
except Exception as e:
result[1] = e

thread = threading.Thread(target=run, args=(result,))
thread.start()
thread.join()
if result[1] is not None:
raise result[1]
else:
return result[0]

return inner


einx_multithread = WrappedEinx(in_new_thread, "multithreading", inline_args=True)


def in_new_process(op):
def inner(*args, **kwargs):
result = multiprocessing.Queue()
exception = multiprocessing.Queue()

def run(result, exception):
try:
result.put(op(*args, **kwargs))
except Exception as e:
exception.put(e)

process = multiprocessing.Process(target=run, args=(result, exception))
process.start()
process.join()
if not exception.empty():
raise exception.get()
else:
return result.get()

return inner


einx_multiprocess = WrappedEinx(in_new_process, "multiprocessing", inline_args=True)


# numpy is always available
import numpy as np

Expand All @@ -65,6 +116,8 @@ def op3(*args, **kwargs):
)

tests.append((einx, backend, test))
tests.append((einx_multithread, backend, test))
# tests.append((einx_multiprocess, backend, test)) # too slow


if importlib.util.find_spec("jax"):
Expand Down

0 comments on commit 85c2f51

Please sign in to comment.