-
Notifications
You must be signed in to change notification settings - Fork 354
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Support segment_ids/pos as FA inputs (#1406)
* POC for segment_ids/segment_pos Signed-off-by: Reese Wang <rewang@nvidia.com> * Change segment_pos position Signed-off-by: Reese Wang <rewang@nvidia.com> * Use RemainingArgs to solve number of parameters mismatches Signed-off-by: Reese Wang <rewang@nvidia.com> * Test mask_descriptor for accomendating different mask representations Signed-off-by: Reese Wang <rewang@nvidia.com> * Fix bugs Signed-off-by: Reese Wang <rewang@nvidia.com> * Use descriptor in bwd Signed-off-by: Reese Wang <rewang@nvidia.com> * Primitives only accepts pure jnp array Signed-off-by: Reese Wang <rewang@nvidia.com> * segment_ids/pos support POC Signed-off-by: Reese Wang <rewang@nvidia.com> * Move seqlens/offsets generation to mask descriptor Signed-off-by: Reese Wang <rewang@nvidia.com> * Rename MaskDescriptor to SequenceDescriptor Signed-off-by: Reese Wang <rewang@nvidia.com> * Generalize get_seqlens_and_offsets Signed-off-by: Reese Wang <rewang@nvidia.com> * Utilize sequence desc on FA bwd Signed-off-by: Reese Wang <rewang@nvidia.com> * Migrate to new API Signed-off-by: Reese Wang <rewang@nvidia.com> * Add docstrings Signed-off-by: Reese Wang <rewang@nvidia.com> * Remove small inputs and test different input format Signed-off-by: Reese Wang <rewang@nvidia.com> * Fix lint Signed-off-by: Reese Wang <rewang@nvidia.com> * Fix seed shardings Signed-off-by: Reese Wang <rewang@nvidia.com> * Optimize sequence converting overhead Signed-off-by: Reese Wang <rewang@nvidia.com> * Optimize seq_offsets calculation Signed-off-by: Reese Wang <rewang@nvidia.com> * Fix up Signed-off-by: Reese Wang <rewang@nvidia.com> * fix lint Signed-off-by: Reese Wang <rewang@nvidia.com> * Fix conflicts Signed-off-by: Reese Wang <rewang@nvidia.com> * Remove reduntant line Signed-off-by: Reese Wang <rewang@nvidia.com> --------- Signed-off-by: Reese Wang <rewang@nvidia.com>
- Loading branch information
Showing
5 changed files
with
786 additions
and
297 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.