From 1ec33ae1191ae6644365155f8e8f618145c44cd7 Mon Sep 17 00:00:00 2001 From: Rachit Garg Date: Fri, 15 Mar 2024 13:44:15 -0700 Subject: [PATCH] Rachitg/dp carveout (#722) * fix the perf regression because of constant property polling of the device Signed-off-by: Rachit Garg * Fix lint error Signed-off-by: Tim Moon --------- Signed-off-by: Rachit Garg Signed-off-by: Tim Moon Co-authored-by: Rachit Garg Co-authored-by: Tim Moon --- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 71402d2001..a7217d4570 100755 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -9,6 +9,7 @@ #include #include #include "common/util/system.h" +#include "common/util/cuda_runtime.h" namespace { transformer_engine::DType reverse_map_dtype(int64_t dtype) { @@ -320,10 +321,9 @@ at::Tensor te_gemm_ts(at::Tensor A, // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - int num_math_sms = prop.multiProcessorCount \ - - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + + const int sm_count = transformer_engine::cuda::sm_count(); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];