Skip to content

Commit

Permalink
Use deterministic algo for test_fused_attn.py
Browse files Browse the repository at this point in the history
Signed-off-by: Reese Wang <rewang@nvidia.com>
  • Loading branch information
zlsh80826 committed Dec 10, 2024
1 parent f49f32f commit fea31d4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
3 changes: 0 additions & 3 deletions tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ def enable_fused_attn():
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ:
del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"]
11 changes: 11 additions & 0 deletions tests/jax/test_praxis_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@
FP8_FORMATS = [Format.E4M3, Format.HYBRID]


@pytest.fixture(autouse=True, scope="module")
def set_deterministic_algo():
"""
test_praxis_layers will run both flax and praxis interface and expect the results are the same.
We need to enable the deterministic algo to ensure the results are exactly the same.
"""
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
yield
del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"]


def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, f"{key} not found in test dict {test_fd}"
Expand Down

0 comments on commit fea31d4

Please sign in to comment.