Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test clean up and add python 3.13 test feature #98

Merged
merged 6 commits into from
Jan 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ jobs:
- py310
- py311
- py312
- py313
- np1x
steps:
- name: Checkout branch
8 changes: 4 additions & 4 deletions ndonnx/_core/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

@@ -168,7 +168,7 @@ def via_upcast(
via_dtype = int_dtype
else:
raise TypeError(
f"Can't upcast unsigned type `{dtype}`. Available implementations are for `{*available_types,}`"
f"Can't upcast unsigned type `{dtype}`. Available implementations are for `{(*available_types,)}`"
)
elif isinstance(dtype, (dtypes.Integral, dtypes.NullableIntegral)) or dtype in (
dtypes.nbool,
@@ -178,7 +178,7 @@ def via_upcast(
via_dtype = int_dtype
else:
raise TypeError(
f"Can't upcast signed type `{dtype}`. Available implementations are for `{*available_types,}`"
f"Can't upcast signed type `{dtype}`. Available implementations are for `{(*available_types,)}`"
)
elif isinstance(dtype, (dtypes.Floating, dtypes.NullableFloating)):
if (
@@ -188,7 +188,7 @@ def via_upcast(
via_dtype = float_dtype
else:
raise TypeError(
f"Can't upcast float type `{dtype}`. Available implementations are for `{*available_types,}`"
f"Can't upcast float type `{dtype}`. Available implementations are for `{(*available_types,)}`"
)
else:
raise TypeError(f"Expected numerical data type, found {dtype}")
4 changes: 2 additions & 2 deletions ndonnx/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations
@@ -85,7 +85,7 @@ def asarray(
ret = x.copy() if copy is True else x

if dtype is not None:
ret = ret.astype(dtype)
ret = astype(ret, dtype=dtype, copy=copy)

return ret

19,512 changes: 9,286 additions & 10,226 deletions pixi.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
@@ -74,6 +74,8 @@ python = "3.10.*"
python = "3.11.*"
[feature.py312.dependencies]
python = "3.12.*"
[feature.py313.dependencies]
python = "3.13.*"
[feature.np1x.dependencies]
python = "3.11.*"
numpy = "1.*"
@@ -83,6 +85,7 @@ default = ["test", "lint"]
py310 = ["py310", "test"]
py311 = ["py311", "test"]
py312 = ["py312", "test"]
py313 = ["py313", "test"]
np1x = ["np1x", "test"]
docs = ["docs"]
build = ["build"]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
requires-python = ">=3.10"
dependencies = ["numpy", "spox>=0.10", "typing_extensions"]
4 changes: 2 additions & 2 deletions tests/test_additional.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

import sys
@@ -155,7 +155,7 @@ def test_static_map_unimplemented_for_nullable():


@pytest.mark.skipif(
sys.platform.startswith("win"),
sys.platform.startswith("win") and np.__version__ < "2",
reason="ORT 1.18 not registering LabelEncoder(4) only on Windows.",
)
def test_isin():
54 changes: 0 additions & 54 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -554,60 +554,6 @@ def test_all(x):
assert_array_equal(ndx.all(x).to_numpy(), np.all(x.to_numpy()))


@pytest.mark.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are those tests redundant?

Copy link
Member Author

@adityagoel4512 adityagoel4512 Jan 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/Quantco/ndonnx/blob/main/tests/test_additional.py#L15-L66 (note that the link is test_*additional*.py).

"side",
[
"left",
"right",
],
)
def test_searchsorted(side):
a_val = [0, 1, 2, 5, 5, 6, 10, 15]
b_val = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 10, 15, 20, 20]
c_val = np.searchsorted(a_val, b_val, side)

a = ndx.asarray(a_val, dtype=ndx.int64)
b = ndx.asarray(b_val, dtype=ndx.int64)
c = ndx.searchsorted(a, b, side=side)
assert_array_equal(c_val, c.to_numpy())


@pytest.mark.skip(reason="TODO: onnxruntime")
@pytest.mark.parametrize(
"side",
[
"left",
"right",
],
)
def test_searchsorted_nans(side):
a_val = np.array([0, 1, 2, 5, 5, 6, 10, 15, np.nan])
b_val = np.array([0, 1, 2, np.nan, np.nan])
c_val = np.searchsorted(a_val, b_val, side)

a = ndx.array(shape=(len(a_val),), dtype=ndx.float64)
b = ndx.array(shape=(len(b_val),), dtype=ndx.float64)
c = ndx.searchsorted(a, b, side=side)

model = ndx.build({"a": a, "b": b}, {"c": c})

assert_array_equal(c_val, run(model, dict(a=a_val, b=b_val))["c"])


def test_searchsorted_raises():
with pytest.raises(TypeError):
a = ndx.array(shape=(), dtype=ndx.int64)
b = ndx.array(shape=(), dtype=ndx.float64)

ndx.searchsorted(a, b)

with pytest.raises(ValueError):
a = ndx.array(shape=(3,), dtype=ndx.int64)
b = ndx.array(shape=(3,), dtype=ndx.int64)

ndx.searchsorted(a, b, side="middle") # type: ignore[arg-type]


def test_truediv():
x = ndx.asarray([1, 2, 3], dtype=ndx.int64)
y = ndx.asarray([2, 3, 3], dtype=ndx.int64)
13 changes: 0 additions & 13 deletions tests/test_include.py

This file was deleted.

7 changes: 3 additions & 4 deletions tests/test_iter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) QuantCo 2023-2024
# Copyright (c) QuantCo 2023-2025
# SPDX-License-Identifier: BSD-3-Clause

import pytest
@@ -10,10 +10,9 @@ def test_iter_for_loop():
n = 5
a = ndx.array(shape=(n,), dtype=ndx.int64)

for i, el in enumerate(a): # type: ignore
for i, el in enumerate(a):
assert isinstance(el, ndx.Array)
if i > n:
assert False, "Iterated past the number of elements"
assert i < n, "Iterated past the number of elements"


@pytest.mark.parametrize(
Loading
Oops, something went wrong.