Skip to content

Commit

Permalink
PR #22593: [XLA:GPU] Fix triton dot op on sm120 (RTX50xx)
Browse files Browse the repository at this point in the history
Imported from GitHub PR #22593

Triton doesn't currently support sm120 GPUs - adding a patch to fix that, the upstream support should be available soon.

Converting "12.0" arch into "10.0" is not correct, as they're not compatible - removing this.
Copybara import of the project:

--
5b5752d by Sergey Kozub <skozub@nvidia.com>:

[XLA:GPU] Fix triton dot op on sm120 (RTX50xx)

Merging this change closes #22593

FUTURE_COPYBARA_INTEGRATE_REVIEW=#22593 from openxla:skozub/sm120_dot 5b5752d
PiperOrigin-RevId: 725968626
  • Loading branch information
sergey-kozub authored and Google-ML-Automation committed Feb 13, 2025
1 parent 6b470af commit 1854cb5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ those to this list.
temporary_patch_list = [
"//third_party/triton:temporary/fix_fence_insertion_race.patch",
"//third_party/triton:temporary/enable_peer_access.patch",
"//third_party/triton:temporary/sm120.patch",
# Add new patches just above this line
]
13 changes: 13 additions & 0 deletions third_party/triton/temporary/sm120.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
index c66c9f4ae..3415d6a91 100644
--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
@@ -33,6 +33,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
versionsSupported = {3, 2};
} else if (computeCapability < 110) {
versionsSupported = {5, 2};
+ } else if (computeCapability == 120) {
+ versionsSupported = {2};
} else {
assert(false && "computeCapability not supported");
}
5 changes: 0 additions & 5 deletions xla/backends/gpu/codegen/triton/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1244,11 +1244,6 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const auto& cc = device_info.gpu_compute_capability();
std::string arch_name =
std::visit([](auto& cc) { return cc.ToString(); }, cc);
if (arch_name == "12.0") {
LOG(WARNING) << "Triton does not support sm_120 yet. Passing CC 10.0 to "
"avoid spurious \"unsupported conversion\" errors";
arch_name = "10.0";
}
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
auto ccCuda = std::get<se::CudaComputeCapability>(cc);
if (!ccCuda.IsAtLeastAmpere()) {
Expand Down

0 comments on commit 1854cb5

Please sign in to comment.