You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I noticed in the TransformerEngine source code there are two environment variables related to LayerNorm/RMSNorm:
NVTE_FWD_LAYERNORM_SM_MARGIN
NVTE_BWD_LAYERNORM_SM_MARGIN
In NVIDIA’s submission for MLPerf Training 4.1 results, these variables were set to 8. The comments indicate that setting these two variables can improve p2p overlap performance on H100 GPUs:
# source: https://github.com/mlcommons/training_results_v4.1/blob/8821c7037ffd06e3775398fd39361a4c591d2235/NVIDIA/benchmarks/gpt3/implementations/eos-dfw_n1452_ngc24.04_nemo/config_common.sh#L9# This is to improve p2p overlap on H100, and it shouldn't affect A100:export NVTE_FWD_LAYERNORM_SM_MARGIN=8
export NVTE_BWD_LAYERNORM_SM_MARGIN=8
Could you please clarify which type of P2P (p2p in pipeline parallelism, p2p in context parallelism, or p2p in tp-overlap) these variables impact? Additionally, would you mind provide some tuning recommendations for these parameters specifically for H800 and H20 GPUs?
Thanks!
The text was updated successfully, but these errors were encountered:
When running with complicated communication schemes, we've found that it's often beneficial to dedicate some of the SMs ("streaming multiprocessors", basically GPU cores) to communication so that we can overlap with compute. In other words, we explicitly tell compute kernels (GEMMs, LayerNorm, etc) not to use all available SMs but to leave a few for NCCL.
Figuring out the best value requires some experimentation and I would recommend capturing an Nsight profile to see how many SMs are taken up by each kernel. In my experience NCCL kernels on H100 take 8 SMs, so that's a good starting point. I don't remember off the top of my head, but I think this specific optimization in LayerNorm is to help overlap LayerNorm with pipeline-parallel communication.
Hi!
I noticed in the TransformerEngine source code there are two environment variables related to LayerNorm/RMSNorm:
NVTE_FWD_LAYERNORM_SM_MARGIN
NVTE_BWD_LAYERNORM_SM_MARGIN
In NVIDIA’s submission for MLPerf Training 4.1 results, these variables were set to 8. The comments indicate that setting these two variables can improve p2p overlap performance on H100 GPUs:
Could you please clarify which type of P2P (p2p in pipeline parallelism, p2p in context parallelism, or p2p in tp-overlap) these variables impact? Additionally, would you mind provide some tuning recommendations for these parameters specifically for H800 and H20 GPUs?
Thanks!
The text was updated successfully, but these errors were encountered: