Skip to content

Commit

Permalink
Clean up some tests and drop numpy 1 tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 16, 2025
1 parent bcde9f0 commit d00f9ad
Show file tree
Hide file tree
Showing 9 changed files with 7,476 additions and 10,505 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
- py310
- py311
- py312
- np1x
- py313
steps:
- name: Checkout branch
uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down
8 changes: 4 additions & 4 deletions ndonnx/_core/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

Expand Down Expand Up @@ -168,7 +168,7 @@ def via_upcast(
via_dtype = int_dtype
else:
raise TypeError(
f"Can't upcast unsigned type `{dtype}`. Available implementations are for `{*available_types,}`"
f"Can't upcast unsigned type `{dtype}`. Available implementations are for `{(*available_types,)}`"
)
elif isinstance(dtype, (dtypes.Integral, dtypes.NullableIntegral)) or dtype in (
dtypes.nbool,
Expand All @@ -178,7 +178,7 @@ def via_upcast(
via_dtype = int_dtype
else:
raise TypeError(
f"Can't upcast signed type `{dtype}`. Available implementations are for `{*available_types,}`"
f"Can't upcast signed type `{dtype}`. Available implementations are for `{(*available_types,)}`"
)
elif isinstance(dtype, (dtypes.Floating, dtypes.NullableFloating)):
if (
Expand All @@ -188,7 +188,7 @@ def via_upcast(
via_dtype = float_dtype
else:
raise TypeError(
f"Can't upcast float type `{dtype}`. Available implementations are for `{*available_types,}`"
f"Can't upcast float type `{dtype}`. Available implementations are for `{(*available_types,)}`"
)
else:
raise TypeError(f"Expected numerical data type, found {dtype}")
Expand Down
17,887 changes: 7,462 additions & 10,425 deletions pixi.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ python = "3.10.*"
python = "3.11.*"
[feature.py312.dependencies]
python = "3.12.*"
[feature.np1x.dependencies]
python = "3.11.*"
numpy = "1.*"
[feature.py313.dependencies]
python = "3.13.*"


[environments]
default = ["test", "lint"]
py310 = ["py310", "test"]
py311 = ["py311", "test"]
py312 = ["py312", "test"]
np1x = ["np1x", "test"]
py313 = ["py313", "test"]
docs = ["docs"]
build = ["build"]
lint = { features = ["lint"], no-default-feature = true }
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
requires-python = ">=3.10"
dependencies = ["numpy", "spox>=0.10", "typing_extensions"]
Expand Down
56 changes: 1 addition & 55 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
Expand Down Expand Up @@ -554,60 +554,6 @@ def test_all(x):
assert_array_equal(ndx.all(x).to_numpy(), np.all(x.to_numpy()))


@pytest.mark.parametrize(
"side",
[
"left",
"right",
],
)
def test_searchsorted(side):
a_val = [0, 1, 2, 5, 5, 6, 10, 15]
b_val = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 10, 15, 20, 20]
c_val = np.searchsorted(a_val, b_val, side)

a = ndx.asarray(a_val, dtype=ndx.int64)
b = ndx.asarray(b_val, dtype=ndx.int64)
c = ndx.searchsorted(a, b, side=side)
assert_array_equal(c_val, c.to_numpy())


@pytest.mark.skip(reason="TODO: onnxruntime")
@pytest.mark.parametrize(
"side",
[
"left",
"right",
],
)
def test_searchsorted_nans(side):
a_val = np.array([0, 1, 2, 5, 5, 6, 10, 15, np.nan])
b_val = np.array([0, 1, 2, np.nan, np.nan])
c_val = np.searchsorted(a_val, b_val, side)

a = ndx.array(shape=(len(a_val),), dtype=ndx.float64)
b = ndx.array(shape=(len(b_val),), dtype=ndx.float64)
c = ndx.searchsorted(a, b, side=side)

model = ndx.build({"a": a, "b": b}, {"c": c})

assert_array_equal(c_val, run(model, dict(a=a_val, b=b_val))["c"])


def test_searchsorted_raises():
with pytest.raises(TypeError):
a = ndx.array(shape=(), dtype=ndx.int64)
b = ndx.array(shape=(), dtype=ndx.float64)

ndx.searchsorted(a, b)

with pytest.raises(ValueError):
a = ndx.array(shape=(3,), dtype=ndx.int64)
b = ndx.array(shape=(3,), dtype=ndx.int64)

ndx.searchsorted(a, b, side="middle") # type: ignore[arg-type]


def test_truediv():
x = ndx.asarray([1, 2, 3], dtype=ndx.int64)
y = ndx.asarray([2, 3, 3], dtype=ndx.int64)
Expand Down
13 changes: 0 additions & 13 deletions tests/test_include.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/test_iter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

import pytest
Expand All @@ -10,7 +10,7 @@ def test_iter_for_loop():
n = 5
a = ndx.array(shape=(n,), dtype=ndx.int64)

for i, el in enumerate(a): # type: ignore
for i, el in enumerate(a):
assert isinstance(el, ndx.Array)
if i > n:
assert False, "Iterated past the number of elements"
Expand Down

0 comments on commit d00f9ad

Please sign in to comment.