Skip to content

Commit

Permalink
Handle numpy's new shape typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Jan 30, 2025
1 parent 5f62c5b commit a5a1534
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
6 changes: 5 additions & 1 deletion py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,11 @@ def _np_ndarray_component_type(t: type) -> Optional[type]:
elif nargs == 2: # for npt.NDArray[np.int64], etc.
a0 = t.__args__[0]
a1 = t.__args__[1]
if a0 == typing.Any and isinstance(a1, types.GenericAlias): # novermin
# a0 is typing.Any before numpy 2.2.0 or a generic alias of tuple[int, ...] in numpy 2.2.0+. The latter
# is to support shape typing for numpy arrays. e.g. np.ndarray[tuple[Literal[2], Literal[3]], np.int32]
# is a 2x3 array of int32.
if ((a0 == typing.Any or (isinstance(a0, types.GenericAlias) and a0.__origin__ is tuple))
and isinstance(a1, types.GenericAlias)): # novermin
component_type = a1.__args__[0]
return component_type

Expand Down
13 changes: 11 additions & 2 deletions py/server/tests/test_udf_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import datetime
import typing
import unittest
from typing import List, Union, Tuple, Sequence, Optional
from typing import List, Union, Optional

import numba as nb
import numpy as np
import numpy.typing as npt
import pandas as pd
Expand Down Expand Up @@ -206,6 +205,7 @@ def f4557_1(x, y) -> np.ndarray[np.int64]:
return np.array(x) + y

# Testing https://github.com/deephaven/deephaven-core/issues/4562
import numba as nb
@nb.guvectorize([(nb.int32[:], nb.int32, nb.int32[:])], "(m),()->(m)", nopython=True)
def f4562_1(x, y, res):
res[:] = x + y
Expand Down Expand Up @@ -321,6 +321,15 @@ def udf() -> List[dtypes.Instant]:
t = empty_table(10).update(["X1 = udf()"])
self.assertEqual(t.columns[0].data_type, dtypes.instant_array)

def test_alternative_np_typehint(self):
import numpy.typing as npt

def f() -> npt.NDArray[np.int64]:
return np.array([1, 2], dtype=np.int64)

t = empty_table(10).update(["X1 = f()"])
self.assertEqual(t.columns[0].data_type, dtypes.long_array)


if __name__ == '__main__':
unittest.main()

0 comments on commit a5a1534

Please sign in to comment.