-
Notifications
You must be signed in to change notification settings - Fork 472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to skip certain expert in grouped gemm masked layout #20
Comments
The current impl already supports that, try this test (notice def test_m_grouped_gemm_masked() -> None:
print('Testing grouped masked GEMM:')
for num_groups, m in ((2, 512), (4, 256)):
for k, n in ((7168, 4096), (2048, 7168), ):
# Test correctness
masked_m_candidates = list(filter(lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)))
for i in range(10):
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = random.choice(masked_m_candidates)
masked_m[1] = 0 # !!!! Skipping the expert 1
expected_m = min(int(masked_m.float().mean()) + 1, m)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m)
for j in range(num_groups):
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
# Skipping expert 1 checks
assert j == 1 or diff < 0.001, f'{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * m
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, m)
# Test performance with fixed shapes
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Performance ({num_groups=}, m_per_group={m:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(num_groups * (m * k + k * n + m * n * 2)) / 1e9 / t:4.0f} GB/s')
print() For the skipped experts, the NaN output is in expectation as no computation for that expert and no memory writes (no cleaning them into zero). |
Ahh thanks for your reply! That makes sense, got asserted for the nan diff at the beginning, but yeah that should be the expected behavior, closing this for now :) |
Sry for re-opening this issue, but I observed an unexpected benchmark result. Below is my latency benchmark and also my test code. All tested on H200 GPU though.:
|
@DiegoD94 It seems that def test_func():
# Construct new tensors every time to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out = construct_grouped(num_groups, m, k, n, is_masked=True)
masked_m = torch.ones((num_groups, ), device='cuda', dtype=torch.int) * real_m
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, real_m) |
Hi, thanks for the great implementation.
I'm trying to use this grouped gemm kernel into FusedMoE. I tried to set the mask_m as 0 for the experts that I want to skip but it got me nan for the output. Wondering is there a best practice for skipping certain experts? Which should be a very common situation under low batch size decoding stage.
The text was updated successfully, but these errors were encountered: