From 8f871f41f17a8a8ca90596ef0aa0c4d813182fdb Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 5 Dec 2023 11:21:44 -0500 Subject: [PATCH] Check notebook launcher for 3090+ (#2212) * Include dist launch * Better way * CLean * Just do it always * Account for notebook launcher * Use better gpu check * Clean output * Set logic --- src/accelerate/launchers.py | 24 +++++++++++--- src/accelerate/state.py | 15 +++++---- src/accelerate/utils/environment.py | 49 ++++++++++++++++++++++++----- 3 files changed, 67 insertions(+), 21 deletions(-) diff --git a/src/accelerate/launchers.py b/src/accelerate/launchers.py index 0e32d84d06d..b01f6b9f98e 100644 --- a/src/accelerate/launchers.py +++ b/src/accelerate/launchers.py @@ -19,7 +19,14 @@ import torch from .state import AcceleratorState, PartialState -from .utils import PrecisionType, PrepareForLaunch, are_libraries_initialized, is_mps_available, patch_environment +from .utils import ( + PrecisionType, + PrepareForLaunch, + are_libraries_initialized, + check_cuda_p2p_ib_support, + is_mps_available, + patch_environment, +) def test_launch(): @@ -153,16 +160,23 @@ def train(*args): err += f"\n\t* `{lib_name}`" raise RuntimeError(err) - # torch.distributed will expect a few environment variable to be here. We set the ones common to each - # process here (the other ones will be set be the launcher). - with patch_environment( + patched_env = dict( nproc=num_processes, node_rank=node_rank, world_size=num_nodes * num_processes, master_addr=master_addr, master_port=use_port, mixed_precision=mixed_precision, - ): + ) + + # Check for CUDA P2P and IB issues + if not check_cuda_p2p_ib_support(): + patched_env["nccl_p2p_disable"] = "1" + patched_env["nccl_ib_disable"] = "1" + + # torch.distributed will expect a few environment variable to be here. We set the ones common to each + # process here (the other ones will be set be the launcher). + with patch_environment(**patched_env): # First dummy launch if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true": launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU") diff --git a/src/accelerate/state.py b/src/accelerate/state.py index a8fa33e3e13..b4f95b03fbd 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -182,14 +182,6 @@ def __init__(self, cpu: bool = False, **kwargs): self.backend = "nccl" dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs) - if not check_cuda_p2p_ib_support(): - if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ: - raise NotImplementedError( - "Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. " - 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which ' - "will do this automatically." - ) - self.num_processes = torch.distributed.get_world_size() self.process_index = torch.distributed.get_rank() self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) @@ -206,6 +198,13 @@ def __init__(self, cpu: bool = False, **kwargs): self.device = torch.device("cuda", self.local_process_index) if self.device is not None: torch.cuda.set_device(self.device) + if self.device.type == "cuda" and not check_cuda_p2p_ib_support(): + if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ: + raise NotImplementedError( + "Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. " + 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which ' + "will do this automatically." + ) self._mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu and torch.cuda.is_available(): self.distributed_type = DistributedType.MULTI_GPU diff --git a/src/accelerate/utils/environment.py b/src/accelerate/utils/environment.py index 99e153bf04e..3de72e51c15 100644 --- a/src/accelerate/utils/environment.py +++ b/src/accelerate/utils/environment.py @@ -13,11 +13,12 @@ # limitations under the License. import os +import platform +import subprocess import sys +from distutils import spawn from typing import Dict -import torch - def str_to_bool(value) -> int: """ @@ -61,17 +62,49 @@ def are_libraries_initialized(*library_names: str) -> Dict[str, bool]: return [lib_name for lib_name in library_names if lib_name in sys.modules] +def get_gpu_info(): + """ + Gets GPU count and names using `nvidia-smi` instead of torch to not initialize CUDA. + + Largely based on the `gputil` library. + """ + if platform.system() == "Windows": + # If platform is Windows and nvidia-smi can't be found in path + # try from systemd rive with default installation path + command = spawn.find_executable("nvidia-smi") + if command is None: + command = "%s\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe" % os.environ["systemdrive"] + else: + command = "nvidia-smi" + # Returns as list of `n` GPUs and their names + output = subprocess.check_output( + [command, "--query-gpu=count,name", "--format=csv,noheader"], universal_newlines=True + ) + output = output.strip() + gpus = output.split(os.linesep) + # Get names from output + gpu_count = len(gpus) + gpu_names = [gpu.split(",")[1].strip() for gpu in gpus] + return gpu_names, gpu_count + + def check_cuda_p2p_ib_support(): """ Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after the 3090. + + Noteably uses `nvidia-smi` instead of torch to not initialize CUDA. """ - if torch.cuda.is_available(): - # Get the first device/default - device_name = torch.cuda.get_device_name() - device_count = torch.cuda.device_count() - unsupported_devices = ["RTX 3090", "RTX 40"] + try: + device_names, device_count = get_gpu_info() + unsupported_devices = {"RTX 3090", "RTX 40"} if device_count > 1: - if any(device in device_name for device in unsupported_devices): + if any( + unsupported_device in device_name + for device_name in device_names + for unsupported_device in unsupported_devices + ): return False + except Exception: + pass return True