Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix argmax and argmin #84

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fix argmax
  • Loading branch information
adityagoel4512 committed Oct 24, 2024
commit a09cac99abb72985aae773a05fd37b71fe5601b3
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -7,6 +7,12 @@
Changelog
=========

0.9.3 (unreleased)
------------------

- Reduced the number of unnecessary casts in :func:`ndonnx.argmax` and :func:`ndonnx.argmin`.


0.9.2 (2024-10-03)
------------------

Binary file modified docs/_static/classify_iris.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
64 changes: 30 additions & 34 deletions ndonnx/_core/_numericimpl.py
Original file line number Diff line number Diff line change
@@ -355,43 +355,39 @@ def matrix_transpose(self, x) -> ndx.Array:

@validate_core
def argmax(self, x, axis=None, keepdims=False):
if axis is None:
reshaped_x = ndx.reshape(x, [-1])._core()
if keepdims:
return from_corearray(
opx.reshape(
opx.arg_max(reshaped_x, axis=0, keepdims=False),
opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64),
)
)
else:
return from_corearray(
opx.reshape(
opx.arg_max(reshaped_x, axis=0, keepdims=False),
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims), [x])
out = via_upcast(
lambda x: opx.arg_max(
x,
axis=axis or 0,
keepdims=int(keepdims),
),
[ndx.reshape(x, [-1]) if axis is None else x],
cast_return=False,
int_dtype=ndx.int32,
float_dtype=ndx.float64,
)

while keepdims and out.ndim < x.ndim:
out = ndx.expand_dims(out, axis=0)
return out

@validate_core
def argmin(self, x, axis=None, keepdims=False):
if axis is None:
reshaped_x = ndx.reshape(x, [-1])._core()
if keepdims:
return from_corearray(
opx.reshape(
opx.arg_min(reshaped_x, axis=0, keepdims=False),
opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64),
)
)
else:
return from_corearray(
opx.reshape(
opx.arg_min(reshaped_x, axis=0, keepdims=False),
opx.const([], dtype=dtypes.int64),
)
)
return _via_i64_f64(lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), [x])
out = via_upcast(
lambda x: opx.arg_min(
x,
axis=axis or 0,
keepdims=int(keepdims),
),
[ndx.reshape(x, [-1]) if axis is None else x],
cast_return=False,
int_dtype=ndx.int32,
float_dtype=ndx.float64,
)

while keepdims and out.ndim < x.ndim:
out = ndx.expand_dims(out, axis=0)
return out

@validate_core
def nonzero(self, x) -> tuple[Array, ...]:
2 changes: 1 addition & 1 deletion ndonnx/_funcs.py
Original file line number Diff line number Diff line change
@@ -670,7 +670,7 @@ def argmax(x, axis=None, keepdims=False):

def argmin(x, axis=None, keepdims=False):
if (
out := x.dtype._ops.argmax(x, axis=axis, keepdims=keepdims)
out := x.dtype._ops.argmin(x, axis=axis, keepdims=keepdims)
) is not NotImplemented:
return out
raise UnsupportedOperationError(f"Unsupported operand type for argmin: '{x.dtype}'")
46 changes: 46 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
from __future__ import annotations

import re
import warnings

import numpy as np
import pytest
@@ -992,3 +993,48 @@ def test_no_unsafe_cumulative_sum_cast():
):
a = ndx.asarray([1, 2, 3], ndx.int32)
ndx.cumulative_sum(a, dtype=ndx.uint64)


@pytest.mark.parametrize("keepdims", [True, False])
@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)),
(np.argmax, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)),
(np.argmin, np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)),
],
)
def test_argmaxmin(func, x, keepdims):
np_result = func(x, keepdims=keepdims)
ndx_result = getattr(ndx, func.__name__)(
ndx.asarray(x), keepdims=keepdims
).to_numpy()
assert_array_equal(np_result, ndx_result)


# Pending ORT 1.19 conda-forge release before this becomes supported:
# https://github.com/conda-forge/onnxruntime-feedstock/pull/128
@pytest.mark.parametrize(
"func, x",
[
(np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
(np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)),
],
)
def test_argmaxmin_unsupported_kernels(func, x):
import onnxruntime as ort

if ort.__version__.startswith("19"):
warnings.warn(
"Please remove this test and update `argmax` and `argmin` to reflect expanded kernel support.",
Warning,
)

with pytest.raises(TypeError):
getattr(ndx, func.__name__)(ndx.asarray(x))
2 changes: 0 additions & 2 deletions xfails.txt
Original file line number Diff line number Diff line change
@@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit
array_api_tests/test_operators_and_elementwise_functions.py::test_sinh
array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt
array_api_tests/test_operators_and_elementwise_functions.py::test_tan
array_api_tests/test_searching_functions.py::test_argmax
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_searching_functions.py::test_where