Skip to content

Commit

Permalink
Merge branch 'main' into multicastRS
Browse files Browse the repository at this point in the history
  • Loading branch information
ksivaman authored Apr 3, 2024
2 parents 390f166 + 47276e1 commit d189c9e
Show file tree
Hide file tree
Showing 31 changed files with 1,797 additions and 2,333 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Transformer Engine
Latest News
==================


* [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-release-notes.html>`_
* [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_

.. image:: docs/examples/H200-NeMo-performance.png
Expand Down Expand Up @@ -226,7 +226,7 @@ Transformer Engine has been integrated with popular LLM frameworks such as:
* `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_
* `NVIDIA Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_
* `NVIDIA NeMo Framework <https://github.com/NVIDIA/NeMo-Megatron-Launcher>`_
* `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-core-features-v2-tensor-parallelism.html>`
* `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-core-features-v2-tensor-parallelism.html>`_
* `Colossal-AI <https://github.com/hpcaitech/ColossalAI>`_ - Coming soon!
* `PeriFlow <https://github.com/friendliai/periflow-python-sdk>`_ - Coming soon!
* `GPT-NeoX <https://github.com/EleutherAI/gpt-neox>`_ - Coming soon!
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.5.0.dev0
1.6.0.dev0
51 changes: 30 additions & 21 deletions docs/examples/te_llama/te_llama.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from transformers.utils.hub import get_checkpoint_shard_files

@contextmanager
def replace_decoder(te_decodder_cls):
def replace_decoder(te_decoder_cls):
"""
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
"""
original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decodder_cls
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
try:
yield
finally:
Expand Down Expand Up @@ -56,6 +56,7 @@ def __init__(self, config, *args, **kwargs):
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads
)
te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
Expand Down Expand Up @@ -84,7 +85,7 @@ class is monkey-patched with `TELlamaDecoderLayer` class before
"""

def __new__(cls, config: LlamaConfig):
with replace_decoder(te_decodder_cls=TELlamaDecoderLayer):
with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):
llama_for_causal_lm = LlamaForCausalLM(config)
return llama_for_causal_lm

Expand Down Expand Up @@ -120,53 +121,61 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
assert not isinstance(resolved_archive_file, list)
resolved_archive_file = [resolved_archive_file]

error_msgs = []
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
replaced_layers = replace_params(state_dict, vanilla_model.state_dict())

error_msgs += _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
# replace_params copies parameters relevant only to TransformerEngine
replace_params(state_dict, vanilla_model.state_dict(), config)
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")

# Force mem release. Taken from huggingface code
del state_dict
gc.collect()

return vanilla_model

def replace_params(hf_state_dict, te_state_dict):
def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = 'model.layers.\d+.'
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())



for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in TE model
if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict:
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + 'input_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]

if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict:
if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]

if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict:
if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]

if layer_prefix + 'self_attention.layernorm_qkv.value_weight' in te_state_dict:
if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]

if layer_prefix + 'self_attention.proj.weight' in te_state_dict:
if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]

if layer_prefix + 'layernorm_mlp.layer_norm_weight' in te_state_dict:
if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]

if layer_prefix + 'layernorm_mlp.fc1_weight' in te_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0)

if layer_prefix + 'layernorm_mlp.fc2_weight' in te_state_dict:

# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \
hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data

if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \
hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data

if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict:
te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]

return all_layer_prefixes
Loading

0 comments on commit d189c9e

Please sign in to comment.