Skip to content

Commit

Permalink
explicitly passing visible devices
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Apr 5, 2024
1 parent 1f9d645 commit fe30aad
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions optimum_benchmark/launchers/isolation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def isolation_signal_handler(signum, frame):
signal.signal(signal.SIGUSR1, isolation_signal_handler)


def get_nvidia_devices_pids() -> Set[int]:
def get_nvidia_devices_pids(device_ids: str) -> Set[int]:
if not is_pynvml_available():
raise ValueError(
"The library pynvml is required to get the pids running on NVIDIA GPUs, but is not installed. "
Expand All @@ -41,7 +41,7 @@ def get_nvidia_devices_pids() -> Set[int]:
pynvml.nvmlInit()

devices_pids = set()
devices_ids = [int(device_id) for device_id in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
devices_ids = list(map(int, device_ids.split(",")))

for device_id in devices_ids:
device_handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
Expand All @@ -54,7 +54,7 @@ def get_nvidia_devices_pids() -> Set[int]:
return devices_pids


def get_amd_devices_pids() -> Set[int]:
def get_amd_devices_pids(device_ids: str) -> Set[int]:
if not is_amdsmi_available():
raise ValueError(
"The library amdsmi is required to get the pids running on AMD GPUs, but is not installed. "
Expand All @@ -64,7 +64,7 @@ def get_amd_devices_pids() -> Set[int]:
amdsmi.amdsmi_init()

devices_pids = set()
devices_ids = [int(device_id) for device_id in os.environ["ROCR_VISIBLE_DEVICES"].split(",")]
devices_ids = list(map(int, device_ids.split(",")))

processor_handles = amdsmi.amdsmi_get_processor_handles()
for device_id in devices_ids:
Expand Down Expand Up @@ -93,24 +93,25 @@ def get_amd_devices_pids() -> Set[int]:
return devices_pids


def get_pids_running_on_system_devices() -> Set[int]:
def get_pids_running_on_system_devices(device_ids: str) -> Set[int]:
"""Returns the set of pids running on the system device(s)."""
if is_nvidia_system():
devices_pids = get_nvidia_devices_pids()
devices_pids = get_nvidia_devices_pids(device_ids)
elif is_rocm_system():
devices_pids = get_amd_devices_pids()
devices_pids = get_amd_devices_pids(device_ids)
else:
raise ValueError("get_pids_running_on_system_device is only supported on NVIDIA and AMD GPUs")

return devices_pids


def assert_system_devices_isolation(isolated_pid: int) -> None:
def assert_system_devices_isolation(isolated_pid: int, device_ids: str):
setup_logging("ERROR")

isolation_pid = os.getpid()

while psutil.pid_exists(isolated_pid):
devices_pids = get_pids_running_on_system_devices()
devices_pids = get_pids_running_on_system_devices(device_ids=device_ids)
devices_pids = {pid for pid in devices_pids if psutil.pid_exists(pid)}
isolated_children_pids = {child.pid for child in psutil.Process(isolated_pid).children(recursive=True)}
isolation_children_pids = {child.pid for child in psutil.Process(isolation_pid).children(recursive=True)}
Expand Down Expand Up @@ -141,19 +142,27 @@ def device_isolation(enabled: bool, isolated_pid: int):
yield
return

if is_nvidia_system():
device_ids = os.environ.get("CUDA_VISIBLE_DEVICES", "")
elif is_rocm_system():
device_ids = os.environ.get("ROCR_VISIBLE_DEVICES", "")
else:
raise ValueError("Device isolation is only supported on NVIDIA and AMD GPUs")

isolation_process = Process(
target=assert_system_devices_isolation,
kwargs={"isolated_pid": isolated_pid},
kwargs={"isolated_pid": isolated_pid, "device_ids": device_ids},
daemon=True,
)
isolation_process.start()

LOGGER.info(f"\t+ Launched device(s) isolation process {isolation_process.pid}.")
LOGGER.info(f"\t+ Launched device(s) isolation process {isolation_process.pid}")
LOGGER.info(f"\t+ Isolating device(s) [{device_ids}]")

yield

LOGGER.info("\t+ Closing device(s) isolation process...")

isolation_process.kill()
isolation_process.join()
isolation_process.close()
if isolation_process.is_alive():
LOGGER.info("\t+ Closing device(s) isolation process...")
isolation_process.kill()
isolation_process.join()
isolation_process.close()

0 comments on commit fe30aad

Please sign in to comment.