Skip to content

Commit

Permalink
BUG: Avoid data race in PyArray_CheckFromAny_int (numpy#28154)
Browse files Browse the repository at this point in the history
* BUG: Avoid data race in PyArray_CheckFromAny_int

* TST: add test

* MAINT: simplify byteswapping code in PyArray_CheckFromAny_int

* MAINT: drop ISBYTESWAPPED check
  • Loading branch information
ngoldbaum authored Jan 15, 2025
1 parent bbf4836 commit 1e10174
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
16 changes: 5 additions & 11 deletions numpy/_core/src/multiarray/ctors.c
Original file line number Diff line number Diff line change
Expand Up @@ -1829,18 +1829,12 @@ PyArray_CheckFromAny_int(PyObject *op, PyArray_Descr *in_descr,
{
PyObject *obj;
if (requires & NPY_ARRAY_NOTSWAPPED) {
if (!in_descr && PyArray_Check(op) &&
PyArray_ISBYTESWAPPED((PyArrayObject* )op)) {
in_descr = PyArray_DescrNew(PyArray_DESCR((PyArrayObject *)op));
if (in_descr == NULL) {
return NULL;
}
}
else if (in_descr && !PyArray_ISNBO(in_descr->byteorder)) {
PyArray_DESCR_REPLACE(in_descr);
if (!in_descr && PyArray_Check(op)) {
in_descr = PyArray_DESCR((PyArrayObject *)op);
Py_INCREF(in_descr);
}
if (in_descr && in_descr->byteorder != NPY_IGNORE) {
in_descr->byteorder = NPY_NATIVE;
if (in_descr) {
PyArray_DESCR_REPLACE_CANONICAL(in_descr);
}
}

Expand Down
5 changes: 5 additions & 0 deletions numpy/_core/src/multiarray/dtypemeta.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ PyArray_SETITEM(PyArrayObject *arr, char *itemptr, PyObject *v)
v, itemptr, arr);
}

// Like PyArray_DESCR_REPLACE, but calls ensure_canonical instead of DescrNew
#define PyArray_DESCR_REPLACE_CANONICAL(descr) do { \
PyArray_Descr *_new_ = NPY_DT_CALL_ensure_canonical(descr); \
Py_XSETREF(descr, _new_); \
} while(0)


#endif /* NUMPY_CORE_SRC_MULTIARRAY_DTYPEMETA_H_ */
13 changes: 13 additions & 0 deletions numpy/_core/tests/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def closure(b):


def test_parallel_flat_iterator():
# gh-28042
x = np.arange(20).reshape(5, 4).T

def closure(b):
Expand All @@ -142,3 +143,15 @@ def closure(b):
list(x.flat)

run_threaded(closure, outer_iterations=100, pass_barrier=True)

# gh-28143
def prepare_args():
return [np.arange(10)]

def closure(x, b):
b.wait()
for _ in range(100):
y = np.arange(10)
y.flat[x] = x

run_threaded(closure, pass_barrier=True, prepare_args=prepare_args)
8 changes: 6 additions & 2 deletions numpy/testing/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2687,12 +2687,16 @@ def _get_glibc_version():


def run_threaded(func, iters=8, pass_count=False, max_workers=8,
pass_barrier=False, outer_iterations=1):
pass_barrier=False, outer_iterations=1,
prepare_args=None):
"""Runs a function many times in parallel"""
for _ in range(outer_iterations):
with (concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
as tpe):
args = []
if prepare_args is None:
args = []
else:
args = prepare_args()
if pass_barrier:
if max_workers != iters:
raise RuntimeError(
Expand Down

0 comments on commit 1e10174

Please sign in to comment.