Skip to content

Commit

Permalink
Handle cases where things are missing
Browse files Browse the repository at this point in the history
  • Loading branch information
jaywonchung committed Sep 9, 2024
1 parent 8da36b4 commit 697d327
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
44 changes: 37 additions & 7 deletions zeus/show_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from zeus.utils import framework
from zeus.device import get_gpus, get_cpus
from zeus.device.cpu import RAPLCPUs
from zeus.device.gpu.common import ZeusGPUInitError, EmptyGPUs
from zeus.device.cpu.common import ZeusCPUInitError, EmptyCPUs


SECTION_SEPARATOR = "=" * 80 + "\n"
Expand All @@ -29,23 +31,48 @@ def show_env():
package_availability = "\nPackage availability and versions:\n"
package_availability += f" Zeus: {zeus.__version__}\n"

if framework.torch_is_available():
try:
torch_available = framework.torch_is_available()
torch_cuda_available = True
except RuntimeError:
torch_available = framework.torch_is_available(ensure_cuda=False)
torch_cuda_available = False

if torch_available and torch_cuda_available:
torch = framework.MODULE_CACHE["torch"]
package_availability += f" PyTorch: {torch.__version__} (with CUDA support)\n"
elif torch_available and not torch_cuda_available:
torch = framework.MODULE_CACHE["torch"]
package_availability += f" PyTorch: {torch.__version__}\n"
package_availability += (
f" PyTorch: {torch.__version__} (without CUDA support)\n"
)
else:
package_availability += " PyTorch: not available\n"

if framework.jax_is_available():
try:
jax_available = framework.jax_is_available()
jax_cuda_available = True
except RuntimeError:
jax_available = framework.jax_is_available(ensure_cuda=False)
jax_cuda_available = False

if jax_available and jax_cuda_available:
jax = framework.MODULE_CACHE["jax"]
package_availability += f" JAX: {jax.__version__} (with CUDA support)\n"
elif jax_available and not jax_cuda_available:
jax = framework.MODULE_CACHE["jax"]
package_availability += f" JAX: {jax.__version__}\n"
package_availability += f" JAX: {jax.__version__} (without CUDA support)\n"
else:
package_availability += " JAX: not available\n"

print(package_availability)

print(SECTION_SEPARATOR)
gpu_availability = "\nGPU availability:\n"
gpus = get_gpus()
try:
gpus = get_gpus()
except ZeusGPUInitError:
gpus = EmptyGPUs()
if len(gpus) > 0:
for i in range(len(gpus)):
gpu_availability += f" GPU {i}: {gpus.getName(i)}\n"
Expand All @@ -55,9 +82,12 @@ def show_env():

print(SECTION_SEPARATOR)
cpu_availability = "\nCPU availability:\n"
cpus = get_cpus()
assert isinstance(cpus, RAPLCPUs)
try:
cpus = get_cpus()
except ZeusCPUInitError:
cpus = EmptyCPUs()
if len(cpus) > 0:
assert isinstance(cpus, RAPLCPUs)
for i in range(len(cpus)):
cpu_availability += f" CPU {i}:\n CPU measurements available ({cpus.cpus[i].rapl_file.path})\n"
if cpus.supportsGetDramEnergyConsumption(i):
Expand Down
12 changes: 6 additions & 6 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@


@lru_cache(maxsize=1)
def torch_is_available(ensure_available: bool = False):
def torch_is_available(ensure_available: bool = False, ensure_cuda: bool = True):
"""Check if PyTorch is available."""
try:
import torch

assert (
torch.cuda.is_available()
), "PyTorch is available but does not have CUDA support."
if ensure_cuda and not torch.cuda.is_available():
raise RuntimeError("PyTorch is available but does not have CUDA support.")
MODULE_CACHE["torch"] = torch
logger.info("PyTorch with CUDA support is available.")
return True
Expand All @@ -32,12 +31,13 @@ def torch_is_available(ensure_available: bool = False):


@lru_cache(maxsize=1)
def jax_is_available(ensure_available: bool = False):
def jax_is_available(ensure_available: bool = False, ensure_cuda: bool = True):
"""Check if JAX is available."""
try:
import jax # type: ignore

assert jax.devices("gpu"), "JAX is available but does not have CUDA support."
if ensure_cuda and not jax.devices("gpu"):
raise RuntimeError("JAX is available but does not have CUDA support.")
MODULE_CACHE["jax"] = jax
logger.info("JAX with CUDA support is available.")
return True
Expand Down

0 comments on commit 697d327

Please sign in to comment.