diff --git a/setup.py b/setup.py index e6bfe496ff..d50a9b8706 100644 --- a/setup.py +++ b/setup.py @@ -265,7 +265,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None: # Framework-specific requirements if "pytorch" in frameworks(): - add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.6,!=2.0.9,!=2.1.0"]) + add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) if "jax" in frameworks(): if not found_pybind11(): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 924e2bb97d..f5e7753e6a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -58,7 +58,6 @@ _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version_required = packaging.version.Version("2.0.6") -_flash_attn_max_version = packaging.version.Version("2.5.6") _flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") _flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") @@ -1658,9 +1657,6 @@ def __init__( assert ( _flash_attn_version >= _flash_attn_version_required ), f"FlashAttention minimum version {_flash_attn_version_required} is required." - assert ( - _flash_attn_version <= _flash_attn_max_version - ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." self.norm_factor = norm_factor self.attention_dropout_ctx = attention_dropout_ctx