Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array api #1289

Merged
merged 3 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1839,15 +1839,6 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str());
}

// TODO: Fix GPU kernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {
std::ostringstream msg;
msg << "[argsort] GPU sort cannot handle sort axis of >= 2M elements,"
<< " got array with sort axis size " << a.shape(axis) << "."
<< " Please place this operation on the CPU instead.";
throw std::runtime_error(msg.str());
}

return array(
a.shape(), uint32, std::make_shared<ArgSort>(to_stream(s), axis), {a});
}
Expand Down
23 changes: 23 additions & 0 deletions python/src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,29 @@ void init_array(nb::module_& m) {
Returns:
array: The array with type ``dtype``.
)pbdoc")
.def(
"__array_namespace__",
[](const array& a, const std::optional<std::string>& api_version) {
if (api_version) {
throw std::invalid_argument(
"Explicitly specifying api_version is not yet implemented.");
}
return nb::module_::import_("mlx.core");
},
"api_version"_a = nb::none(),
R"pbdoc(
Returns an object that has all the array API functions on it.

See the `Python array API <https://data-apis.org/array-api/latest/index.html>`_
for more information.

Args:
api_version (str, optional): String representing the version
of the array API spec to return. Default: ``None``.

Returns:
out (Any): An object representing the array API namespace.
)pbdoc")
.def("__getitem__", mlx_get_item, nb::arg().none())
.def("__setitem__", mlx_set_item, nb::arg().none(), nb::arg())
.def_prop_ro(
Expand Down
9 changes: 0 additions & 9 deletions python/src/constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,9 @@
namespace nb = nanobind;

void init_constants(nb::module_& m) {
m.attr("Inf") = std::numeric_limits<double>::infinity();
m.attr("Infinity") = std::numeric_limits<double>::infinity();
m.attr("NAN") = NAN;
m.attr("NINF") = -std::numeric_limits<double>::infinity();
m.attr("NZERO") = -0.0;
m.attr("NaN") = NAN;
m.attr("PINF") = std::numeric_limits<double>::infinity();
m.attr("PZERO") = 0.0;
m.attr("e") = 2.71828182845904523536028747135266249775724709369995;
m.attr("euler_gamma") = 0.5772156649015328606065120900824024310421;
m.attr("inf") = std::numeric_limits<double>::infinity();
m.attr("infty") = std::numeric_limits<double>::infinity();
m.attr("nan") = NAN;
m.attr("newaxis") = nb::none();
m.attr("pi") = 3.1415926535897932384626433;
Expand Down
42 changes: 41 additions & 1 deletion python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2061,7 +2061,7 @@ void init_ops(nb::module_& m) {
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value()) {
return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s);
return transpose(a, *axes, s);
} else {
return transpose(a, s);
}
Expand All @@ -2083,6 +2083,26 @@ void init_ops(nb::module_& m) {
Returns:
array: The transposed array.
)pbdoc");
m.def(
"permute_dims",
[](const array& a,
const std::optional<std::vector<int>>& axes,
StreamOrDevice s) {
if (axes.has_value()) {
return transpose(a, *axes, s);
} else {
return transpose(a, s);
}
},
nb::arg(),
"axes"_a = nb::none(),
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def permute_dims(a: array, /, axes: Optional[Sequence[int]] = None, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
See :func:`transpose`.
)pbdoc");
m.def(
"sum",
[](const array& a,
Expand Down Expand Up @@ -2666,6 +2686,26 @@ void init_ops(nb::module_& m) {
Returns:
array: The concatenated array.
)pbdoc");
m.def(
"concat",
[](const std::vector<array>& arrays,
std::optional<int> axis,
StreamOrDevice s) {
if (axis) {
return concatenate(arrays, *axis, s);
} else {
return concatenate(arrays, s);
}
},
nb::arg(),
"axis"_a.none() = 0,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def concat(arrays: List[array], axis: Optional[int] = 0, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
See :func:`concatenate`.
)pbdoc");
m.def(
"stack",
[](const std::vector<array>& arrays,
Expand Down
6 changes: 6 additions & 0 deletions python/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,12 @@ def test_setitem_with_list(self):
anp[:, idx] = 4
self.assertTrue(np.array_equal(a, anp))

def test_array_namespace(self):
a = mx.array(1.0)
api = a.__array_namespace__()
self.assertTrue(hasattr(api, "array"))
self.assertTrue(hasattr(api, "add"))


if __name__ == "__main__":
unittest.main()
18 changes: 0 additions & 18 deletions python/tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,40 +10,22 @@
class TestConstants(mlx_tests.MLXTestCase):
def test_constants_values(self):
# Check if mlx constants match expected values
self.assertAlmostEqual(mx.Inf, float("inf"))
self.assertAlmostEqual(mx.Infinity, float("inf"))
self.assertTrue(np.isnan(mx.NAN))
self.assertAlmostEqual(mx.NINF, float("-inf"))
self.assertEqual(mx.NZERO, -0.0)
self.assertTrue(np.isnan(mx.NaN))
self.assertAlmostEqual(mx.PINF, float("inf"))
self.assertEqual(mx.PZERO, 0.0)
self.assertAlmostEqual(
mx.e, 2.71828182845904523536028747135266249775724709369995
)
self.assertAlmostEqual(
mx.euler_gamma, 0.5772156649015328606065120900824024310421
)
self.assertAlmostEqual(mx.inf, float("inf"))
self.assertAlmostEqual(mx.infty, float("inf"))
self.assertTrue(np.isnan(mx.nan))
self.assertIsNone(mx.newaxis)
self.assertAlmostEqual(mx.pi, 3.1415926535897932384626433)

def test_constants_availability(self):
# Check if mlx constants are available
self.assertTrue(hasattr(mx, "Inf"))
self.assertTrue(hasattr(mx, "Infinity"))
self.assertTrue(hasattr(mx, "NAN"))
self.assertTrue(hasattr(mx, "NINF"))
self.assertTrue(hasattr(mx, "NaN"))
self.assertTrue(hasattr(mx, "PINF"))
self.assertTrue(hasattr(mx, "NZERO"))
self.assertTrue(hasattr(mx, "PZERO"))
self.assertTrue(hasattr(mx, "e"))
self.assertTrue(hasattr(mx, "euler_gamma"))
self.assertTrue(hasattr(mx, "inf"))
self.assertTrue(hasattr(mx, "infty"))
self.assertTrue(hasattr(mx, "nan"))
self.assertTrue(hasattr(mx, "newaxis"))
self.assertTrue(hasattr(mx, "pi"))
Expand Down