diff --git a/desc/__init__.py b/desc/__init__.py index 840b9985b9..2e5f9003fb 100644 --- a/desc/__init__.py +++ b/desc/__init__.py @@ -61,11 +61,12 @@ def __getattr__(name): config = {"device": None, "avail_mem": None, "kind": None} -def set_device(kind="cpu"): +def set_device(kind="cpu", gpuid=None): """Sets the device to use for computation. - If kind==``'gpu'``, checks available GPUs and selects the one with the most - available memory. + If kind==``'gpu'`` and a gpuid is specified, uses the specified GPU. If + gpuid==``None`` or a wrong GPU id is given, checks available GPUs and selects the + one with the most available memory. Respects environment variable CUDA_VISIBLE_DEVICES for selecting from multiple available GPUs @@ -127,14 +128,26 @@ def set_device(kind="cpu"): set_device(kind="cpu") return devices = [dev for dev in devices if dev["index"] in gpu_ids] - for dev in devices: - mem = dev["mem_total"] - dev["mem_used"] - if mem > maxmem: - maxmem = mem - selected_gpu = dev + + if gpuid is not None and (str(gpuid) in gpu_ids): + selected_gpu = [dev for dev in devices if dev["index"] == str(gpuid)][0] + else: + for dev in devices: + mem = dev["mem_total"] - dev["mem_used"] + if mem > maxmem: + maxmem = mem + selected_gpu = dev config["device"] = selected_gpu["type"] + " (id={})".format( selected_gpu["index"] ) + if gpuid is not None and not (str(gpuid) in gpu_ids): + warnings.warn( + colored( + "Specified gpuid {} not found, falling back to ".format(str(gpuid)) + + config["device"], + "yellow", + ) + ) config["avail_mem"] = ( selected_gpu["mem_total"] - selected_gpu["mem_used"] ) / 1024 # in GB