Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Low Rank Adaptation (LoRA). #745

Merged
merged 8 commits into from
Apr 16, 2024
Merged

Conversation

mingxu1067
Copy link
Collaborator

@mingxu1067 mingxu1067 commented Apr 2, 2024

  • Implemented LoRA and related tests.
  • Supported the LoRA scope control to TransformerLayer and MultiHeadAttention.
  • LoRA Implementation Details::
  • Only support len(axis) and len(features) <= 5
  • When there are multiple dimensions of features, the LoRA would transform the last dimension only.
    For example, X in shape of (B, S, Hin), features in (3, Hout), axis = (2,), then LoRA would do like (B, S, Hin) x (3, rank) = (B, S, 3, rank), then (B, S, 3, rank) x (3, rank, Hout) = (B, S, 3, Hout)

@zlsh80826
Copy link
Collaborator

Generally LGTM!

SCOPE_EX_OUTPUT_PROJ = 'exclude_output_proj'
SCOPE_EX_MLP = 'exclude_mlp'

assert scope in [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the low_rank_adaptation_scope is expected to be a string. That would make users confuse to use None or string 'None'. It would be better to enhance the handle of None. Either accepting None and converting it to string 'None' or showing an error message to let user pass string 'None' is okay.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This a good point. Thank you for bringing this up. Added the handle to None.

Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
Signed-off-by: Ming Huang <mingh@nvidia.com>
@mingxu1067
Copy link
Collaborator Author

/te-ci jax

Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@yhtang
Copy link

yhtang commented Apr 16, 2024

When could this get merged? I will cherry-pick this into the JAX 24.04 NGC release once it is merged into TE.

@denera denera merged commit 7c1828f into NVIDIA:main Apr 16, 2024
15 checks passed
@denera
Copy link
Collaborator

denera commented Apr 16, 2024

@yhtang Just merged. Thanks for the heads up!

pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 9, 2024
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 15, 2024
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 16, 2024
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 23, 2024
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants