From c308230ace6aedf217cdd055fd4f434dda21e219 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Thu, 24 Oct 2024 22:09:09 +0100 Subject: [PATCH] Fix argmax --- tests/test_core.py | 16 ++++++++++++++++ xfails.txt | 2 -- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 5063954..1aa9242 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -992,3 +992,19 @@ def test_no_unsafe_cumulative_sum_cast(): ): a = ndx.asarray([1, 2, 3], ndx.int32) ndx.cumulative_sum(a, dtype=ndx.uint64) + + +@pytest.mark.parametrize( + "func, x", + [ + (ndx.argmax, np.array([1, 2, 3, 4, 5], dtype=ndx.int32)), + (ndx.argmax, np.array([1, 2, 3, 4, 5], dtype=ndx.float32)), + (ndx.argmin, np.array([1, 2, 3, 4, 5], dtype=ndx.float32)), + (ndx.argmin, np.array([1, 2, 3, 4, 5], dtype=ndx.float32)), + ], +) +def test_argmaxmin(func, x): + ndx_result = func(x) + np_result = getattr(np, func.__name__)(x) + breakpoint() + np.testing.assert_equal(ndx_result.to_numpy(), np_result) diff --git a/xfails.txt b/xfails.txt index 69431ee..9665406 100644 --- a/xfails.txt +++ b/xfails.txt @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit array_api_tests/test_operators_and_elementwise_functions.py::test_sinh array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_tan -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_searching_functions.py::test_where