-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make TransformerLayer accept a bshd
or sbhd
tensor format
#557
make TransformerLayer accept a bshd
or sbhd
tensor format
#557
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Once the CI passes, can merge.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…to plumb_tensor_format_thru_transformer_layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix a few small things. Looks good to me. Thanks!
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…to plumb_tensor_format_thru_transformer_layer
@sudhakarsingh27 Please resolve the merge conflicts @cyanguwa Could you rereview? |
…to plumb_tensor_format_thru_transformer_layer Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
/te-ci/pytorch |
/te-ci pytorch |
1 similar comment
/te-ci pytorch |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@sudhakarsingh27 Please fix CI before merging -- looks like there's a few failed jobs. |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…#557) * make TransformerLayer accept a `bshd` or `sbhd` tensor format Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Fixes from feedback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * more feedback fixes Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * remove incorrect info from docstring Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix from feedback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
* make TransformerLayer accept a `bshd` or `sbhd` tensor format Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * Fixes from feedback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * more feedback fixes Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * remove incorrect info from docstring Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> * fix from feedback Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com> --------- Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Changes:
DotProductAttention
can interpret the input tensor insbhd
orbshd
format butMultiHeadAttention
andTransformerLayer
aren't aware of this. This PR plumbs this information throughTransformerLayer
->MultiHeadAttention
->DotProductAttention
.RotaryPositionEmbedding
also needs to be aware of this format to correctly apply the rope cache to theq
andk
inputs.Uses:
LlamaDecoderLayer
withTransformerLayer
, it helps to have this control since HF isbshd
by default and TE issbhd
.