Skip to content

Commit

Permalink
Array api (#1289)
Browse files Browse the repository at this point in the history
* some updates for numpy 2.0 and array api

* some updates for numpy 2.0 and array api

* fix array api doc
  • Loading branch information
awni authored Jul 26, 2024
1 parent e9e5385 commit 7b456fd
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 37 deletions.
9 changes: 0 additions & 9 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1839,15 +1839,6 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str());
}

// TODO: Fix GPU kernel
if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {
std::ostringstream msg;
msg << "[argsort] GPU sort cannot handle sort axis of >= 2M elements,"
<< " got array with sort axis size " << a.shape(axis) << "."
<< " Please place this operation on the CPU instead.";
throw std::runtime_error(msg.str());
}

return array(
a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});
}
Expand Down
23 changes: 23 additions & 0 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,29 @@ void init_array(nb::module_& m) {
Returns:
array: The array with type ``dtype``.
)pbdoc")
.def(
"__array_namespace__",
[](const array& a, const std::optional<std::string>& api_version) {
if (api_version) {
throw std::invalid_argument(
"Explicitly specifying api_version is not yet implemented.");
}
return nb::module_::import_("mlx.core");
},
"api_version"_a = nb::none(),
R"pbdoc(
Returns an object that has all the array API functions on it.
See the `Python array API <https://data-apis.org/array-api/latest/index.html>`_
for more information.
Args:
api_version (str, optional): String representing the version
of the array API spec to return. Default: ``None``.
Returns:
out (Any): An object representing the array API namespace.
)pbdoc")
.def("__getitem__", mlx_get_item, nb::arg().none())
.def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg())
.def_prop_ro(
Expand Down
9 changes: 0 additions & 9 deletions python/src/constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,9 @@
namespace nb = nanobind;

void init_constants(nb::module_& m) {
m.attr("Inf") = std::numeric_limits<double>::infinity();
m.attr("Infinity") = std::numeric_limits<double>::infinity();
m.attr("NAN") = NAN;
m.attr("NINF") = -std::numeric_limits<double>::infinity();
m.attr("NZERO") = -0.0;
m.attr("NaN") = NAN;
m.attr("PINF") = std::numeric_limits<double>::infinity();
m.attr("PZERO") = 0.0;
m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
m.attr("inf") = std::numeric_limits<double>::infinity();
m.attr("infty") = std::numeric_limits<double>::infinity();
m.attr("nan") = NAN;
m.attr("newaxis") = nb::none();
m.attr("pi") = 3.1415926535897932384626433;
Expand Down
42 changes: 41 additions & 1 deletion python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2061,7 +2061,7 @@ void init_ops(nb::module_& m) {
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value()) {
return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s);
return transpose(a, *axes, s);
} else {
return transpose(a, s);
}
Expand All @@ -2083,6 +2083,26 @@ void init_ops(nb::module_& m) {
Returns:
array: The transposed array.
)pbdoc");
m.def(
"permute_dims",
[](const array& a,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value()) {
return transpose(a, *axes, s);
} else {
return transpose(a, s);
}
},
nb::arg(),
"axes"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
See :func:`transpose`.
)pbdoc");
m.def(
"sum",
[](const array& a,
Expand Down Expand Up @@ -2666,6 +2686,26 @@ void init_ops(nb::module_& m) {
Returns:
array: The concatenated array.
)pbdoc");
m.def(
"concat",
[](const std::vector<array>& arrays,
std::optional<int> axis,
StreamOrDevice s) {
if (axis) {
return concatenate(arrays, *axis, s);
} else {
return concatenate(arrays, s);
}
},
nb::arg(),
"axis"_a.none() = 0,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
See :func:`concatenate`.
)pbdoc");
m.def(
"stack",
[](const std::vector<array>& arrays,
Expand Down
6 changes: 6 additions & 0 deletions python/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,12 @@ def test_setitem_with_list(self):
anp[:, idx] = 4
self.assertTrue(np.array_equal(a, anp))

def test_array_namespace(self):
a = mx.array(1.0)
api = a.__array_namespace__()
self.assertTrue(hasattr(api, "array"))
self.assertTrue(hasattr(api, "add"))


if __name__ == "__main__":
unittest.main()
18 changes: 0 additions & 18 deletions python/tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,22 @@
class TestConstants(mlx_tests.MLXTestCase):
def test_constants_values(self):
# Check if mlx constants match expected values
self.assertAlmostEqual(mx.Inf, float("inf"))
self.assertAlmostEqual(mx.Infinity, float("inf"))
self.assertTrue(np.isnan(mx.NAN))
self.assertAlmostEqual(mx.NINF, float("-inf"))
self.assertEqual(mx.NZERO, -0.0)
self.assertTrue(np.isnan(mx.NaN))
self.assertAlmostEqual(mx.PINF, float("inf"))
self.assertEqual(mx.PZERO, 0.0)
self.assertAlmostEqual(
mx.e, 2.71828182845904523536028747135266249775724709369995
)
self.assertAlmostEqual(
mx.euler_gamma, 0.5772156649015328606065120900824024310421
)
self.assertAlmostEqual(mx.inf, float("inf"))
self.assertAlmostEqual(mx.infty, float("inf"))
self.assertTrue(np.isnan(mx.nan))
self.assertIsNone(mx.newaxis)
self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433)

def test_constants_availability(self):
# Check if mlx constants are available
self.assertTrue(hasattr(mx, "Inf"))
self.assertTrue(hasattr(mx, "Infinity"))
self.assertTrue(hasattr(mx, "NAN"))
self.assertTrue(hasattr(mx, "NINF"))
self.assertTrue(hasattr(mx, "NaN"))
self.assertTrue(hasattr(mx, "PINF"))
self.assertTrue(hasattr(mx, "NZERO"))
self.assertTrue(hasattr(mx, "PZERO"))
self.assertTrue(hasattr(mx, "e"))
self.assertTrue(hasattr(mx, "euler_gamma"))
self.assertTrue(hasattr(mx, "inf"))
self.assertTrue(hasattr(mx, "infty"))
self.assertTrue(hasattr(mx, "nan"))
self.assertTrue(hasattr(mx, "newaxis"))
self.assertTrue(hasattr(mx, "pi"))
Expand Down

0 comments on commit 7b456fd

Please sign in to comment.