Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make TransformerLayer accept a bshd or sbhd tensor format #557

Conversation

sudhakarsingh27
Copy link
Collaborator

Changes:

  1. DotProductAttention can interpret the input tensor in sbhd or bshd format but MultiHeadAttention and TransformerLayer aren't aware of this. This PR plumbs this information through TransformerLayer -> MultiHeadAttention -> DotProductAttention.
  2. RotaryPositionEmbedding also needs to be aware of this format to correctly apply the rope cache to the q and k inputs.

Uses:

  1. When replacing layers in HF models like LlamaDecoderLayer with TransformerLayer, it helps to have this control since HF is bshd by default and TE is sbhd.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 self-assigned this Dec 8, 2023
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Once the CI passes, can merge.

transformer_engine/pytorch/attention.py Show resolved Hide resolved
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…to plumb_tensor_format_thru_transformer_layer
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix a few small things. Looks good to me. Thanks!

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…to plumb_tensor_format_thru_transformer_layer
@ptrendx ptrendx added the 1.3.0 label Jan 17, 2024
@ptrendx
Copy link
Member

ptrendx commented Jan 17, 2024

@sudhakarsingh27 Please resolve the merge conflicts

@cyanguwa Could you rereview?

…to plumb_tensor_format_thru_transformer_layer

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci/pytorch

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch

1 similar comment
@cyanguwa
Copy link
Collaborator

/te-ci pytorch

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sudhakarsingh27 sudhakarsingh27 merged commit 36047fd into NVIDIA:main Jan 18, 2024
9 checks passed
@cyanguwa
Copy link
Collaborator

@sudhakarsingh27 Please fix CI before merging -- looks like there's a few failed jobs.

sudhakarsingh27 added a commit to sudhakarsingh27/TransformerEngine that referenced this pull request Jan 19, 2024
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
ksivaman pushed a commit that referenced this pull request Jan 20, 2024
fix failing tests due to PR #557

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>
Wong4j pushed a commit to Wong4j/TransformerEngine that referenced this pull request Jan 22, 2024
…#557)

* make TransformerLayer accept a `bshd` or `sbhd` tensor format

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Fixes from feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* more feedback fixes

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* remove incorrect info from docstring

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix from feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
ptrendx pushed a commit that referenced this pull request Jan 22, 2024
* make TransformerLayer accept a `bshd` or `sbhd` tensor format

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* Fixes from feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* more feedback fixes

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* remove incorrect info from docstring

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

* fix from feedback

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

---------

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
ptrendx pushed a commit that referenced this pull request Jan 22, 2024
fix failing tests due to PR #557

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants