From 85c2f511c4fac829ae4335d38bbaf9e0b44e3bf2 Mon Sep 17 00:00:00 2001 From: Florian Fervers Date: Fri, 26 Apr 2024 10:54:58 +0200 Subject: [PATCH] Fix bug when calling einx ops from multiple threads --- einx/backend/base.py | 17 ++++++++----- einx/traceback_util.py | 23 ++++++++++++++--- einx/tracer/decorator.py | 19 ++++++++------ einx/tracer/tracer.py | 15 ++++++++---- test/conftest.py | 53 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 105 insertions(+), 22 deletions(-) diff --git a/einx/backend/base.py b/einx/backend/base.py index 745a18a..3b5b1fe 100644 --- a/einx/backend/base.py +++ b/einx/backend/base.py @@ -56,12 +56,17 @@ def inner(*args): _thread_local = threading.local() -_thread_local.default_backend_stack = [] + + +def _get_backend_stack(): + if not hasattr(_thread_local, "backend_stack"): + _thread_local.backend_stack = [] + return _thread_local.backend_stack def get_default(): - if len(_thread_local.default_backend_stack) > 0: - return _thread_local.default_backend_stack[-1] + if len(_get_backend_stack()) > 0: + return _get_backend_stack()[-1] else: return None @@ -71,12 +76,12 @@ class Backend: decorators = [] def __enter__(backend): - _thread_local.default_backend_stack.append(backend) + _get_backend_stack().append(backend) return backend def __exit__(backend, *args): - assert _thread_local.default_backend_stack[-1] is backend - _thread_local.default_backend_stack.pop() + assert _get_backend_stack()[-1] is backend + _get_backend_stack().pop() @staticmethod def _decorate_construct_graph(f): diff --git a/einx/traceback_util.py b/einx/traceback_util.py index ff8b6cd..c9d7ecd 100644 --- a/einx/traceback_util.py +++ b/einx/traceback_util.py @@ -12,7 +12,22 @@ def include_frame(fname): thread_local = threading.local() -thread_local.in_reraise = False + + +def _set_in_reraise(): + if not hasattr(thread_local, "in_reraise"): + thread_local.in_reraise = False + assert not thread_local.in_reraise + thread_local.in_reraise = True + + +def _unset_in_reraise(): + assert thread_local.in_reraise + thread_local.in_reraise = False + + +def _is_in_reraise(): + return getattr(thread_local, "in_reraise", False) def _filter_tb(tb): @@ -46,8 +61,8 @@ def filter(func): @functools.wraps(func) def func_with_reraise(*args, **kwargs): - if not thread_local.in_reraise: - thread_local.in_reraise = True + if not _is_in_reraise(): + _set_in_reraise() tb = None try: return func(*args, **kwargs) @@ -56,7 +71,7 @@ def func_with_reraise(*args, **kwargs): raise e.with_traceback(tb) from None finally: del tb - thread_local.in_reraise = False + _unset_in_reraise() else: return func(*args, **kwargs) diff --git a/einx/tracer/decorator.py b/einx/tracer/decorator.py index 6e89bb4..b42588b 100644 --- a/einx/tracer/decorator.py +++ b/einx/tracer/decorator.py @@ -112,7 +112,12 @@ def lru_cache(func): _thread_local = threading.local() -_thread_local.stack = [] + + +def _get_trace_stack(): + if not hasattr(_thread_local, "stack"): + _thread_local.stack = [] + return _thread_local.stack class _trace_context: @@ -120,15 +125,15 @@ def __init__(self, backend): self.backend = backend def __enter__(self): - _thread_local.stack.append(self) + _get_trace_stack().append(self) def __exit__(self, *args): - assert id(_thread_local.stack[-1]) == id(self) - _thread_local.stack.pop() + assert id(_get_trace_stack()[-1]) == id(self) + _get_trace_stack().pop() def _is_tracing(): - return len(_thread_local.stack) > 0 + return len(_get_trace_stack()) > 0 trace_all = lambda t, c: lambda *args, **kwargs: c( @@ -187,8 +192,8 @@ def func_jit(*args, backend=None, graph=False, **kwargs): if _is_tracing(): assert not graph if backend is None: - backend = _thread_local.stack[-1].backend - elif backend != _thread_local.stack[-1].backend: + backend = _get_trace_stack()[-1].backend + elif backend != _get_trace_stack()[-1].backend: raise ValueError("Cannot change backend during tracing") return func(*args, backend=backend, **kwargs) diff --git a/einx/tracer/tracer.py b/einx/tracer/tracer.py index 366f3dc..a5eb370 100644 --- a/einx/tracer/tracer.py +++ b/einx/tracer/tracer.py @@ -84,12 +84,17 @@ def check(got_output, expected_output): signature=signature, inplace_updates=inplace_updates, comment=comment, - depend_on=depend_on + _thread_local.depend_on, + depend_on=depend_on + _get_depend_on_stack(), ).output _thread_local = threading.local() -_thread_local.depend_on = [] + + +def _get_depend_on_stack(): + if not hasattr(_thread_local, "depend_on"): + _thread_local.depend_on = [] + return _thread_local.depend_on class depend_on: @@ -97,11 +102,11 @@ def __init__(self, tracers): self.tracer = list(einx.tree_util.tree_flatten(tracers)) def __enter__(self): - _thread_local.depend_on.append(self.tracer) + _get_depend_on_stack().append(self.tracer) def __exit__(self, *args): - assert _thread_local.depend_on[-1] is self.tracer - _thread_local.depend_on.pop() + assert _get_depend_on_stack()[-1] is self.tracer + _get_depend_on_stack().pop() class Tracer: diff --git a/test/conftest.py b/test/conftest.py index b2a3eaf..f76990c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,6 +2,8 @@ import numpy as np import types import einx +import threading +import multiprocessing tests = [] @@ -53,6 +55,55 @@ def op3(*args, **kwargs): return op3 +def in_new_thread(op): + def inner(*args, **kwargs): + result = [None, None] + + def run(result): + try: + result[0] = op(*args, **kwargs) + except Exception as e: + result[1] = e + + thread = threading.Thread(target=run, args=(result,)) + thread.start() + thread.join() + if result[1] is not None: + raise result[1] + else: + return result[0] + + return inner + + +einx_multithread = WrappedEinx(in_new_thread, "multithreading", inline_args=True) + + +def in_new_process(op): + def inner(*args, **kwargs): + result = multiprocessing.Queue() + exception = multiprocessing.Queue() + + def run(result, exception): + try: + result.put(op(*args, **kwargs)) + except Exception as e: + exception.put(e) + + process = multiprocessing.Process(target=run, args=(result, exception)) + process.start() + process.join() + if not exception.empty(): + raise exception.get() + else: + return result.get() + + return inner + + +einx_multiprocess = WrappedEinx(in_new_process, "multiprocessing", inline_args=True) + + # numpy is always available import numpy as np @@ -65,6 +116,8 @@ def op3(*args, **kwargs): ) tests.append((einx, backend, test)) +tests.append((einx_multithread, backend, test)) +# tests.append((einx_multiprocess, backend, test)) # too slow if importlib.util.find_spec("jax"):