diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py new file mode 100644 index 0000000000..dcbfafc467 --- /dev/null +++ b/examples/jax/encoder/common.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Shared functions for the encoder tests""" +from functools import lru_cache + +from transformer_engine.transformer_engine_jax import get_device_compute_capability + + +@lru_cache +def is_bf16_supported(): + """Return if BF16 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 80 diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 25d744887e..bafd9bd2fb 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -22,6 +22,8 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax +from common import is_bf16_supported + DEVICE_DP_AXIS = "data" DEVICE_TP_AXIS = "model" NAMED_BROADCAST_AXIS = "my_broadcast_axis" @@ -434,6 +436,7 @@ def setUpClass(cls): """Run 3 epochs for testing""" cls.args = encoder_parser(["--epochs", "3"]) + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) @@ -446,6 +449,7 @@ def test_te_fp8(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.45 and actual[1] > 0.79 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 9d08254f4d..a4a19b43c2 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -22,6 +22,8 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax +from common import is_bf16_supported + DEVICE_DP_AXIS = "data" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" @@ -402,6 +404,7 @@ def setUpClass(cls): """Run 3 epochs for testing""" cls.args = encoder_parser(["--epochs", "3"]) + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index e581dbc3f9..f54deff69c 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -24,6 +24,8 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax +from common import is_bf16_supported + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" DEVICE_TP_AXIS = "model" @@ -552,8 +554,9 @@ def encoder_parser(args): def query_gpu(q): """Query GPU info on the system""" gpu_has_fp8, reason = te.fp8.is_fp8_available() + gpu_has_bf16 = is_bf16_supported() num_gpu = len(jax.devices()) - q.put([num_gpu, gpu_has_fp8, reason]) + q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason]) def unittest_query_gpu(): @@ -566,15 +569,15 @@ def unittest_query_gpu(): q = mp.Queue() p = mp.Process(target=query_gpu, args=(q,)) p.start() - num_gpu, gpu_has_fp8, reason = q.get() + num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get() p.join() - return num_gpu, gpu_has_fp8, reason + return num_gpu, gpu_has_fp8, gpu_has_bf16, reason class TestEncoder(unittest.TestCase): """Encoder unittests""" - num_gpu, gpu_has_fp8, reason = unittest_query_gpu() + num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu() def exec(self, use_fp8): """Run 3 epochs for testing""" @@ -598,6 +601,7 @@ def exec(self, use_fp8): return results + @unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" results = self.exec(False) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 363759afea..ac71fe4c0e 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -19,6 +19,8 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax +from common import is_bf16_supported + PARAMS_KEY = "params" DROPOUT_KEY = "dropout" INPUT_KEY = "input_rng" @@ -321,6 +323,7 @@ def setUpClass(cls): """Run 4 epochs for testing""" cls.args = encoder_parser(["--epochs", "3"]) + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index db3aa31951..9efec6f2e5 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -18,7 +18,5 @@ pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist -# Make encoder tests to have run-to-run deterministic to have the stable CI results -export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py