Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Fused attention unit tests fixes and refinements (#1352)
* Add util functions to attn_mask_type Signed-off-by: Reese Wang <rewang@nvidia.com> * Add util functions to qkv_layout Signed-off-by: Reese Wang <rewang@nvidia.com> * Fix THD cross reference code Signed-off-by: Reese Wang <rewang@nvidia.com> * Remove explicit segment_pad, encoding it to segment_ids Signed-off-by: Reese Wang <rewang@nvidia.com> * Add jax.jit, replace _token with segment_ids, rename bias shape enum Signed-off-by: Reese Wang <rewang@nvidia.com> * Add comment for make_mask Signed-off-by: Reese Wang <rewang@nvidia.com> * Clean code Signed-off-by: Reese Wang <rewang@nvidia.com> * Add doc strings for the added functions Signed-off-by: Reese Wang <rewang@nvidia.com> * Remove cache for fa deterministic which causes UT failed Signed-off-by: Reese Wang <rewang@nvidia.com> * Rename fixture to avoid conflict Signed-off-by: Reese Wang <rewang@nvidia.com> --------- Signed-off-by: Reese Wang <rewang@nvidia.com>
- Loading branch information