-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Array api #1289
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is so nice! Excited to use it. 🚀
python/src/array.cpp
Outdated
}, | ||
"api_version"_a = nb::none(), | ||
R"pbdoc( | ||
Used to apply updates at the given indices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: doc string? Took me a sec to figure out this is how einops knows to call mx.permute_dims
etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that's a copy typo, good catch.
@@ -1839,15 +1839,6 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) { | |||
throw std::invalid_argument(msg.str()); | |||
} | |||
|
|||
// TODO: Fix GPU kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
great to see progress towards gh-48! FYI, there has been interest in adding a fully-compliant wrapper to array-api-compat in data-apis/array-api-compat#162. |
Cool, thanks for sharing! If there's anything we can do to help on the MLX side let us know. |
The main thing that would help is if you could let us know what your plans are for MLX and the array API. If you're planning to have full compliance in MLX itself, then there's not much that we need to do in array-api-compat, other than documentation and possibly adding helper functions. This is similar to how JAX and sparse work. On the other hand, if there are some functionalities that you won't be able to support for backwards compatibility, we should add wrappers to array-api-compat. (if there are functions that you can't support because they simply aren't supported by the hardware, then there's no way to wrap that and we just have to document it) |
Looking through the array API I think for 95+% of it we already have an implementation or would like to have one. There are a few exceptions that I noticed:
Let me know what you think about that. |
I believe there are also |
Good point I missed that. And we don’t plan to add them as they are not natively supported for Apple GPUs. |
The lack of I think how to deal with the fact that mlx is designed with a "native" unified memory experience (I don't know a better short way to describe it :D) will be interesting to figure out. On the face of it I think mlx could do what Numpy does, have a Maybe at the level of array API standard compatibility it isn't such a big deal, but as a "array consumer" in scikit-learn I am currently still scratching my head how we'd write code that works with mlx and (say) PyTorch inputs. Without a bunch of |
The default device is the GPU so you wouldn't need to place any operations explicitly unless you wanted to change the default. To avoid needing a bunch of if statements, one possibility is to use a stream context manager e.g. with mx.stream(mx.cpu):
# call the function / do the work For the most part though running everything on the GPU works pretty well. I'm not sure how you feel about using that as the default though.. |
You might want to use the test suite to see what is missing in more detail (although we may need to update it so that
This is similar to NumPy where there is only one device ('cpu'). NumPy added these methods but they don't really do anything.
Are you saying that things like If you don't have a real "empty" function you could always alias
This is OK. The standard explicitly states that this functionality isn't required (see https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html). Note that the test suite doesn't yet have a way to disable these tests, so you'll have to just ignore those failures. The important thing for the array API is consistency across libraries so that people can write one function that works with all array libraries. So if a Correct me if I am wrong, but it sounds like you are wanting to implement full compatibility in MLX (excluding data-dependent operations and missing dtypes). If that's the case, there is no need to implement wrappers in array-api-compat. It also sounds like an |
Indeed I wasn't aware of it. Good idea.
No they work as expected. They just don't give any guarantees about the underlying memory buffer being unchanged.
Seems reasonable!
👍 |
Maybe something like device returning 'uma' would make sense, as it explicitly conveys that data resides on the unified memory architecture. |
I was trying to use MLX with scikit learn through the array api back-end. It failed because MLX doesn't have It would be super cool if we could get MLX to work with Scikit-learn.. but I'm not sure if there is a workaround for the Here's the example I tried for reference: from sklearn.datasets import make_classification
from sklearn import config_context
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import mlx.core as mx
X_np, y_np = make_classification(random_state=0)
X_mx = mx.array(X_np)
y_mx = mx.array(y_np)
with config_context(array_api_dispatch=True):
lda = LinearDiscriminantAnalysis()
X_trans = lda.fit_transform(X_mx, y_mx)
print(type(X_trans)) |
Do you just get an error message like " If so, and I'm not saying this is a good solution, things should run if you wrap MLX such that |
Yea exactly. And yes that's a good suggestion for a patch.. though long term it might be worth thinking about how to deal with that in the array API / compatibility layer. |
The functionality exists to query which dtypes a library supports: https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.__array_namespace_info__.html Indeed, based on data-apis/array-api#640 (comment), it sounds like the spec will be relaxed at some point to accommodate for missing dtypes. Consumers being able to tell which dtypes they can use is possible via this introspection method, but we're going to want a helper where we can say:
Probably something that should belong in array-api-compat, however I can't immediately see how we would want this to work. Most dtype selections happen through
It could get quite complex if there are varying policies for different parts of a library, like if we are perfectly happy to fallback from float64 to float32 in some places, but not others. I suppose this case in itself is not too bad, but it would be pretty complex to generalize for weirder (based on the standards of today...) sets of implemented dtypes. There is also the case of testing. It doesn't seem ideal to have to add a line to skip based on whether the dtype is supported in every test which is parameterized by standard dtypes, but I suppose that would work. |
@awni could you open an issue on the scikit-learn repository with your code snippet and a few words of context (for those who haven't followed this thread)? That would be a great to forge some links between the two projects. The reason it fails is that we currently special case torch (on MPS) as it is the only namespace without float64. Having another library with full array API support and a missing dtype would be a good reason to work on making this more general (and eventually using the introspection mechanism). |
Yes, done! |
Some updates to match the array api and remove deprecated constants.
These updates allow
mx.array
to be used with einops: