From 1854cb54c08c92131791e48d553bb775fdd1ca6a Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Wed, 12 Feb 2025 02:23:13 -0800 Subject: [PATCH] PR #22593: [XLA:GPU] Fix triton dot op on sm120 (RTX50xx) Imported from GitHub PR https://github.com/openxla/xla/pull/22593 Triton doesn't currently support sm120 GPUs - adding a patch to fix that, the upstream support should be available soon. Converting "12.0" arch into "10.0" is not correct, as they're not compatible - removing this. Copybara import of the project: -- 5b5752dc3b3ae3611ffeca0f55a87f130ff1e8bb by Sergey Kozub : [XLA:GPU] Fix triton dot op on sm120 (RTX50xx) Merging this change closes #22593 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/22593 from openxla:skozub/sm120_dot 5b5752dc3b3ae3611ffeca0f55a87f130ff1e8bb PiperOrigin-RevId: 725968626 --- third_party/triton/temporary/series.bzl | 1 + third_party/triton/temporary/sm120.patch | 13 +++++++++++++ xla/backends/gpu/codegen/triton/fusion_emitter.cc | 5 ----- 3 files changed, 14 insertions(+), 5 deletions(-) create mode 100644 third_party/triton/temporary/sm120.patch diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index ad02e2075db19..6cbc285e2db0a 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -16,5 +16,6 @@ those to this list. temporary_patch_list = [ "//third_party/triton:temporary/fix_fence_insertion_race.patch", "//third_party/triton:temporary/enable_peer_access.patch", + "//third_party/triton:temporary/sm120.patch", # Add new patches just above this line ] diff --git a/third_party/triton/temporary/sm120.patch b/third_party/triton/temporary/sm120.patch new file mode 100644 index 0000000000000..252ec47ad1240 --- /dev/null +++ b/third_party/triton/temporary/sm120.patch @@ -0,0 +1,13 @@ +diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +index c66c9f4ae..3415d6a91 100644 +--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +@@ -33,6 +33,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) { + versionsSupported = {3, 2}; + } else if (computeCapability < 110) { + versionsSupported = {5, 2}; ++ } else if (computeCapability == 120) { ++ versionsSupported = {2}; + } else { + assert(false && "computeCapability not supported"); + } \ No newline at end of file diff --git a/xla/backends/gpu/codegen/triton/fusion_emitter.cc b/xla/backends/gpu/codegen/triton/fusion_emitter.cc index d7927c10e5490..8f24b88ce60ba 100644 --- a/xla/backends/gpu/codegen/triton/fusion_emitter.cc +++ b/xla/backends/gpu/codegen/triton/fusion_emitter.cc @@ -1244,11 +1244,6 @@ absl::StatusOr CompileTritonToLLVM( const auto& cc = device_info.gpu_compute_capability(); std::string arch_name = std::visit([](auto& cc) { return cc.ToString(); }, cc); - if (arch_name == "12.0") { - LOG(WARNING) << "Triton does not support sm_120 yet. Passing CC 10.0 to " - "avoid spurious \"unsupported conversion\" errors"; - arch_name = "10.0"; - } if (std::holds_alternative(cc)) { auto ccCuda = std::get(cc); if (!ccCuda.IsAtLeastAmpere()) {