Skip to content

Commit

Permalink
Bump FlashAttn version and add deterministic option for FAv2 (#585)
Browse files Browse the repository at this point in the history
* Deterministic FA, bump minimum supported version

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix MQA/GQA

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ksivaman authored Jan 6, 2024
1 parent e2a7531 commit f2bd53c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 92 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:

# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6,<=2.3.3,!=2.0.9,!=2.1.0"])
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
Expand Down
Loading

0 comments on commit f2bd53c

Please sign in to comment.