Skip to content

Commit

Permalink
FP8 reduction for atomic TP-RS with p2p exchange
Browse files Browse the repository at this point in the history
Signed-off-by: Sangkug Lym <slym@nvidia.com>
  • Loading branch information
erhoo82 committed Mar 28, 2024
1 parent 695edb9 commit 61be0f0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
20 changes: 14 additions & 6 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
_ubufs[_self_chunk_id].numel() *
_ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
Expand All @@ -754,8 +755,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
src_ptr,
n_chunk * m * D.element_size(),
cudaMemcpyDeviceToDevice,
(cudaStream_t) stream_main)
);
(cudaStream_t) stream_main));
// Return the last N rows of D_buffer
torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n);
return D_return;
Expand Down Expand Up @@ -1019,9 +1019,17 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {

// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].data_ptr());
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
assert(_ubuf_scale_inv_initialized);
float *d_scale_inv_ptr = reinterpret_cast<float *>(_ubuf_scale_inv.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr,
_tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main);
} else {
torch::Tensor reduce_buf = torch::from_blob(
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
}

/*
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,18 @@ def add_ub(
if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
if name in rs_ag_pairs:
assert_massage = (
assert_message = (
f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
"outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
"GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
"`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
"for functionality."
)
if name in layers_atomic_ring_exchange:
assert atomic_gemm and method == "ring_exchange", assert_massage
assert atomic_gemm and method == "ring_exchange", assert_message
else:
if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_massage
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message

sample_buffer = torch.empty(
shape,
Expand Down

0 comments on commit 61be0f0

Please sign in to comment.