Skip to content

Commit

Permalink
Update to numpy 2.0 (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 authored Jun 24, 2024
1 parent 9cedf62 commit 25185a9
Show file tree
Hide file tree
Showing 13 changed files with 1,831 additions and 1,077 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:
- py310
- py311
- py312
- np1x
steps:
- name: Checkout branch
uses: actions/checkout@v4
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ It has a couple of key features:
- It implements the [`Array API`](https://data-apis.org/array-api/) standard. Standard compliant code can be executed without changes across numerous backends such as like `NumPy`, `JAX` and now `ndonnx`.

```python
import numpy.array_api as npx
import numpy as np
import ndonnx as ndx
from jax.experimental import array_api as jxp

def mean_drop_outliers(a, low=-5, high=5):
xp = a.__array_namespace__()
return xp.mean(a[(low < a) & (a < high)])

np_result = mean_drop_outliers(npx.asarray([-10, 0.5, 1, 5]))
np_result = mean_drop_outliers(np.asarray([-10, 0.5, 1, 5]))
jax_result = mean_drop_outliers(jxp.asarray([-10, 0.5, 1, 5]))
onnx_result = mean_drop_outliers(ndx.asarray([-10, 0.5, 1, 5]))

Expand Down
15 changes: 8 additions & 7 deletions docs/experimental/experimental.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,17 @@ Note that we will use the ``_experimental`` submodule.
author="myextensionlibrary",
)
def _parse_input(self, input: np.ndarray) -> dict:
def _parse_input(self, x: np.ndarray) -> dict:
# We accept numpy arrays of dtype datetime64[s] as input only.
if input.dtype == np.dtype("datetime64[s]"):
unix_timestamp = input.astype(np.int64)
if x.dtype == np.dtype("datetime64[s]"):
unix_timestamp = x.astype(np.int64)
return {
"unix_timestamp": self._fields()["unix_timestamp"]._parse_input(unix_timestamp),
}
else:
raise ValueError(f"Cannot parse input of dtype {input.dtype} to {self}")
raise ValueError(f"Cannot parse input of dtype {x.dtype} to {self}")
return {
"unix_timestamp": self._fields()["unix_timestamp"]._parse_input(unix_timestamp),
}
def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray:
return fields["unix_timestamp"].astype("datetime64[s]")
Expand Down
3 changes: 1 addition & 2 deletions docs/intros/gettingstarted.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,12 @@ Writing code in a strictly Array API compliant fashion makes it instantly reusab
import ndonnx as ndx
import numpy as np
import numpy.array_api as npx
def mean_drop_outliers(a, low=-5, high=5):
xp = a.__array_namespace__()
return xp.mean(a[(low < a) & (a < high)])
np_result = mean_drop_outliers(npx.asarray([-10, 0.5, 1, 4]))
np_result = mean_drop_outliers(np.asarray([-10, 0.5, 1, 4]))
onnx_result = mean_drop_outliers(ndx.asarray([-10, 0.5, 1, 4]))
np.testing.assert_equal(np_result, onnx_result.to_numpy())
Expand Down
10 changes: 5 additions & 5 deletions ndonnx/_data_types/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ class _NullableCore(Nullable[CoreType], CastMixin):
def copy(self) -> Self:
return self

def _parse_input(self, input: np.ndarray) -> dict:
if not isinstance(input, np.ma.MaskedArray):
raise TypeError(f"Expected numpy MaskedArray, got {type(input)}")
def _parse_input(self, x: np.ndarray) -> dict:
if not isinstance(x, np.ma.MaskedArray):
raise TypeError(f"Expected numpy MaskedArray, got {type(x)}")
return {
"values": self.values._parse_input(input.data),
"null": self.null._parse_input(input.mask),
"values": self.values._parse_input(x.data),
"null": self.null._parse_input(x.mask),
}

def _assemble_output(self, fields: dict[str, np.ndarray]) -> np.ndarray:
Expand Down
10 changes: 5 additions & 5 deletions ndonnx/_data_types/structtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _fields(self) -> dict[str, StructType | CoreType]:
...

@abstractmethod
def _parse_input(self, input: np.ndarray) -> dict:
def _parse_input(self, x: np.ndarray) -> dict:
"""This method may be used by runtime utilities. It should break down the input
numpy array into a dictionary of values with an entry for each field in this
StructType.
Expand All @@ -44,14 +44,14 @@ def __init__(self):
.. code-block:: python
def _parse_input(self, input: np.ndarray) -> dict[str, np.ndarray]:
def _parse_input(self, x: np.ndarray) -> dict[str, np.ndarray]:
# We expect a numpy array of python int objects.
if input.dtype != object:
if x.dtype != object:
raise TypeError("Input must be an object array")
mask = (1 << 64) - 1
return {
"low": (input & mask).astype(np.uint64),
"high": (input >> 64).astype(np.uint64) & mask,
"low": (x & mask).astype(np.uint64),
"high": (x >> 64).astype(np.uint64) & mask,
}
"""
...
Expand Down
Loading

0 comments on commit 25185a9

Please sign in to comment.