Skip to content

Commit

Permalink
Atomic gemm for TP-AR and TP-RS overlap with P2P exchanges
Browse files Browse the repository at this point in the history
Signed-off-by: Sangkug Lym <slym@nvidia.com>
  • Loading branch information
erhoo82 committed Mar 23, 2024
1 parent c1a68f6 commit 1fd33af
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 99 deletions.
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def fp8_gemm(
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args)
if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
out = fn(*args)
else:
_ = fn(*args)

return out, gelu_input

Expand Down
156 changes: 65 additions & 91 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,26 +623,20 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_ubuf_scale_inv_initialized = false;

_atomic_gemm = atomic_gemm;
_self_chunk_id = _tp_id;
if (_atomic_gemm) {
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
counter = torch::zeros({tp_size * 2}, counter_options);
counter.index_put_({Slice(None, tp_size)}, 1);
_self_chunk_id = _tp_id;

if (!is_reduce_scatter) {
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
const char *env_p = std::getenv("NVTE_AG_P2P_INDIV_P2P");
if (rank == 0 && env_p != nullptr) {
if (env_p[0] == '1') {
printf("!!userbuffers_sendrecv_atomic\n");
} else if (env_p[0] == '2') {
printf("!!userbuffers_sendrecv_multiatomic\n");
} else if (env_p[0] == '3') {
printf("!!userbuffers_sendrecv_multiatomic_shuffle\n");
_self_chunk_id = 0;
} else {
printf("!!userbuffers_sendrecv\n");
printf("!!userbuffers_sendrecv_indiv_atomic_shuffle\n");
}
}
_self_chunk_id = 0;
counter.index_put_({_self_chunk_id}, 0);
}
}
Expand Down Expand Up @@ -675,13 +669,17 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0);
const int n_chunk = _ubufs[0].size(0);
const int n = _ubuf.size(0);
const int n_chunk = n / _tp_size;

// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();

// Create an GEMM output buffer with N+1 chunks in a contiguous memory
torch::Tensor D_buffer = torch::empty({n_chunk * (_tp_size + 1), m}, D.options());
D = torch::from_blob(D_buffer.data_ptr(), {D.size(0), D.size(1)}, D.options());

// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
Expand All @@ -692,100 +690,75 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (B_scale_inverse.numel())
B_scale_inverse = B_scale_inverse[B_fp8_tensor];

at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));

assert(pre_gelu_out.numel() == 0);

// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));

torch::Tensor output_chunk = torch::from_blob(output_ptr, {_ubuf.size(0), m}, D.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
for (int i = 0; i < _tp_size; i++) {

for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;

if (i < _tp_size - 1) {
const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC");
if (env_p != nullptr && env_p[0] == '1') {
userbuffers_sendrecv_atomic(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes,
_ub_comm, _next_rank, _prev_rank, &counter_ptr[recv_chunk_id],
(cudaStream_t)_stream_recv);
} else if (env_p != nullptr && env_p[0] == '2') {
if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, false, (cudaStream_t)_stream_recv);
}
} else if (env_p != nullptr && env_p[0] == '3') {
if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, true, (cudaStream_t)_stream_recv);
}
} else {
// P2P communication
// userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset,
// comm_bytes, _ub_comm,
// _next_rank, (cudaStream_t)_stream_send);
// userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset,
// comm_bytes, _ub_comm,
// _prev_rank, (cudaStream_t)_stream_recv);
// CHECK_CUDA(cudaEventRecord(_stop_recv,
// (cudaStream_t)_stream_recv));
// CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send,
// _stop_recv, 0));
userbuffers_sendrecv(_ub_reg, _ub_reg, send_offset, recv_offset, comm_bytes, _ub_comm,
_next_rank, _prev_rank, (cudaStream_t)_stream_recv);
producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv);
}
if (i == 0) {
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, false, counter);
}
const char *env_p = std::getenv("NVTE_AG_P2P_INDIV_P2P");
if (env_p != nullptr && env_p[0] == '1') {
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, _next_rank, (cudaStream_t) _stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, _prev_rank, (cudaStream_t) _stream_recv);
producer(counter_ptr, recv_chunk_id, (cudaStream_t)_stream_recv);
} else {
// GEMM
// userbuffers_send_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes,
// _ub_comm,
// _next_rank, _tp_size, comm_bytes, comm_bytes,
// (cudaStream_t)_stream_send);
// userbuffers_recv_multiatomic(_ub_reg, 0, _ub_reg, 0, comm_bytes,
// _ub_comm,
// _prev_rank, _tp_size, counter_ptr,
// (cudaStream_t)_stream_recv);
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
if (i == 0) {
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size,
counter_ptr, true, (cudaStream_t)_stream_recv);
}
}
}
for (int i = 0; i < _tp_size; i++) {
if (i != _self_chunk_id) {
consumer(counter_ptr, i, (cudaStream_t)_stream_compute[0]);
if (i == 0) {
te_atomic_gemm(A, A_scale_inverse, A_type, transa, _ubuf, B_scale_inverse, B_type, transb,
D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, false, counter);
}
}
at::cuda::setCurrentCUDAStream(stream_main);
CHECK_CUDA(cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[0]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));

return D;
// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_self_chunk_id].data_ptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaEventRecord(_stop_send, (cudaStream_t)_stream_send));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_send, 0));
}

// Reset atomic counters
consumer_batch(counter_ptr, 1, _tp_size, (cudaStream_t)stream_main);

// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.data_ptr());
CHECK_CUDA(cudaMemcpyAsync(
src_ptr + (D.numel() * D.element_size()),
src_ptr,
n_chunk * m * D.element_size(),
cudaMemcpyDeviceToDevice,
(cudaStream_t) stream_main)
);
// Return the last N rows of D_buffer
torch::Tensor D_return = D_buffer.narrow(0, n_chunk, n);
return D_return;
} // atomic_gemm_overlap_ag

/*
Expand Down Expand Up @@ -1018,6 +991,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));

// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
Expand All @@ -1031,14 +1005,14 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp;

consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, send_rank, (cudaStream_t) _stream_recv);
_ub_comm, send_rank, (cudaStream_t) _stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes,
_ub_comm, recv_rank, (cudaStream_t) _stream_recv);
_ub_comm, recv_rank, (cudaStream_t) _stream_recv);
}
CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0));
Expand Down Expand Up @@ -1174,7 +1148,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS)
NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
Expand Down
20 changes: 20 additions & 0 deletions transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3655,6 +3655,20 @@ static __global__ void consumer_kernel(void *atomic_ptr, int chunk_i) {
}
}

// consumer_batch
static __global__ void consumer_batch_kernel(void *atomic_ptr, int first_chunk_i, int num_chunks) {
// Wait for producer to change the val to 0, which signal producer ready
if (blockIdx.x == 0 && threadIdx.x == 0) {
int old_val;
for (int i = first_chunk_i; i < num_chunks; i++) {
while (0 != (old_val = atomicCAS((unsigned int *)atomic_ptr + i, 0, 0))) {
}
((unsigned int *)atomic_ptr)[i] = 1;
asm volatile("fence.sc.gpu;\n");
}
}
}

void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
Expand All @@ -3667,6 +3681,12 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) {
consumer_kernel<<<grid, block, 0, stream>>>(atomic_ptr, chunk_i);
}

void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) {
dim3 block(1);
dim3 grid(1);
consumer_batch_kernel<<<grid, block, 0, stream>>>(atomic_ptr, first_chunk_i, num_chunks);
}

template <typename fp8type>
__global__ void __launch_bounds__(MAX_THREADS / 4)
reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale,
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/csrc/userbuffers/userbuffers.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ typedef struct communicator communicator;

void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream);
int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */

Expand Down
35 changes: 28 additions & 7 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def initialize_ub(
}
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]

# AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop":"proj_fprop", "fc1_fprop":"fc2_fprop"}
rs_ag_pairs = {v : k for k, v in ag_rs_pairs.items()}
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []

def get_method(name):
for method, names in methods.items():
if name in names:
Expand All @@ -160,20 +166,35 @@ def add_ub(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
)
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if is_reduce_scatter and method == "ring_exchange":
raise ValueError(
"Atomic GEMM is not supported for ReduceScatter with `ring_exchange` method."
)
if method == 'bulk':
warnings.warn(
"Atoimic GEMM not is supported for a bulk overlap."
f"At {name}, atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`."
)
atomic_gemm = 0
if not is_reduce_scatter and method == 'pipeline':
raise ValueError(
"`pipeline` overlap method is not supported for AllGather."
f"At {name}, `pipeline` overlap method is not supported for AllGather."
)
# Check if both AG and RS overlaps use `atomic GEMM`` + `p2p ring-exchange`.
# Using atomic GEMM + p2p ring-exchange in only one of the pair breaks functionality.
global layers_atomic_ring_exchange
if atomic_gemm and method == "ring_exchange" and name in ag_rs_pairs:
layers_atomic_ring_exchange += [name, ag_rs_pairs[name]]
if name in rs_ag_pairs:
assert_massage = (
f"At {name}, atomic AG-GEMM overlap with `ring_exchange` shuffles GEMM chunk "
"outputs, and RS-GEMM overlap un-suffle them. When one of the GEMM-AG and "
"GEMM-RS overlaps forming a TP block (e.g., qkv_fprop and proj_fprop) uses "
"`atomic gemm` and `ring_exhcnage`, its pair must use the same overlap config "
"for functionality."
)
if name in layers_atomic_ring_exchange:
assert atomic_gemm and method == "ring_exchange", assert_massage
else:
if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_massage

sample_buffer = torch.empty(
shape,
dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype,
Expand Down Expand Up @@ -213,7 +234,7 @@ def add_ub(
method = ub_cfg["method"] if "method" in ub_cfg else get_method(name)
num_sm = ub_cfg["num_sm"] if "num_sm" in ub_cfg else 16
cga_size = ub_cfg["cga_size"] if "cga_size" in ub_cfg else 2
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0
num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 4
set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0
aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0
atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0
Expand Down

0 comments on commit 1fd33af

Please sign in to comment.