Skip to content

Commit

Permalink
Check notebook launcher for 3090+ (#2212)
Browse files Browse the repository at this point in the history
* Include dist launch

* Better way

* CLean

* Just do it always

* Account for notebook launcher

* Use better gpu check

* Clean output

* Set logic
  • Loading branch information
muellerzr authored Dec 5, 2023
1 parent 47e6c36 commit 8f871f4
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 21 deletions.
24 changes: 19 additions & 5 deletions src/accelerate/launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@
import torch

from .state import AcceleratorState, PartialState
from .utils import PrecisionType, PrepareForLaunch, are_libraries_initialized, is_mps_available, patch_environment
from .utils import (
PrecisionType,
PrepareForLaunch,
are_libraries_initialized,
check_cuda_p2p_ib_support,
is_mps_available,
patch_environment,
)


def test_launch():
Expand Down Expand Up @@ -153,16 +160,23 @@ def train(*args):
err += f"\n\t* `{lib_name}`"
raise RuntimeError(err)

# torch.distributed will expect a few environment variable to be here. We set the ones common to each
# process here (the other ones will be set be the launcher).
with patch_environment(
patched_env = dict(
nproc=num_processes,
node_rank=node_rank,
world_size=num_nodes * num_processes,
master_addr=master_addr,
master_port=use_port,
mixed_precision=mixed_precision,
):
)

# Check for CUDA P2P and IB issues
if not check_cuda_p2p_ib_support():
patched_env["nccl_p2p_disable"] = "1"
patched_env["nccl_ib_disable"] = "1"

# torch.distributed will expect a few environment variable to be here. We set the ones common to each
# process here (the other ones will be set be the launcher).
with patch_environment(**patched_env):
# First dummy launch
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
Expand Down
15 changes: 7 additions & 8 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,6 @@ def __init__(self, cpu: bool = False, **kwargs):
self.backend = "nccl"
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)

if not check_cuda_p2p_ib_support():
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
raise NotImplementedError(
"Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. "
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
"will do this automatically."
)

self.num_processes = torch.distributed.get_world_size()
self.process_index = torch.distributed.get_rank()
self.local_process_index = int(os.environ.get("LOCAL_RANK", -1))
Expand All @@ -206,6 +198,13 @@ def __init__(self, cpu: bool = False, **kwargs):
self.device = torch.device("cuda", self.local_process_index)
if self.device is not None:
torch.cuda.set_device(self.device)
if self.device.type == "cuda" and not check_cuda_p2p_ib_support():
if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ:
raise NotImplementedError(
"Using RTX 3090 or 4000 series doesn't support faster communication broadband via P2P or IB. "
'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which '
"will do this automatically."
)
self._mixed_precision = "no" # deepspeed handles mixed_precision using deepspeed_config
elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu and torch.cuda.is_available():
self.distributed_type = DistributedType.MULTI_GPU
Expand Down
49 changes: 41 additions & 8 deletions src/accelerate/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

import os
import platform
import subprocess
import sys
from distutils import spawn
from typing import Dict

import torch


def str_to_bool(value) -> int:
"""
Expand Down Expand Up @@ -61,17 +62,49 @@ def are_libraries_initialized(*library_names: str) -> Dict[str, bool]:
return [lib_name for lib_name in library_names if lib_name in sys.modules]


def get_gpu_info():
"""
Gets GPU count and names using `nvidia-smi` instead of torch to not initialize CUDA.
Largely based on the `gputil` library.
"""
if platform.system() == "Windows":
# If platform is Windows and nvidia-smi can't be found in path
# try from systemd rive with default installation path
command = spawn.find_executable("nvidia-smi")
if command is None:
command = "%s\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe" % os.environ["systemdrive"]
else:
command = "nvidia-smi"
# Returns as list of `n` GPUs and their names
output = subprocess.check_output(
[command, "--query-gpu=count,name", "--format=csv,noheader"], universal_newlines=True
)
output = output.strip()
gpus = output.split(os.linesep)
# Get names from output
gpu_count = len(gpus)
gpu_names = [gpu.split(",")[1].strip() for gpu in gpus]
return gpu_names, gpu_count


def check_cuda_p2p_ib_support():
"""
Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
the 3090.
Noteably uses `nvidia-smi` instead of torch to not initialize CUDA.
"""
if torch.cuda.is_available():
# Get the first device/default
device_name = torch.cuda.get_device_name()
device_count = torch.cuda.device_count()
unsupported_devices = ["RTX 3090", "RTX 40"]
try:
device_names, device_count = get_gpu_info()
unsupported_devices = {"RTX 3090", "RTX 40"}
if device_count > 1:
if any(device in device_name for device in unsupported_devices):
if any(
unsupported_device in device_name
for device_name in device_names
for unsupported_device in unsupported_devices
):
return False
except Exception:
pass
return True

0 comments on commit 8f871f4

Please sign in to comment.