Add kernel‘s config cache and improve TMA alignment #47
+122
−92
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
When testing the function deep_gemm.gemm_fp8_fp8_bf16_nt, we observed that for certain examples, especially matrices with small shapes, the overhead outside of the GEMM kernel's execution was non-negligible. By using PyTorch's profiler, we found that the execution of the functions get_col_major_tma_aligned_tensor and get_best_configs introduces significant overhead.

We tried to perform memory alignment in advance by calling the function get_col_major_tma_aligned_tensor before before invoking the function deep_gemm.gemm_fp8_fp8_bf16_nt. But we still noticed that the alignment-related operations were triggered during the execution of deep_gemm.gemm_fp8_fp8_bf16_nt. Additionally, the function get_best_configs redundantly executes for the same inputs, yielding the same results, which introduces unnecessary overhead.
First, we fixed the improper handling of 2D matrices in function get_col_major_tma_aligned_tensor. Additionally, we added a cache to obtain the optimal configuration before the GEMM kernel launches, which reduces overhead for inputs of the same size.