Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Code snippet runs 4x slower since mlx 0.19.0 #1918

Open
gabrieldemarmiesse opened this issue Mar 3, 2025 · 5 comments
Open

[BUG] Code snippet runs 4x slower since mlx 0.19.0 #1918

gabrieldemarmiesse opened this issue Mar 3, 2025 · 5 comments

Comments

@gabrieldemarmiesse
Copy link

gabrieldemarmiesse commented Mar 3, 2025

Describe the bug

When developping our model, we noticed that changing the version of mlx had a huge impact on performance, notably that mlx 0.18.1 was faster than all mlx versions released afterwards.

To Reproduce

Here is a minimal reproducible example:

# /// script
# requires-python = "==3.12.9"
# dependencies = [
#     "mlx==0.23.1",
# ]
# ///
import time

import mlx.core as mx
import mlx.nn as nn


def something(kv):
    xs = mx.zeros(shape=(1, 32, 1, 128), dtype=mx.bfloat16)
    for _ in range(128):
        xs = xs + mx.fast.scaled_dot_product_attention(xs, kv, kv, scale=0.1)
    return xs


def main():
    WARMUP = 5
    TOTAL_STEPS = 40
    mx.random.seed(299792458)
    
    kv = mx.zeros(shape=(1, 1024, 4096), dtype=mx.bfloat16)
    kv = kv.reshape(1, -1, 2, 32, 128)
    kv = kv[:, :, 0].transpose(0, 2, 1, 3)
    mx.eval(kv)

    sum_times = 0
    for i in range(TOTAL_STEPS):
        t1 = time.time()
        mx.eval(something(kv))
        t2 = time.time()
        if i >= WARMUP:
            sum_times += t2 - t1

    print(f"average time per step: {(sum_times / (TOTAL_STEPS - WARMUP)) * 1000:1f} ms")

main()
gabrieldemarmiesse@Kyutais-Laptop V2LLMs % uv add --script bench.py mlx==0.17.1 && uv run bench.py
average time per step: 10.089261 ms
gabrieldemarmiesse@Kyutais-Laptop V2LLMs % uv add --script bench.py mlx==0.18.1 && uv run bench.py
average time per step: 9.984575 ms
gabrieldemarmiesse@Kyutais-Laptop V2LLMs % uv add --script bench.py mlx==0.19.0 && uv run bench.py
average time per step: 44.592115 ms
gabrieldemarmiesse@Kyutais-Laptop V2LLMs % uv add --script bench.py mlx==0.19.3 && uv run bench.py
average time per step: 44.662435 ms
gabrieldemarmiesse@Kyutais-Laptop V2LLMs % uv add --script bench.py mlx==0.20.0 && uv run bench.py
average time per step: 44.754716 ms
gabrieldemarmiesse@Kyutais-Laptop V2LLMs % uv add --script bench.py mlx==0.21.1 && uv run bench.py
average time per step: 40.465389 ms
gabrieldemarmiesse@Kyutais-Laptop V2LLMs % uv add --script bench.py mlx==0.22.1 && uv run bench.py
average time per step: 40.441854 ms
gabrieldemarmiesse@Kyutais-Laptop V2LLMs % uv add --script bench.py mlx==0.23.1 && uv run bench.py
average time per step: 40.197761 ms

Expected behavior
When a new MLX version is released, it should be as fast or faster than the previous versions for any code written.

Desktop (please complete the following information):
Model Name: MacBook Air
Model Identifier: Mac15,12
Chip: Apple M3
Total Number of Cores: 8 (4 performance and 4 efficiency)
Memory: 16 GB
System Firmware Version: 10151.81.1
OS Loader Version: 10151.81.1
ProductName: macOS
ProductVersion: 14.3
BuildVersion: 23D2057

As previously mentionned, the mlx version influences the average time per step, with more recent version performing worse. I run the script with a given version by running uv add --script bench.py mlx==0.23.1 && uv run bench.py.

Additional information

I also reproduced this slowdown on a mac Mini (Mac16,11), with 24gb of ram and 12 cores. The slowdown is then between 4x and 7x depending on the version.

@awni
Copy link
Member

awni commented Mar 3, 2025

It runs a lot slower on M3 Max from 0.18.1 to 0.19.0 🤔

@awni
Copy link
Member

awni commented Mar 3, 2025

It must be related to the addition of the fused op #1497 and it must be hitting an unusual case (since your doing some slicing + transposes on the KV).

@awni
Copy link
Member

awni commented Mar 3, 2025

The problem is the transpose:

kv[...].transpose(0, 2, 1, 3)

incurs a copy of both the keys and values before calling the kernel and that slows things down a lot.

I'm not sure if we can support transposed keys natively in the op .. maybe.

Usually the KV cache does not need a transpose because you are updating an already transposed KV cache with the new key. Is your actual computation quite different from that?

@gabrieldemarmiesse
Copy link
Author

gabrieldemarmiesse commented Mar 4, 2025

On the minimal reproducible example, the bug can be avoided by adding

kv = mx.array(np.array(kv.astype(mx.float32))).astype(mx.bfloat16)

just before the first eval. Though it's quite counter-intuitive, it makes the bug go away, and we get to 6ms per step with mlx 0.23.1.

A similar method was used on the full size model and achieve a speedup on mlx 0.23.1. But while the full size model was faster with this trick, it was not as fast as when using mlx v0.18.1 (0.18.1 is still ~20% faster than 0.23.1). So it seems there are other things at play here.

I could once again try to reduce the whole model into a minimal reproducible example with the numpy trick, but it's quite time consuming, so maybe I could open a new issue once this one is fixed?

@awni
Copy link
Member

awni commented Mar 4, 2025

Though it's quite counter-intuitive, it makes the bug go away

It makes sense.. since it results in the kv being contiguous. When you convert from numpy to MLX we usually copy to a contiguous array.

What would be helpful is if you could share more about how your actual computation works (not the toy example). The reason being, we want to make sure we optimize the right thing here / find an optimal solution.

In your actual computation do you use a growing KV cache? If so, the simplest fix would be to store that already transposed. If not, what are you doing instead?

Is it more like an encoder / decoder model with cross attention to the kv?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants