Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array api #1289

Merged
merged 3 commits into from
Jul 26, 2024
Merged

Array api #1289

merged 3 commits into from
Jul 26, 2024

Conversation

awni
Copy link
Member

@awni awni commented Jul 26, 2024

Some updates to match the array api and remove deprecated constants.

These updates allow mx.array to be used with einops:

import mlx.core as mx
from einops.array_api import rearrange, reduce, repeat, pack, unpack

# rearrange elements according to the pattern
input_tensor = mx.zeros((2,3,4))
output_tensor = rearrange(input_tensor, 't b c -> b c t')
print(output_tensor.shape)

# combine rearrangement and reduction
input_tensor = mx.zeros((2, 3, 2*3, 2*4))
output_tensor = reduce(input_tensor, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2)
print(output_tensor.shape)

# copy along a new axis
input_tensor = mx.zeros((2, 3))
output_tensor = repeat(input_tensor, 'h w -> h w c', c=3)
print(output_tensor.shape)

# pack and unpack allow reversibly 'packing' multiple tensors into one.
# Packed tensors may be of different dimensionality:
h, w = 100, 200
image_rgb = mx.zeros([h, w, 3])
image_depth = mx.zeros([h, w])
# but we can stack them
image_rgbd, ps = pack([image_rgb, image_depth], 'h w *')
print(image_rgbd.shape)
print(ps)

unpacked_rgb, unpacked_depth = unpack(image_rgbd, ps, 'h w *')
print(unpacked_rgb.shape, unpacked_depth.shape)                                                             

@awni awni requested review from jagrit06, angeloskath and barronalex and removed request for jagrit06 July 26, 2024 15:13
@awni awni mentioned this pull request Jul 26, 2024
4 tasks
Copy link
Collaborator

@barronalex barronalex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so nice! Excited to use it. 🚀

},
"api_version"_a = nb::none(),
R"pbdoc(
Used to apply updates at the given indices.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: doc string? Took me a sec to figure out this is how einops knows to call mx.permute_dims etc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that's a copy typo, good catch.

@@ -1839,15 +1839,6 @@ array argsort(const array& a, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str());
}

// TODO: Fix GPU kernel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

@awni awni merged commit 7b456fd into main Jul 26, 2024
3 checks passed
@awni awni deleted the array_api branch July 26, 2024 17:40
@lucascolley
Copy link

great to see progress towards gh-48! FYI, there has been interest in adding a fully-compliant wrapper to array-api-compat in data-apis/array-api-compat#162.

@awni
Copy link
Member Author

awni commented Aug 6, 2024

FYI, there has been interest in adding a fully-compliant wrapper to array-api-compat in data-apis/array-api-compat#162

Cool, thanks for sharing! If there's anything we can do to help on the MLX side let us know.

@asmeurer
Copy link

asmeurer commented Aug 6, 2024

The main thing that would help is if you could let us know what your plans are for MLX and the array API. If you're planning to have full compliance in MLX itself, then there's not much that we need to do in array-api-compat, other than documentation and possibly adding helper functions. This is similar to how JAX and sparse work.

On the other hand, if there are some functionalities that you won't be able to support for backwards compatibility, we should add wrappers to array-api-compat. (if there are functions that you can't support because they simply aren't supported by the hardware, then there's no way to wrap that and we just have to document it)

@awni
Copy link
Member Author

awni commented Aug 6, 2024

Looking through the array API I think for 95+% of it we already have an implementation or would like to have one.

There are a few exceptions that I noticed:

  • Attributes and methods related to devices, like device and to_device. Arrays in MLX live in unified memory and don't belong to a specific device and can't be moved to a device
  • The empty creation function. There are no true "in-place" ops in MLX. We manage memory buffer sharing under the hood. So something like an empty doesn't really make sense.
  • Operations with output shapes which depend on input data (e.g. boolean indexing and similar ops). Those are hard to do in MLX and we may not be able to support them at least in their typical form.

Let me know what you think about that.

@lucascolley
Copy link

I believe there are also float64 and complex128 dtypes which don't exist in MLX.

@awni
Copy link
Member Author

awni commented Aug 6, 2024

Good point I missed that. And we don’t plan to add them as they are not natively supported for Apple GPUs.

@betatim
Copy link

betatim commented Aug 6, 2024

The lack of float64 shouldn't be a problem. For example PyTorch with the MPS backend (their way of using the GPU in modern Apple laptops) also doesn't support float64 but PyTorch is still considered compatible.

I think how to deal with the fact that mlx is designed with a "native" unified memory experience (I don't know a better short way to describe it :D) will be interesting to figure out. On the face of it I think mlx could do what Numpy does, have a device attribute that is always the same. In Numpy's case it is always the string "cpu". More interesting might be how to deal with the fact that in mlx operations can be placed on a device. This is an idea that doesn't exist in any(?) other array library and also doesn't exist in the array API standard.

Maybe at the level of array API standard compatibility it isn't such a big deal, but as a "array consumer" in scikit-learn I am currently still scratching my head how we'd write code that works with mlx and (say) PyTorch inputs. Without a bunch of if statements to conditionally add stream=mx.gpu as argument to operations that I'd like to have executed on the GPU.

@awni
Copy link
Member Author

awni commented Aug 6, 2024

Without a bunch of if statements to conditionally add stream=mx.gpu as argument to operations that I'd like to have executed on the GPU.

The default device is the GPU so you wouldn't need to place any operations explicitly unless you wanted to change the default.

To avoid needing a bunch of if statements, one possibility is to use a stream context manager e.g.

with mx.stream(mx.cpu):
   # call the function / do the work

For the most part though running everything on the GPU works pretty well. I'm not sure how you feel about using that as the default though..

@asmeurer
Copy link

asmeurer commented Aug 6, 2024

You might want to use the test suite to see what is missing in more detail (although we may need to update it so that ARRAY_API_TESTS_SKIP_DTYPES works with float64. I'm not sure if works right now) https://github.com/data-apis/array-api-tests

Attributes and methods related to devices, like device and to_device. Arrays in MLX live in unified memory and don't belong to a specific device and can't be moved to a device

This is similar to NumPy where there is only one device ('cpu'). NumPy added these methods but they don't really do anything.

The empty creation function. There are no true "in-place" ops in MLX. We manage memory buffer sharing under the hood. So something like an empty doesn't really make sense.

Are you saying that things like += don't work at all? That's a fairly big discrepancy if so. Note that the standard already says that mutation across views is not a required feature https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html. For instance, JAX, which also manages memory for the user is able to be compatible.

If you don't have a real "empty" function you could always alias empty = zeros (this is what JAX does).

Operations with output shapes which depend on input data (e.g. boolean indexing and similar ops). Those are hard to do in MLX and we may not be able to support them at least in their typical form.

This is OK. The standard explicitly states that this functionality isn't required (see https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html). Note that the test suite doesn't yet have a way to disable these tests, so you'll have to just ignore those failures.

The important thing for the array API is consistency across libraries so that people can write one function that works with all array libraries. So if a scipy function uses empty or device it should just work even if those functions aren't really doing anything interesting.

Correct me if I am wrong, but it sounds like you are wanting to implement full compatibility in MLX (excluding data-dependent operations and missing dtypes). If that's the case, there is no need to implement wrappers in array-api-compat. The main thing that you should implement, if this is the case, is __array_namespace__, so that MLX will work with array API consuming code. (EDIT: just realized that's exactly what this PR did)

It also sounds like an is_mlx_array helper in array-api-compat would be useful since MLX has some features that are specific to it (like the per-device computation features) that consuming code might want to take advantage of.

@awni
Copy link
Member Author

awni commented Aug 6, 2024

You might want to use the test suite to see what is missing in more detail

Indeed I wasn't aware of it. Good idea.

Are you saying that things like += don't work at all?

No they work as expected. They just don't give any guarantees about the underlying memory buffer being unchanged.

If you don't have a real "empty" function you could always alias empty = zeros

Seems reasonable!

It also sounds like an is_mlx_array helper in array-api-compat would be useful since MLX has some features that are specific to it (like the per-device computation features) that consuming code might want to take advantage of.

👍

@altaic
Copy link

altaic commented Aug 6, 2024

This is similar to NumPy where there is only one device ('cpu'). NumPy added these methods but they don't really do anything.

Maybe something like device returning 'uma' would make sense, as it explicitly conveys that data resides on the unified memory architecture.

@altaic altaic mentioned this pull request Aug 8, 2024
@awni
Copy link
Member Author

awni commented Aug 10, 2024

I was trying to use MLX with scikit learn through the array api back-end. It failed because MLX doesn't have float64 dtype.

It would be super cool if we could get MLX to work with Scikit-learn.. but I'm not sure if there is a workaround for the float64 issue. What do you all think about that?

Here's the example I tried for reference:

from sklearn.datasets import make_classification
from sklearn import config_context
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import mlx.core as mx

X_np, y_np = make_classification(random_state=0)
X_mx = mx.array(X_np)
y_mx = mx.array(y_np)

with config_context(array_api_dispatch=True):
    lda = LinearDiscriminantAnalysis()
    X_trans = lda.fit_transform(X_mx, y_mx)

print(type(X_trans))

@lucascolley
Copy link

Do you just get an error message like "mx has no attribute float64"?

If so, and I'm not saying this is a good solution, things should run if you wrap MLX such that xp.float64 is actually the float32 dtype, but 'pretends' it is float64? (I may be missing something though)

@awni
Copy link
Member Author

awni commented Aug 10, 2024

Yea exactly. And yes that's a good suggestion for a patch.. though long term it might be worth thinking about how to deal with that in the array API / compatibility layer.

@lucascolley
Copy link

lucascolley commented Aug 10, 2024

The functionality exists to query which dtypes a library supports: https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.__array_namespace_info__.html

Indeed, based on data-apis/array-api#640 (comment), it sounds like the spec will be relaxed at some point to accommodate for missing dtypes. Consumers being able to tell which dtypes they can use is possible via this introspection method, but we're going to want a helper where we can say:

  • If it exists, use {this dtype}
  • If not, fallback to {one of these dtypes}
  • Else, raise an exception

Probably something that should belong in array-api-compat, however I can't immediately see how we would want this to work. Most dtype selections happen through xp.asarray and xp.astype, but it's not clear to me whether we would want to:

  • wrap these functions themselves
  • wrap the input dtypes to these functions
  • or something else

It could get quite complex if there are varying policies for different parts of a library, like if we are perfectly happy to fallback from float64 to float32 in some places, but not others. I suppose this case in itself is not too bad, but it would be pretty complex to generalize for weirder (based on the standards of today...) sets of implemented dtypes.

There is also the case of testing. It doesn't seem ideal to have to add a line to skip based on whether the dtype is supported in every test which is parameterized by standard dtypes, but I suppose that would work.

@betatim
Copy link

betatim commented Aug 14, 2024

@awni could you open an issue on the scikit-learn repository with your code snippet and a few words of context (for those who haven't followed this thread)? That would be a great to forge some links between the two projects.

The reason it fails is that we currently special case torch (on MPS) as it is the only namespace without float64. Having another library with full array API support and a missing dtype would be a good reason to work on making this more general (and eventually using the introspection mechanism).

@awni
Copy link
Member Author

awni commented Aug 14, 2024

could you open an issue on the scikit-learn repository with your code snippet and a few words of context (for those who haven't followed this thread)

Yes, done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants