From 09519718887e78d62156bf55589590b089e76797 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:07:37 -0800 Subject: [PATCH 1/2] Update list of CI users (#1340) * Update list of CI users Signed-off-by: Tim Moon * Update list of CI users Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c2317c6509..586abd0541 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -40,6 +40,8 @@ jobs: || github.actor == 'vasunvidia' || github.actor == 'erhoo82' || github.actor == 'kocchop' + || github.actor == 'youngeunkwon0405' + || github.actor == 'KshitijLakhani' ) steps: - name: Check if comment is issued by authorized person From 64126aa8c469b2a97ace01f925f3d5786d5fd1bb Mon Sep 17 00:00:00 2001 From: Youngeun Kwon Date: Mon, 2 Dec 2024 12:26:48 -0800 Subject: [PATCH 2/2] Improving communication overlap for the case of multi kernel queue usage (#1308) * draft implementation Signed-off-by: Youngeun Kwon * compile error fix Signed-off-by: Youngeun Kwon * fix compile error Signed-off-by: Youngeun Kwon * remove print Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Edit comments Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * edit the bulk-overlap test case Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add version guard Signed-off-by: Youngeun Kwon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add runtime version guard Signed-off-by: Youngeun Kwon * fix the version guard Signed-off-by: Youngeun Kwon --------- Signed-off-by: Youngeun Kwon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../distributed/test_comm_gemm_overlap.py | 34 +++-- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 30 ++++- .../userbuffers/userbuffers.cu | 116 ++++++++++++++---- .../userbuffers/userbuffers.h | 18 ++- .../transformer_engine/comm_gemm_overlap.h | 2 +- 5 files changed, 157 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index ce46a72189..f81fbae1fe 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -209,19 +209,39 @@ def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): @pytest.mark.parametrize( - "comm_type,fp8", + "comm_type, fp8, connections", [ - ("AG", False), - ("RS", False), - ("RS", True), + ("AG", False, 1), + ("RS", False, 1), + ("RS", True, 1), + ("AG", False, 8), + ("RS", False, 8), + ("RS", True, 8), + ], + ids=[ + "ALL-GATHER - BF16 - 1 connections", + "REDUCE-SCATTER - BF16 - 1 connections", + "REDUCE-SCATTER - FP8 - 1 connections", + "ALL-GATHER - BF16 - 8 connections", + "REDUCE-SCATTER - BF16 - 8 connections", + "REDUCE-SCATTER - FP8 - 8 connections", ], - ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "], ) -def test_bulk_overlaps(comm_type, fp8): +def test_bulk_overlaps(comm_type, fp8, connections): """ Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + if connections == 8: + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip( + "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" + " 9.0 (HOPPER ARCH)." + ) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + else: + _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) @pytest.mark.parametrize( diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index a663385b68..c6f0f870ff 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -90,6 +90,23 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_stop_comm, 0); + + /* + Defining the launcher order between the communication and GEMM kernels + using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. + The event is used to schedule the communication kernel before the GEMM. + This is needed only for Hopper, which uses persistent CTA execution. + */ + int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); + int runtime_version = 0; + cudaRuntimeGetVersion(&runtime_version); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { + cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + } else { + _comm_launch_event = 0; + } } CommOverlapCore::~CommOverlapCore() { @@ -97,6 +114,7 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); + if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); if (_atomic_gemm) cudaFree(_counter.dptr()); @@ -168,7 +186,8 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper // Communication: AG and RS int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size if (comm_type == CommOverlapType::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized); @@ -178,13 +197,18 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, - comm_elements, _ub_comm, _stream_comm); + comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } } assert(pre_gelu_out.numel() == 0); + // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch + if (_comm_launch_event) + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0)); nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, stream_main); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 26843d8107..91667958e7 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1366,6 +1366,28 @@ __global__ void __launch_bounds__(MAX_THREADS) cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; +#if (CUDART_VERSION >= 12030) +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \ + attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3 +#else +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2 +#endif + +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \ + ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH; + #define callranks_ag(x) \ if (ar_nvsize == x) { \ int arg1 = op - NVTE_MAX_OPS, \ @@ -1753,7 +1775,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1766,11 +1789,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) + } } } @@ -1790,7 +1822,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con } void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1803,17 +1836,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) + } } } void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, communicator *comm, - cudaStream_t stream) { + cudaStream_t stream, cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1827,23 +1869,35 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + } } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + } } } void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { + reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream, + comm_launch_event); } template void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1857,33 +1911,43 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + } } template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream) { + const int elements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { reducescatter2_userbuff_stridedoutput_fp8(output, scale, handler, offset, elements, 1, 0, - comm, stream); + comm, stream, comm_launch_event); } template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 57e68afce0..75655ef691 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -213,7 +213,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * // for TP-parallelism, only single node is implemented void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); /* each Rank input is allgather2_userbuff_inplace: offset+myrank*elements @@ -228,21 +229,26 @@ for(int slice=0;slice void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream = 0); + const int elements, communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 17ecca5ff0..1d5d192a39 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -62,7 +62,7 @@ class CommOverlapCore { bool _ubuf_scale_inv_initialized{false}; std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; public: CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,