Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Numerical Incorrectness for BERT model training with #497 #1010

Open
ZYHowell opened this issue Jul 12, 2024 · 1 comment
Open

[BUG] Numerical Incorrectness for BERT model training with #497 #1010

ZYHowell opened this issue Jul 12, 2024 · 1 comment
Assignees

Comments

@ZYHowell
Copy link

ZYHowell commented Jul 12, 2024

I'm running the Megatron-LM BERT example with Wikipedia data, and observed a loss divergence between TE v1.1 and v1.2. I then debug by fixing the Megatron-LM/torch version, and binary searched until a precise commit id: 32db392 (#497). The loss curve is shown below.

Update: also notice a huge performance drop after this commit

image

Any suggestion what else can I do to help with the debugging?

Megatron-LM version: nightly
apex commit id: f8e60c47c5c3034ddf8181e33910f3da5b289f25 (v0.1)
CUDA version: 12.1; cudnn version in cudnn_version.h: 8.9.7
torch version: 2.3.1
flash_attn version (BERT model cannot use flash_attn): 2.3.3

launch scripts:

export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_FLASH_ATTN=0

GPUS_PER_NODE=4
NNODES=1
WORK_DIR=/home/ubuntu/megatron-mcore-test
CKPT_DIR=/data/ckpt/test_mcore_te
DATA_PATH=/data/wikipedia_text_sentence
VOCAB_FILE=/data/bert-large-uncased-vocab.txt

# Change for multinode config
head_node_ip=$<MY_HEAD_NODE_IP>

DISTRIBUTED_ARGS=(
    --nproc_per_node $GPUS_PER_NODE
    --nnodes $NNODES
    --rdzv_id $RANDOM
    --rdzv_backend c10d
    --rdzv_endpoint $head_node_ip:29502
)

BERT_MODEL_ARGS=(
    --num-layers 24 
    --hidden-size 1024 
    --num-attention-heads 16 
    --seq-length 512
    --max-position-embeddings 512
    --bert-no-binary-head
)

TRAINING_ARGS=(
    --micro-batch-size 64
    --global-batch-size 256
    --train-iters 1000000 
    --weight-decay 1e-2 
    --clip-grad 1.0 
    --bf16
    --lr 0.0001
    --lr-decay-iters 990000 
    --lr-decay-style cosine
    --min-lr 1.0e-5 
    --weight-decay 1e-2 
    --lr-warmup-fraction .01 
    --clip-grad 1.0 
)

MODEL_PARALLEL_ARGS=(
	--tensor-model-parallel-size 4
	--pipeline-model-parallel-size 1
)

DATA_ARGS=(
    --data-path $DATA_PATH 
    --vocab-file $VOCAB_FILE 
    --split 998,1,1
)

EVAL_AND_LOGGING_ARGS=(
    --log-interval 100
    --save-interval 20000 
    --eval-interval 500 
    --save $CKPT_DIR
    --load $CKPT_DIR
    --eval-iters 10
    --tensorboard-dir $CKPT_DIR
    --wandb-project=wiki-infra-test \
    --wandb-exp-name=mcore-base-test-te \
)

echo Node IP: $head_node_ip
export LOGLEVEL=INFO

echo TE VERSION:
python3 -m pip show transformer_engine

command="torchrun ${DISTRIBUTED_ARGS[@]} ${WORK_DIR}/pretrain_bert.py \
    ${BERT_MODEL_ARGS[@]} \
    ${TRAINING_ARGS[@]} \
    ${MODEL_PARALLEL_ARGS[@]} \
    ${DATA_ARGS[@]} \
    ${EVAL_AND_LOGGING_ARGS[@]}"
echo ${command}

torchrun ${DISTRIBUTED_ARGS[@]} ${WORK_DIR}/pretrain_bert.py \
    ${BERT_MODEL_ARGS[@]} \
    ${TRAINING_ARGS[@]} \
    ${MODEL_PARALLEL_ARGS[@]} \
    ${DATA_ARGS[@]} \
    ${EVAL_AND_LOGGING_ARGS[@]}
@cyanguwa
Copy link
Collaborator

cyanguwa commented Jul 17, 2024

Hi @ZYHowell ,

Could you please try the latest TE and also cuDNN 9.0+? If the divergence problem disappears, then we can work backwards to see if it's a particular version/commit of TE or version of cuDNN that's the problem.

In the latest TE, you can run with NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 to see more details about the backend used to run and why other backends are disabled. For TE 1.1 or 1.2, you can add a print line manually here:

print('backends: ', use_flash_attention, use_fused_attention)

For commits before PR 497, I wonder if you were using the UnfusedDotProductAttention backend, and after PR 497, the FusedAttention backend. Either way, it's helpful if we can figure out which backend was used and then focus on that backend when debugging.

Thanks,
Charlene

@cyanguwa cyanguwa self-assigned this Jul 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants