You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm running the Megatron-LM BERT example with Wikipedia data, and observed a loss divergence between TE v1.1 and v1.2. I then debug by fixing the Megatron-LM/torch version, and binary searched until a precise commit id: 32db392 (#497). The loss curve is shown below.
Update: also notice a huge performance drop after this commit
Any suggestion what else can I do to help with the debugging?
Megatron-LM version: nightly
apex commit id: f8e60c47c5c3034ddf8181e33910f3da5b289f25 (v0.1)
CUDA version: 12.1; cudnn version in cudnn_version.h: 8.9.7
torch version: 2.3.1
flash_attn version (BERT model cannot use flash_attn): 2.3.3
Could you please try the latest TE and also cuDNN 9.0+? If the divergence problem disappears, then we can work backwards to see if it's a particular version/commit of TE or version of cuDNN that's the problem.
In the latest TE, you can run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to see more details about the backend used to run and why other backends are disabled. For TE 1.1 or 1.2, you can add a print line manually here:
For commits before PR 497, I wonder if you were using the UnfusedDotProductAttention backend, and after PR 497, the FusedAttention backend. Either way, it's helpful if we can figure out which backend was used and then focus on that backend when debugging.
I'm running the Megatron-LM BERT example with Wikipedia data, and observed a loss divergence between TE v1.1 and v1.2. I then debug by fixing the Megatron-LM/torch version, and binary searched until a precise commit id: 32db392 (#497). The loss curve is shown below.
Update: also notice a huge performance drop after this commit
Any suggestion what else can I do to help with the debugging?
Megatron-LM version: nightly
apex commit id: f8e60c47c5c3034ddf8181e33910f3da5b289f25 (v0.1)
CUDA version: 12.1; cudnn version in cudnn_version.h: 8.9.7
torch version: 2.3.1
flash_attn version (BERT model cannot use flash_attn): 2.3.3
launch scripts:
The text was updated successfully, but these errors were encountered: