Skip to content

Commit

Permalink
pass accelerator tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 6, 2025
1 parent fa1bc44 commit d3e24c5
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
):
if any(p.requires_grad for p in model.parameters()):
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
Expand Down Expand Up @@ -3270,6 +3271,7 @@ def _inner(folder):
DistributedType.MULTI_MLU,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_HPU,
):
map_location = "on_device"
else:
Expand Down
6 changes: 5 additions & 1 deletion src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
SCHEDULER_NAME,
WEIGHTS_NAME,
get_pretty_name,
is_cuda_available,
is_hpu_available,
is_mlu_available,
is_musa_available,
is_torch_xla_available,
Expand Down Expand Up @@ -155,7 +157,9 @@ def save_accelerator_state(
states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
if is_musa_available():
states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
else:
if is_hpu_available():
states["torch_hpu_manual_seed"] = torch.hpu.get_rng_state_all()
if is_cuda_available():
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
if is_torch_xla_available():
states["xm_seed"] = xm.get_rng_state()
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
require_multi_xpu,
require_musa,
require_non_cpu,
require_non_hpu,
require_non_torch_xla,
require_non_xpu,
require_npu,
Expand Down
10 changes: 10 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_datasets_available,
is_deepspeed_available,
is_dvclive_available,
is_hpu_available,
is_import_timer_available,
is_mlu_available,
is_mps_available,
Expand Down Expand Up @@ -81,6 +82,8 @@ def get_backend():
return "npu", torch.npu.device_count(), torch.npu.memory_allocated
elif is_xpu_available():
return "xpu", torch.xpu.device_count(), torch.xpu.memory_allocated
elif is_hpu_available():
return "hpu", torch.xpu.device_count(), torch.xpu.memory_allocated
else:
return "cpu", 1, lambda: 0

Expand Down Expand Up @@ -189,6 +192,13 @@ def require_non_xpu(test_case):
return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)


def require_non_hpu(test_case):
"""
Decorator marking a test that should be skipped for HPU.
"""
return unittest.skipUnless(torch_device != "hpu", "test requires a non-HPU")(test_case)


def require_mlu(test_case):
"""
Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available.
Expand Down
3 changes: 2 additions & 1 deletion src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,11 @@ def is_npu_available(check_device=False):

@lru_cache
def is_hpu_available(check_device=False):
"Checks if `torch_hpu` is installed and potentially if a HPU is in the environment"
"Checks if `torch.hpu` is installed and potentially if a HPU is in the environment"
if importlib.util.find_spec("habana_frameworks") is None:
return False

# not sure if this should be here but keeping it for now
import habana_frameworks.torch # noqa: F401
import habana_frameworks.torch.distributed.hccl as hccl # noqa: F401

Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from .imports import (
is_cuda_available,
is_hpu_available,
is_ipex_available,
is_mlu_available,
is_mps_available,
Expand Down Expand Up @@ -58,6 +59,9 @@ def clear_device_cache(garbage_collection=False):
torch.mps.empty_cache()
elif is_cuda_available():
torch.cuda.empty_cache()
elif is_hpu_available():
# doesn't have an empty_cache method
pass


def release_memory(*objects):
Expand Down
8 changes: 4 additions & 4 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def check_device_same(first_device, second_device):
if first_device.type != second_device.type:
return False

if first_device.type == "cuda" and first_device.index is None:
if first_device.type in ["cuda", "hpu"] and first_device.index is None:
# In case the first_device is a cuda device and have
# the index attribute set to `None`, default it to `0`
first_device = torch.device("cuda", index=0)
first_device = torch.device(first_device.type, index=0)

if second_device.type == "cuda" and second_device.index is None:
if second_device.type in ["cuda", "hpu"] and second_device.index is None:
# In case the second_device is a cuda device and have
# the index attribute set to `None`, default it to `0`
second_device = torch.device("cuda", index=0)
second_device = torch.device(second_device.type, index=0)

return first_device == second_device

Expand Down
11 changes: 10 additions & 1 deletion src/accelerate/utils/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
from ..state import AcceleratorState
from .constants import CUDA_DISTRIBUTED_TYPES
from .dataclasses import DistributedType, RNGType
from .imports import is_mlu_available, is_musa_available, is_npu_available, is_torch_xla_available, is_xpu_available
from .imports import (
is_hpu_available,
is_mlu_available,
is_musa_available,
is_npu_available,
is_torch_xla_available,
is_xpu_available,
)


if is_torch_xla_available():
Expand Down Expand Up @@ -53,6 +60,8 @@ def set_seed(seed: int, device_specific: bool = False, deterministic: bool = Fal
torch.mlu.manual_seed_all(seed)
elif is_musa_available():
torch.musa.manual_seed_all(seed)
elif is_hpu_available():
torch.hpu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
Expand Down
5 changes: 4 additions & 1 deletion tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
require_huggingface_suite,
require_multi_device,
require_non_cpu,
require_non_hpu,
require_transformer_engine,
slow,
torch_device,
Expand Down Expand Up @@ -173,7 +174,7 @@ def test_accelerator_state_after_reset(self):
def test_accelerator_can_be_reinstantiated(self):
_ = Accelerator()
assert PartialState._shared_state["_cpu"] is False
assert PartialState._shared_state["device"].type in ["cuda", "mps", "npu", "xpu", "xla"]
assert PartialState._shared_state["device"].type in ["cuda", "mps", "npu", "xpu", "xla", "hpu"]
with self.assertRaises(ValueError):
_ = Accelerator(cpu=True)

Expand Down Expand Up @@ -215,6 +216,7 @@ def test_prepared_objects_are_referenced(self):
assert prepared_train_dl in accelerator._dataloaders
assert prepared_valid_dl in accelerator._dataloaders

@require_non_hpu
def test_free_memory_dereferences_prepared_components(self):
accelerator = Accelerator()
# Free up refs with empty_cache() and gc.collect()
Expand Down Expand Up @@ -583,6 +585,7 @@ def test_can_unwrap_model_te(self):
model_loaded(inputs)

@require_non_cpu
@require_non_hpu
def test_can_unwrap_model_fp16(self):
# test for a regression introduced in #872
# before the fix, after unwrapping with keep_fp32_wrapper=False, there would be the following error:
Expand Down

0 comments on commit d3e24c5

Please sign in to comment.