diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 8fa43207692..dd9eed26976 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -297,7 +297,7 @@ def __init__(self, cpu: bool = False, **kwargs): if self.device.type == "hpu": # we should do this in optimum-habana somehow and not here - from optimum.habana.distributed import parallel_state + from optimum.habana.distributed import parallel_state # noqa: F401 if self.distributed_type != DistributedType.DEEPSPEED: context_parallel_size = 1