Skip to content

Commit c76cb31

Browse files
authored
Merge pull request #132 from tdixon97/main
Allow table evaluations to return np.ndarray also without numexpr
2 parents b678606 + c22d610 commit c76cb31

File tree

3 files changed

+46
-19
lines changed

3 files changed

+46
-19
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ test = [
7676
"pylegendtestdata",
7777
"pytest>=6.0",
7878
"pytest-cov",
79+
"dbetto",
7980
]
8081

8182
[project.scripts]

src/lgdo/types/table.py

+31-19
Original file line numberDiff line numberDiff line change
@@ -351,31 +351,39 @@ def eval(
351351
msg = f"evaluating {expr!r} with locals={(self_unwrap | parameters)} and {has_ak=}"
352352
log.debug(msg)
353353

354-
# use numexpr if we are only dealing with numpy data types (and no global dictionary)
355-
if not has_ak and modules is None:
356-
out_data = ne.evaluate(
357-
expr,
358-
local_dict=(self_unwrap | parameters),
359-
)
360-
361-
msg = f"...the result is {out_data!r}"
362-
log.debug(msg)
363-
364-
# need to convert back to LGDO
365-
# np.evaluate should always return a numpy thing?
366-
if out_data.ndim == 0:
367-
return Scalar(out_data.item())
368-
if out_data.ndim == 1:
369-
return Array(out_data)
370-
if out_data.ndim == 2:
371-
return ArrayOfEqualSizedArrays(nda=out_data)
354+
def _make_lgdo(data):
355+
if data.ndim == 0:
356+
return Scalar(data.item())
357+
if data.ndim == 1:
358+
return Array(data)
359+
if data.ndim == 2:
360+
return ArrayOfEqualSizedArrays(nda=data)
372361

373362
msg = (
374-
f"evaluation resulted in {out_data.ndim}-dimensional data, "
363+
f"evaluation resulted in {data.ndim}-dimensional data, "
375364
"I don't know which LGDO this corresponds to"
376365
)
377366
raise RuntimeError(msg)
378367

368+
# use numexpr if we are only dealing with numpy data types (and no global dictionary)
369+
if not has_ak and modules is None:
370+
try:
371+
out_data = ne.evaluate(
372+
expr,
373+
local_dict=(self_unwrap | parameters),
374+
)
375+
376+
msg = f"...the result is {out_data!r}"
377+
log.debug(msg)
378+
379+
# need to convert back to LGDO
380+
# np.evaluate should always return a numpy thing?
381+
return _make_lgdo(out_data)
382+
383+
except Exception:
384+
msg = f"Warning {expr} could not be evaluated with numexpr probably due to some not allowed characters, trying with eval()."
385+
log.debug(msg)
386+
379387
# resort to good ol' eval()
380388
globs = {"ak": ak, "np": np}
381389
if modules is not None:
@@ -392,6 +400,10 @@ def eval(
392400
return Array(out_data.to_numpy())
393401
return VectorOfVectors(out_data)
394402

403+
# modules can still produce numpy array
404+
if isinstance(out_data, np.ndarray):
405+
return _make_lgdo(out_data)
406+
395407
if np.isscalar(out_data):
396408
return Scalar(out_data)
397409

tests/types/test_table_eval.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import dbetto
34
import hist
45
import numpy as np
56
import pytest
@@ -85,6 +86,19 @@ def test_eval_dependency():
8586
res = obj.eval("lgdo.Array([1,2,3])", {}, modules={"lgdo": lgdo})
8687
assert res == lgdo.Array([1, 2, 3])
8788

89+
# test with module returning np.array
90+
assert obj.eval("np.sum(a)", {}, modules={"np": np}).value == np.int64(10)
91+
8892
# check bad type
8993
with pytest.raises(RuntimeError):
9094
obj.eval("hist.Hist()", modules={"hist": hist})
95+
96+
# check impossible numexpr can still run
97+
assert np.allclose(
98+
obj.eval(
99+
"a*args.value",
100+
{"args": dbetto.AttrsDict({"value": 2})},
101+
modules={"lgdo": lgdo},
102+
).view_as("np"),
103+
[2, 4, 6, 8],
104+
)

0 commit comments

Comments
 (0)