Skip to content

Commit

Permalink
Revert "Update FA version to 2.5.6 (#714)"
Browse files Browse the repository at this point in the history
This reverts commit 965803c.
  • Loading branch information
ksivaman committed Apr 3, 2024
1 parent 2dd6b14 commit 47276e1
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:

# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.6,!=2.0.9,!=2.1.0"])
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
if not found_pybind11():
Expand Down
4 changes: 0 additions & 4 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@

_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_version_required = packaging.version.Version("2.0.6")
_flash_attn_max_version = packaging.version.Version("2.5.6")
_flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1")
_flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3")
_flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4")
Expand Down Expand Up @@ -1658,9 +1657,6 @@ def __init__(
assert (
_flash_attn_version >= _flash_attn_version_required
), f"FlashAttention minimum version {_flash_attn_version_required} is required."
assert (
_flash_attn_version <= _flash_attn_max_version
), f"FlashAttention maximum version {_flash_attn_max_version} is supported."

self.norm_factor = norm_factor
self.attention_dropout_ctx = attention_dropout_ctx
Expand Down

0 comments on commit 47276e1

Please sign in to comment.