Skip to content

Commit

Permalink
Rachitg/dp carveout (#722)
Browse files Browse the repository at this point in the history
* fix the perf regression because of constant property polling of the device

Signed-off-by: Rachit Garg <rachitg@nvidia.com>

* Fix lint error

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Rachit Garg <rachitg@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: Rachit Garg <rachitg@nvidia.com>
Co-authored-by: Tim Moon <tmoon@nvidia.com>
  • Loading branch information
3 people authored Mar 15, 2024
1 parent ffa2447 commit 1ec33ae
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <cuda.h>
#include <cuda_fp8.h>
#include "common/util/system.h"
#include "common/util/cuda_runtime.h"

namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
Expand Down Expand Up @@ -320,10 +321,9 @@ at::Tensor te_gemm_ts(at::Tensor A,

// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
int num_math_sms = prop.multiProcessorCount \
- transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);

const int sm_count = transformer_engine::cuda::sm_count();
int num_math_sms = sm_count - transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", sm_count);

if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];
Expand Down

0 comments on commit 1ec33ae

Please sign in to comment.