Skip to content

Commit

Permalink
Merge pull request #1350 from qiyulei-mt/musa_support
Browse files Browse the repository at this point in the history
support musa backend in FlagEmbedding
  • Loading branch information
hanhainebula authored Feb 7, 2025
2 parents fdc6786 + 6c9dba5 commit 44e5525
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
17 changes: 15 additions & 2 deletions FlagEmbedding/abc/inference/AbsEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
import numpy as np
from transformers import is_torch_npu_available

try:
import torch_musa
except Exception:
pass

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -106,6 +111,8 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
elif is_torch_npu_available():
return [f"npu:{i}" for i in range(torch.npu.device_count())]
elif hasattr(torch, "musa") and torch.musa.is_available():
return [f"musa:{i}" for i in range(torch.musa.device_count())]
elif torch.backends.mps.is_available():
try:
return [f"mps:{i}" for i in range(torch.mps.device_count())]
Expand All @@ -116,12 +123,18 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
elif isinstance(devices, str):
return [devices]
elif isinstance(devices, int):
return [f"cuda:{devices}"]
if hasattr(torch, "musa") and torch.musa.is_available():
return [f"musa:{devices}"]
else:
return [f"cuda:{devices}"]
elif isinstance(devices, list):
if isinstance(devices[0], str):
return devices
elif isinstance(devices[0], int):
return [f"cuda:{device}" for device in devices]
if hasattr(torch, "musa") and torch.musa.is_available():
return [f"musa:{device}" for device in devices]
else:
return [f"cuda:{device}" for device in devices]
else:
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
else:
Expand Down
17 changes: 15 additions & 2 deletions FlagEmbedding/abc/inference/AbsReranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from tqdm import tqdm, trange
from transformers import is_torch_npu_available

try:
import torch_musa
except Exception:
pass

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -107,19 +112,27 @@ def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[s
return [f"cuda:{i}" for i in range(torch.cuda.device_count())]
elif is_torch_npu_available():
return [f"npu:{i}" for i in range(torch.npu.device_count())]
elif hasattr(torch, "musa") and torch.musa.is_available():
return [f"musa:{i}" for i in range(torch.musa.device_count())]
elif torch.backends.mps.is_available():
return ["mps"]
else:
return ["cpu"]
elif isinstance(devices, str):
return [devices]
elif isinstance(devices, int):
return [f"cuda:{devices}"]
if hasattr(torch, "musa") and torch.musa.is_available():
return [f"musa:{devices}"]
else:
return [f"cuda:{devices}"]
elif isinstance(devices, list):
if isinstance(devices[0], str):
return devices
elif isinstance(devices[0], int):
return [f"cuda:{device}" for device in devices]
if hasattr(torch, "musa") and torch.musa.is_available():
return [f"musa:{device}" for device in devices]
else:
return [f"cuda:{device}" for device in devices]
else:
raise ValueError("devices should be a string or an integer or a list of strings or a list of integers.")
else:
Expand Down

0 comments on commit 44e5525

Please sign in to comment.