diff --git a/zeus/show_env.py b/zeus/show_env.py index c8478051..931e0f12 100644 --- a/zeus/show_env.py +++ b/zeus/show_env.py @@ -15,6 +15,8 @@ from zeus.utils import framework from zeus.device import get_gpus, get_cpus from zeus.device.cpu import RAPLCPUs +from zeus.device.gpu.common import ZeusGPUInitError, EmptyGPUs +from zeus.device.cpu.common import ZeusCPUInitError, EmptyCPUs SECTION_SEPARATOR = "=" * 80 + "\n" @@ -29,15 +31,37 @@ def show_env(): package_availability = "\nPackage availability and versions:\n" package_availability += f" Zeus: {zeus.__version__}\n" - if framework.torch_is_available(): + try: + torch_available = framework.torch_is_available() + torch_cuda_available = True + except RuntimeError: + torch_available = framework.torch_is_available(ensure_cuda=False) + torch_cuda_available = False + + if torch_available and torch_cuda_available: + torch = framework.MODULE_CACHE["torch"] + package_availability += f" PyTorch: {torch.__version__} (with CUDA support)\n" + elif torch_available and not torch_cuda_available: torch = framework.MODULE_CACHE["torch"] - package_availability += f" PyTorch: {torch.__version__}\n" + package_availability += ( + f" PyTorch: {torch.__version__} (without CUDA support)\n" + ) else: package_availability += " PyTorch: not available\n" - if framework.jax_is_available(): + try: + jax_available = framework.jax_is_available() + jax_cuda_available = True + except RuntimeError: + jax_available = framework.jax_is_available(ensure_cuda=False) + jax_cuda_available = False + + if jax_available and jax_cuda_available: + jax = framework.MODULE_CACHE["jax"] + package_availability += f" JAX: {jax.__version__} (with CUDA support)\n" + elif jax_available and not jax_cuda_available: jax = framework.MODULE_CACHE["jax"] - package_availability += f" JAX: {jax.__version__}\n" + package_availability += f" JAX: {jax.__version__} (without CUDA support)\n" else: package_availability += " JAX: not available\n" @@ -45,7 +69,10 @@ def show_env(): print(SECTION_SEPARATOR) gpu_availability = "\nGPU availability:\n" - gpus = get_gpus() + try: + gpus = get_gpus() + except ZeusGPUInitError: + gpus = EmptyGPUs() if len(gpus) > 0: for i in range(len(gpus)): gpu_availability += f" GPU {i}: {gpus.getName(i)}\n" @@ -55,9 +82,12 @@ def show_env(): print(SECTION_SEPARATOR) cpu_availability = "\nCPU availability:\n" - cpus = get_cpus() - assert isinstance(cpus, RAPLCPUs) + try: + cpus = get_cpus() + except ZeusCPUInitError: + cpus = EmptyCPUs() if len(cpus) > 0: + assert isinstance(cpus, RAPLCPUs) for i in range(len(cpus)): cpu_availability += f" CPU {i}:\n CPU measurements available ({cpus.cpus[i].rapl_file.path})\n" if cpus.supportsGetDramEnergyConsumption(i): diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index 26aa0251..ca4d7cf7 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -13,14 +13,13 @@ @lru_cache(maxsize=1) -def torch_is_available(ensure_available: bool = False): +def torch_is_available(ensure_available: bool = False, ensure_cuda: bool = True): """Check if PyTorch is available.""" try: import torch - assert ( - torch.cuda.is_available() - ), "PyTorch is available but does not have CUDA support." + if ensure_cuda and not torch.cuda.is_available(): + raise RuntimeError("PyTorch is available but does not have CUDA support.") MODULE_CACHE["torch"] = torch logger.info("PyTorch with CUDA support is available.") return True @@ -32,12 +31,13 @@ def torch_is_available(ensure_available: bool = False): @lru_cache(maxsize=1) -def jax_is_available(ensure_available: bool = False): +def jax_is_available(ensure_available: bool = False, ensure_cuda: bool = True): """Check if JAX is available.""" try: import jax # type: ignore - assert jax.devices("gpu"), "JAX is available but does not have CUDA support." + if ensure_cuda and not jax.devices("gpu"): + raise RuntimeError("JAX is available but does not have CUDA support.") MODULE_CACHE["jax"] = jax logger.info("JAX with CUDA support is available.") return True