Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #22593: [XLA:GPU] Fix triton dot op on sm120 (RTX50xx) #22690

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ those to this list.
temporary_patch_list = [
"//third_party/triton:temporary/fix_fence_insertion_race.patch",
"//third_party/triton:temporary/enable_peer_access.patch",
"//third_party/triton:temporary/sm120.patch",
# Add new patches just above this line
]
13 changes: 13 additions & 0 deletions third_party/triton/temporary/sm120.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
index c66c9f4ae..3415d6a91 100644
--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
@@ -33,6 +33,8 @@ static int getMMAVersionSafe(int computeCapability, DotOp op) {
versionsSupported = {3, 2};
} else if (computeCapability < 110) {
versionsSupported = {5, 2};
+ } else if (computeCapability == 120) {
+ versionsSupported = {2};
} else {
assert(false && "computeCapability not supported");
}
5 changes: 0 additions & 5 deletions xla/backends/gpu/codegen/triton/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1244,11 +1244,6 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
const auto& cc = device_info.gpu_compute_capability();
std::string arch_name =
std::visit([](auto& cc) { return cc.ToString(); }, cc);
if (arch_name == "12.0") {
LOG(WARNING) << "Triton does not support sm_120 yet. Passing CC 10.0 to "
"avoid spurious \"unsupported conversion\" errors";
arch_name = "10.0";
}
if (std::holds_alternative<se::CudaComputeCapability>(cc)) {
auto ccCuda = std::get<se::CudaComputeCapability>(cc);
if (!ccCuda.IsAtLeastAmpere()) {
Expand Down
Loading