Skip to content

Commit

Permalink
[JAX] Skip V100 encoder tests (#1262)
Browse files Browse the repository at this point in the history
* Skip encoder tests on V100

* Fix mulitprocessing jax.distributed.init

* Remove XLA xla_gpu_deterministic_ops which causes segfault

---------

Signed-off-by: Reese Wang <rewang@nvidia.com>
  • Loading branch information
zlsh80826 authored Oct 22, 2024
1 parent 7b18f23 commit 35f7d26
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 6 deletions.
14 changes: 14 additions & 0 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Shared functions for the encoder tests"""
from functools import lru_cache

from transformer_engine.transformer_engine_jax import get_device_compute_capability


@lru_cache
def is_bf16_supported():
"""Return if BF16 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 80
4 changes: 4 additions & 0 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
NAMED_BROADCAST_AXIS = "my_broadcast_axis"
Expand Down Expand Up @@ -434,6 +436,7 @@ def setUpClass(cls):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand All @@ -446,6 +449,7 @@ def test_te_fp8(self):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
Expand Down
3 changes: 3 additions & 0 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
Expand Down Expand Up @@ -402,6 +404,7 @@ def setUpClass(cls):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand Down
12 changes: 8 additions & 4 deletions examples/jax/encoder/test_multiprocessing_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
Expand Down Expand Up @@ -552,8 +554,9 @@ def encoder_parser(args):
def query_gpu(q):
"""Query GPU info on the system"""
gpu_has_fp8, reason = te.fp8.is_fp8_available()
gpu_has_bf16 = is_bf16_supported()
num_gpu = len(jax.devices())
q.put([num_gpu, gpu_has_fp8, reason])
q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason])


def unittest_query_gpu():
Expand All @@ -566,15 +569,15 @@ def unittest_query_gpu():
q = mp.Queue()
p = mp.Process(target=query_gpu, args=(q,))
p.start()
num_gpu, gpu_has_fp8, reason = q.get()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get()
p.join()
return num_gpu, gpu_has_fp8, reason
return num_gpu, gpu_has_fp8, gpu_has_bf16, reason


class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

num_gpu, gpu_has_fp8, reason = unittest_query_gpu()
num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu()

def exec(self, use_fp8):
"""Run 3 epochs for testing"""
Expand All @@ -598,6 +601,7 @@ def exec(self, use_fp8):

return results

@unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
results = self.exec(False)
Expand Down
3 changes: 3 additions & 0 deletions examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax

from common import is_bf16_supported

PARAMS_KEY = "params"
DROPOUT_KEY = "dropout"
INPUT_KEY = "input_rng"
Expand Down Expand Up @@ -321,6 +323,7 @@ def setUpClass(cls):
"""Run 4 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
Expand Down
2 changes: 0 additions & 2 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,5 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt

pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist

# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py

0 comments on commit 35f7d26

Please sign in to comment.