Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to skip certain expert in grouped gemm masked layout #20

Open
DiegoD94 opened this issue Feb 27, 2025 · 4 comments
Open

How to skip certain expert in grouped gemm masked layout #20

DiegoD94 opened this issue Feb 27, 2025 · 4 comments

Comments

@DiegoD94
Copy link

Hi, thanks for the great implementation.
I'm trying to use this grouped gemm kernel into FusedMoE. I tried to set the mask_m as 0 for the experts that I want to skip but it got me nan for the output. Wondering is there a best practice for skipping certain experts? Which should be a very common situation under low batch size decoding stage.

@LyricZhao
Copy link
Collaborator

LyricZhao commented Feb 27, 2025

The current impl already supports that, try this test (notice masked_m[1] = 0 # !!!! Skipping the expert 1 in the test):

def test_m_grouped_gemm_masked() -> None:
    print('Testing grouped masked GEMM:')

    for num_groups, m in ((2, 512), (4, 256)):
        for k, n in ((7168, 4096), (2048, 7168), ):
            # Test correctness
            masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
            for i in range(10):
                x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
                masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
                for j in range(num_groups):
                    masked_m[j] = random.choice(masked_m_candidates)
                masked_m[1] = 0  # !!!! Skipping the expert 1
                expected_m = min(int(masked_m.float().mean()) + 1, m)
                deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
                for j in range(num_groups):
                    diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
                    # Skipping expert 1 checks
                    assert j == 1 or diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'

            # noinspection PyShadowingNames
            def test_func():
                # Construct new tensors every time to avoid L2 cache acceleration
                x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
                masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
                deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)

            # Test performance with fixed shapes
            t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
            print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
                  f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
                  f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
    print()

For the skipped experts, the NaN output is in expectation as no computation for that expert and no memory writes (no cleaning them into zero).

@DiegoD94
Copy link
Author

Ahh thanks for your reply! That makes sense, got asserted for the nan diff at the beginning, but yeah that should be the expected behavior, closing this for now :)

@DiegoD94
Copy link
Author

DiegoD94 commented Feb 27, 2025

Sry for re-opening this issue, but I observed an unexpected benchmark result.
I was trying to benchmark low batch size decoding situation, where assume batch size = 8, then it should be at most 64 experts get routed, and the rest should be skipped.
I'm setting the group number as 256 for TP situation and 32 for EP situation, then I tried to skip certain experts by setting the masking, however I found that the latency won't reduce much even with a lot mask m = 0, could i be that I have some mis understanding of using the kernel?

Below is my latency benchmark and also my test code. All tested on H200 GPU though.:

Testing grouped masked GEMM:
 > Performance (num_groups=256, m_per_group=  64, real_m=   0, n=7168, k=4096): 1847 us | throughput:  521 TFLOPS, 4233 GB/s
 > Performance (num_groups=256, m_per_group=  64, real_m=  16, n=7168, k=4096): 1847 us | throughput:  521 TFLOPS, 4233 GB/s
 > Performance (num_groups=256, m_per_group=  64, real_m=  32, n=7168, k=4096): 1847 us | throughput:  521 TFLOPS, 4233 GB/s
 > Performance (num_groups=256, m_per_group=  64, real_m=  64, n=7168, k=4096): 1846 us | throughput:  521 TFLOPS, 4235 GB/s

def test_m_grouped_gemm_masked() -> None:
    print('Testing grouped masked GEMM:')

    for num_groups in [256]:
        m = 64
        for k, n in ((4096, 7168), (7168, 2048)):
            for real_m in [0, 16,32,64,128,256]:
            # Test correctness
            
                for i in range(10):
                    x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
                    masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
                    for j in range(num_groups):
                        masked_m[j] = real_m
                    if real_m == 0:
                        masked_m[-1] = 1 ## Set at least 1 expert has 1 token to prevent assertion.
                    expected_m = max(real_m, 1) ## expected_m has to be > 0 to prevent assertion.
                    deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)

                # noinspection PyShadowingNames
                def test_func():
                    # Construct new tensors every time to avoid L2 cache acceleration
                    x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
                    masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
                    deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)

                # Test performance with fixed shapes
                t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
                print(f' > Performance ({num_groups=}, m_per_group={m:4}, real_m={real_m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
                    f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
                    f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
    print()

@DiegoD94 DiegoD94 reopened this Feb 27, 2025
@xuzhean
Copy link
Collaborator

xuzhean commented Feb 27, 2025

@DiegoD94 It seems that real_m is not set correctly in test_func.

            def test_func():
                # Construct new tensors every time to avoid L2 cache acceleration
                x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
                masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * real_m
                deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, real_m)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants