Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Unifying GeLU and GeGLU in LayerNorm MLP #765

Merged
merged 9 commits into from
Apr 24, 2024
Merged

Conversation

phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Apr 9, 2024

This PR unifies the GeLU and GEGLU implementations in LayerNormMLP via a generalized fused_layernorm_fp8_mlp. Previously, there were two separate APIs for the two mentioned activations. The new routine takes activation_type: Tuple and use_bias: bool as two additional arguments, compared to old routines.

This is a preparation step for adding more activations (i.e. swiglu) later.

@denera denera added enhancement New feature or request jax labels Apr 9, 2024
@denera denera linked an issue Apr 9, 2024 that may be closed by this pull request
phu0ngng added 2 commits April 9, 2024 22:44
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
test_layer.py, and test_praxis_layer.py

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@zlsh80826
Copy link
Collaborator

/te-ci jax

transformer_engine/jax/mlp.py Outdated Show resolved Hide resolved
transformer_engine/jax/mlp.py Outdated Show resolved Hide resolved
transformer_engine/jax/mlp.py Outdated Show resolved Hide resolved
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Copy link
Collaborator

@denera denera left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@denera
Copy link
Collaborator

denera commented Apr 15, 2024

@phu0ngng do you have CI permissions yet? If not, please check in with @ptrendx to get permissions and then trigger the CI run for this. We can merge once the tests come back clean.

@denera
Copy link
Collaborator

denera commented Apr 15, 2024

@phu0ngng Also please fix the linting errors here before the CI. Thanks!

************* Module transformer_engine.jax.flax.module
transformer_engine/jax/flax/module.py:690:26: W1401: Anomalous backslash in string: '\m'. String constant might be missing an r prefix. (anomalous-backslash-in-string)
transformer_engine/jax/flax/module.py:690:42: W1401: Anomalous backslash in string: '\s'. String constant might be missing an r prefix. (anomalous-backslash-in-string)
transformer_engine/jax/flax/module.py:690:48: W1401: Anomalous backslash in string: '\m'. String constant might be missing an r prefix. (anomalous-backslash-in-string)
transformer_engine/jax/flax/module.py:690:66: W1401: Anomalous backslash in string: '\e'. String constant might be missing an r prefix. (anomalous-backslash-in-string)
transformer_engine/jax/flax/module.py:691:17: W1401: Anomalous backslash in string: '\g'. String constant might be missing an r prefix. (anomalous-backslash-in-string)
transformer_engine/jax/flax/module.py:696:51: W1401: Anomalous backslash in string: '\g'. String constant might be missing an r prefix. (anomalous-backslash-in-string)
transformer_engine/jax/flax/module.py:702:64: W1401: Anomalous backslash in string: '\g'. String constant might be missing an r prefix. (anomalous-backslash-in-string)

Copy link
Collaborator

@zlsh80826 zlsh80826 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@mingxu1067 mingxu1067 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kindly remove the unnessary comment in tests/jax/test_custom_call_compute.py#332

transformer_engine/jax/flax/module.py Outdated Show resolved Hide resolved
tests/jax/test_custom_call_compute.py Outdated Show resolved Hide resolved
tests/jax/test_custom_call_compute.py Outdated Show resolved Hide resolved
transformer_engine/jax/mlp.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mingxu1067 mingxu1067 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kindly remove the unnessary comment in tests/jax/test_custom_call_compute.py#332

@phu0ngng
Copy link
Collaborator Author

/te-ci jax

tests/jax/test_custom_call_compute.py Outdated Show resolved Hide resolved
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
phu0ngng and others added 2 commits April 22, 2024 16:50
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
@phu0ngng
Copy link
Collaborator Author

/te-ci jax

@phu0ngng
Copy link
Collaborator Author

Hi @denera, @mingxu1067,
I resolved all of your change requests.
Please have a look and let me know if you have any other suggestions.

Copy link
Collaborator

@mingxu1067 mingxu1067 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@denera denera merged commit dac0001 into NVIDIA:main Apr 24, 2024
15 checks passed
@zlsh80826
Copy link
Collaborator

Congratulations on your first pull request! @phu0ngng
This really help the future maintainance for various activation types!

pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 15, 2024
* combined layernorm_geglu with layernorm_gelu into fused_layernorm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* fixes to pass all unit tests in test_custom_call_compute.py,
test_layer.py, and test_praxis_layer.py

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* cleaning and formatting

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* renaming based on reviewers suggestions

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* implemented partial fused layernorm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* geglu + bias passed tests

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* added partial fused calculation for dbias_1

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* clean up

Co-authored-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Co-authored-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 16, 2024
* combined layernorm_geglu with layernorm_gelu into fused_layernorm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* fixes to pass all unit tests in test_custom_call_compute.py,
test_layer.py, and test_praxis_layer.py

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* cleaning and formatting

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* renaming based on reviewers suggestions

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* implemented partial fused layernorm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* geglu + bias passed tests

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* added partial fused calculation for dbias_1

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* clean up

Co-authored-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Co-authored-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 23, 2024
* combined layernorm_geglu with layernorm_gelu into fused_layernorm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* fixes to pass all unit tests in test_custom_call_compute.py,
test_layer.py, and test_praxis_layer.py

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* cleaning and formatting

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* renaming based on reviewers suggestions

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* implemented partial fused layernorm

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* geglu + bias passed tests

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* added partial fused calculation for dbias_1

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>

* clean up

Co-authored-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>

---------

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Co-authored-by: Alp Dener <adener@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request jax
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants