Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement persistent matmul scheduling #3812

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Feb 3, 2025

Stacked on #3642

This is a followup to #3792 that implements persistent scheduling.

There is a current limitation that affects both persistent scheduling and "grid swizzling": if MatmulOp or LinearOp are present in the fusion, we will hit inlining errors. This is because in that case we have a non-trivial AxisMapping on the MmaOp. The missing input dimensions are not tracked through the scheduling transforms (merges and splits) required for either grid swizzling or persistent scheduling. Because of this, I introduced three new parametrized tests matching the original MLPBenchmarkTests but with _BroadcastInputs suffix. These tests use fusedMultiplySum instead of linear. The persistent variant of the non BroadcastInputs tests are skipped until we fix the inlining issue.

I currently observe a correctness issue in the MLPBenchmarkTest.FwdEpilogueFusion_BroadcastInputs test regardless of parametrization. This means that we are getting incorrect results even for data parallel scheduling. I confirmed this test also fails on main. I currently skip this test with a warning mesage.

jacobhinkle and others added 29 commits December 23, 2024 20:54
I think this covers the motivation for #3616
There is still one case that fails, which we should fix. I'll create an
issue for it.
Co-authored-by: Ryan Spring <rspring@nvidia.com>
@jacobhinkle jacobhinkle requested a review from rdspring1 February 3, 2025 16:04
@jacobhinkle
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 3, 2025

Description

  • Implement persistent matmul scheduling

  • Add support for warp specialization on Hopper

  • Parametrize MLP Benchmark tests for different configurations

  • Add new tests for broadcast inputs in MLP benchmarks


Changes walkthrough 📝

Relevant files
Enhancement
index.cpp
Update stmatrix dimension check                                                   

csrc/device_lower/pass/index.cpp

  • Update error check to allow more than 2 dimensions for stmatrix
+1/-1     
hopper_multi_matmul.cpp
Enhance Hopper matmul scheduling                                                 

csrc/scheduler/hopper_multi_matmul.cpp

  • Split Hopper MMA by warp-tile before instruction tile
  • Add new methods for transforming TensorView like mma output with and
    without K
  • Validate persistent kernel scheduling
  • Schedule operands and epilogue with new transformation methods
  • Update TMA tile sizes
  • Schedule split-K sum with new transformation method
  • +76/-28 
    matmul_utils.cpp
    Set default warp specialization                                                   

    csrc/scheduler/matmul_utils.cpp

    • Set warp specialization as default on Hopper
    +4/-0     
    test_matmul.cpp
    Parametrize MLP tests and add broadcast input tests           

    tests/cpp/test_matmul.cpp

  • Parametrize MLP Benchmark tests for different configurations
  • Add new tests for broadcast inputs in MLP benchmarks
  • Skip persistent kernel tests for LinearOp translation
  • +344/-76
    test_translate_mma.cpp
    Extend CUDA arch range and skip failing test                         

    tests/cpp/test_translate_mma.cpp

  • Extend CUDA arch range for matmul node translation test
  • Temporarily skip failing test case on Hopper
  • +6/-1     
    hopper_multi_matmul.h
    Add new transformation methods                                                     

    csrc/scheduler/hopper_multi_matmul.h

  • Add new methods for transforming TensorView like mma output with and
    without K
  • +6/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The test FwdEpilogueFusion_BroadcastInputs is currently failing. This needs to be investigated to understand the root cause and fix the issue.

    TEST_P(MLPBenchmarkTest, FwdEpilogueFusion_BroadcastInputs) {
      GTEST_SKIP() << "THIS TEST IS CURRENTLY FAILING" << std::endl;
    
    Performance Concern

    The test FwdHorizontalFusion has some failing checks due to improper syncing of horizontally fused kernels. This should be addressed to ensure the performance benefits of horizontal fusion are realized.

    // EXPECT_TRUE(cg_outputs[0].allclose(tv3_ref, 1e-6 * K, 1e-6 * K));
    EXPECT_TRUE(cg_outputs[1].allclose(tv10_ref, 1e-6 * K, 1e-6 * K));
    // EXPECT_TRUE(cg_outputs[2].allclose(tv12_ref, 1e-2, 1e-1));
    Code Clarity

    The new methods transformLikeMmaOutputWithK and transformLikeMmaOutputWithoutK have detailed comments but could benefit from additional comments explaining the purpose of each step in the transformation process. This would help future maintainers understand the logic more easily.

    void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithK(
        TensorView* tv) {
      // The input is originally block tiled so that the inner dims are the CTA tile
      // size
      //
      // We split this into warp tiles then instruction tiles
      // Original: [..., M, N, K]
      tv->split(-3, params_->tile_sizes.warp_tile.m);
      tv->split(-3, getM(params_->mma_macro));
      tv->split(-2, params_->tile_sizes.warp_tile.n);
      tv->split(-2, getN(params_->mma_macro));
      // K dimension is present for mma_result
      // We don't need to split by warp_tile.k, since we always have
      // cta_tile.k==warp_tile.k
      tv->split(-1, getK(params_->mma_macro));
      // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Kw, Ki]
      tv->reorder({
          {-8, -8}, // Mo
          {-7, -6}, // Mw
          {-6, -3}, // Mi
          {-5, -7}, // No
          {-4, -5}, // Nw
          {-3, -2}, // Ni
          {-2, -4}, // Kw
          {-1, -1}, // Ki
      });
      // After Reorder: [..., Mo, No, Mw, Nw, Kw, Mi, Ni, Ki]
      tv->merge(-8);
      // After Merge: [..., Mo * No, Mw, Nw, Kw, Mi, Ni]
      tv->axis(-7)->parallelize(ParallelType::TIDy);
      // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Kw, Mi, Ni, Ki]
    }
    
    void HopperMultipleMatmulScheduler::transformLikeMmaOutputWithoutK(
        TensorView* tv) {
      // TODO Add constraints

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant