From 64e10b5adb7dd3bb91ccd5cff0d6fd4e70c061c2 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 26 Jul 2024 07:35:52 -0700 Subject: [PATCH 1/3] some updates for numpy 2.0 and array api --- mlx/ops.cpp | 9 --------- python/src/array.cpp | 20 ++++++++++++++++++++ python/src/constants.cpp | 9 --------- python/tests/test_array.py | 6 ++++++ python/tests/test_constants.py | 18 ------------------ 5 files changed, 26 insertions(+), 36 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index f7801f24c1..a57a0df328 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -1839,15 +1839,6 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) { throw std::invalid_argument(msg.str()); } - // TODO: Fix GPU kernel - if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) { - std::ostringstream msg; - msg << "[argsort] GPU sort cannot handle sort axis of >= 2M elements," - << " got array with sort axis size " << a.shape(axis) << "." - << " Please place this operation on the CPU instead."; - throw std::runtime_error(msg.str()); - } - return array( a.shape(), uint32, std::make_shared(to_stream(s), axis), {a}); } diff --git a/python/src/array.cpp b/python/src/array.cpp index da70a114ed..3a42b39f69 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -294,6 +294,26 @@ void init_array(nb::module_& m) { Returns: array: The array with type ``dtype``. )pbdoc") + .def( + "__array_namespace__", + [](const array& a, const std::optional& api_version) { + if (api_version) { + throw std::invalid_argument( + "Explicitly specifying api_version is not yet implemented."); + } + return nb::module_::import_("mlx.core"); + }, + "api_version"_a = nb::none(), + R"pbdoc( + Used to apply updates at the given indices. + + Args: + api_version (str, optional): String representing the version + of the array API spec to return. Default: ``None``. + + Returns: + out (Any): An object representing the array API namespace. + )pbdoc") .def("__getitem__", mlx_get_item, nb::arg().none()) .def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg()) .def_prop_ro( diff --git a/python/src/constants.cpp b/python/src/constants.cpp index 17e9ccd69e..0d9de75b7d 100644 --- a/python/src/constants.cpp +++ b/python/src/constants.cpp @@ -6,18 +6,9 @@ namespace nb = nanobind; void init_constants(nb::module_& m) { - m.attr("Inf") = std::numeric_limits::infinity(); - m.attr("Infinity") = std::numeric_limits::infinity(); - m.attr("NAN") = NAN; - m.attr("NINF") = -std::numeric_limits::infinity(); - m.attr("NZERO") = -0.0; - m.attr("NaN") = NAN; - m.attr("PINF") = std::numeric_limits::infinity(); - m.attr("PZERO") = 0.0; m.attr("e") = 2.71828182845904523536028747135266249775724709369995; m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421; m.attr("inf") = std::numeric_limits::infinity(); - m.attr("infty") = std::numeric_limits::infinity(); m.attr("nan") = NAN; m.attr("newaxis") = nb::none(); m.attr("pi") = 3.1415926535897932384626433; diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 7e9b68fa5c..0144c34c59 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1828,6 +1828,12 @@ def test_setitem_with_list(self): anp[:, idx] = 4 self.assertTrue(np.array_equal(a, anp)) + def test_array_namespace(self): + a = mx.array(1.0) + api = a.__array_namespace__() + self.assertTrue(hasattr(api, "array")) + self.assertTrue(hasattr(api, "add")) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_constants.py b/python/tests/test_constants.py index 11a466e039..104e7522d7 100644 --- a/python/tests/test_constants.py +++ b/python/tests/test_constants.py @@ -10,14 +10,6 @@ class TestConstants(mlx_tests.MLXTestCase): def test_constants_values(self): # Check if mlx constants match expected values - self.assertAlmostEqual(mx.Inf, float("inf")) - self.assertAlmostEqual(mx.Infinity, float("inf")) - self.assertTrue(np.isnan(mx.NAN)) - self.assertAlmostEqual(mx.NINF, float("-inf")) - self.assertEqual(mx.NZERO, -0.0) - self.assertTrue(np.isnan(mx.NaN)) - self.assertAlmostEqual(mx.PINF, float("inf")) - self.assertEqual(mx.PZERO, 0.0) self.assertAlmostEqual( mx.e, 2.71828182845904523536028747135266249775724709369995 ) @@ -25,25 +17,15 @@ def test_constants_values(self): mx.euler_gamma, 0.5772156649015328606065120900824024310421 ) self.assertAlmostEqual(mx.inf, float("inf")) - self.assertAlmostEqual(mx.infty, float("inf")) self.assertTrue(np.isnan(mx.nan)) self.assertIsNone(mx.newaxis) self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433) def test_constants_availability(self): # Check if mlx constants are available - self.assertTrue(hasattr(mx, "Inf")) - self.assertTrue(hasattr(mx, "Infinity")) - self.assertTrue(hasattr(mx, "NAN")) - self.assertTrue(hasattr(mx, "NINF")) - self.assertTrue(hasattr(mx, "NaN")) - self.assertTrue(hasattr(mx, "PINF")) - self.assertTrue(hasattr(mx, "NZERO")) - self.assertTrue(hasattr(mx, "PZERO")) self.assertTrue(hasattr(mx, "e")) self.assertTrue(hasattr(mx, "euler_gamma")) self.assertTrue(hasattr(mx, "inf")) - self.assertTrue(hasattr(mx, "infty")) self.assertTrue(hasattr(mx, "nan")) self.assertTrue(hasattr(mx, "newaxis")) self.assertTrue(hasattr(mx, "pi")) From 8b5663ed90d9c14b69b87ba04a2b791e9aadefc7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 26 Jul 2024 08:11:21 -0700 Subject: [PATCH 2/3] some updates for numpy 2.0 and array api --- python/src/ops.cpp | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 0f92204735..eba8043f16 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2061,7 +2061,7 @@ void init_ops(nb::module_& m) { const std::optional>& axes, StreamOrDevice s) { if (axes.has_value()) { - return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s); + return transpose(a, *axes, s); } else { return transpose(a, s); } @@ -2083,6 +2083,26 @@ void init_ops(nb::module_& m) { Returns: array: The transposed array. )pbdoc"); + m.def( + "permute_dims", + [](const array& a, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value()) { + return transpose(a, *axes, s); + } else { + return transpose(a, s); + } + }, + nb::arg(), + "axes"_a = nb::none(), + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + See :func:`transpose`. + )pbdoc"); m.def( "sum", [](const array& a, @@ -2666,6 +2686,26 @@ void init_ops(nb::module_& m) { Returns: array: The concatenated array. )pbdoc"); + m.def( + "concat", + [](const std::vector& arrays, + std::optional axis, + StreamOrDevice s) { + if (axis) { + return concatenate(arrays, *axis, s); + } else { + return concatenate(arrays, s); + } + }, + nb::arg(), + "axis"_a.none() = 0, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + See :func:`concatenate`. + )pbdoc"); m.def( "stack", [](const std::vector& arrays, From af3a2d4fdf7b0a97e929afb79b6b67f22dfb8c43 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 26 Jul 2024 09:38:52 -0700 Subject: [PATCH 3/3] fix array api doc --- python/src/array.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 3a42b39f69..61641be88f 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -305,9 +305,12 @@ void init_array(nb::module_& m) { }, "api_version"_a = nb::none(), R"pbdoc( - Used to apply updates at the given indices. + Returns an object that has all the array API functions on it. + + See the `Python array API `_ + for more information. - Args: + Args: api_version (str, optional): String representing the version of the array API spec to return. Default: ``None``.