From 1c8c579e9ee8cfbd39ceb94ad3f08ace3a7b2855 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Thu, 7 Nov 2024 21:52:27 +0530 Subject: [PATCH 1/2] feat: add fsdp2 Signed-off-by: Mehant Kammakomati --- src/accelerate/accelerator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index ab949c42e43..0a1d3ed1736 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1499,6 +1499,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "limit_all_gathers": fsdp_plugin.limit_all_gathers, "device_id": self.device, } + print("using FSDP in accelerate") model = FSDP(model, **kwargs) if fsdp_plugin.activation_checkpointing: from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( From 86211e875ca87f8dacc28d594c6ed1ee9ec2582e Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Fri, 8 Nov 2024 12:50:57 +0530 Subject: [PATCH 2/2] feat: add support for fsdp 2 and document reqs Signed-off-by: Mehant Kammakomati --- src/accelerate/accelerator.py | 109 ++++++++++++++++++++++++++++++++-- 1 file changed, 103 insertions(+), 6 deletions(-) diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 0a1d3ed1736..7fdb07ac6fa 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -1464,13 +1464,15 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e elif self.distributed_type == DistributedType.FSDP: # We need to fix the optimizer *before* sharding the model from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP - + from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy, FSDPModule, OffloadPolicy + from torch.distributed.fsdp.api import ShardingStrategy # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, # don't wrap it again # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it # is a FSDP model, don't wrap it again - is_type_fsdp = isinstance(model, FSDP) or ( - is_compiled_module(model) and isinstance(model._orig_mod, FSDP) + # We check for FSDPModule instead of FSDP class for FSDP v2 + is_type_fsdp = isinstance(model, FSDPModule) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) ) if not is_type_fsdp: @@ -1498,9 +1500,100 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e "ignored_modules": fsdp_plugin.ignored_modules, "limit_all_gathers": fsdp_plugin.limit_all_gathers, "device_id": self.device, + } + ####### + # fsdp2_kwargs holds all the args supported by + # FSDP2 through fully_shard API + # Most of FSDP2 args can be deduced from the existing FSDP1 args + # Some of the existing FSDP1 args not supported or by default set to True + # information can be found here - https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md + ####### + fsdp2_kwargs = { + "reshard_after_forward": True, + "mesh": None, + "mp_policy": MixedPrecisionPolicy(), + "offload_policy": OffloadPolicy() + # shard_placement_fn has been a feature quite recently + # "shard_placement_fn": None } - print("using FSDP in accelerate") - model = FSDP(model, **kwargs) + + ####### + # Preparation of mesh and reshard_after_forward + # Both of these params may be exposed directly to user to be passed through FSDP config + # However, otherway could be to hide them and set them based on sharding strategy + + # Deduction of the mesh and reshard_after_forward from sharding strategy analogy + # borrowed from https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md + # 1 process_group + FULL_SHARD ==> 1D mesh + reshard_after_forward=True + # 1 process_group + SHARD_GRAD_OP ==> 1D mesh + reshard_after_forward=False + # 2 process_groups/2D device_mesh + HYBRID_SHARD ==> 2D mesh + reshard_after_forward=True + # 2 process_groups/2D device_mesh + _HYBRID_SHARD_ZERO2 ==> 2D mesh + reshard_after_forward=False + ####### + + if kwargs["sharding_strategy"] == ShardingStrategy.FULL_SHARD: + # mesh + # no need to prepare mesh and go with default + + # reshard_after_forward = True + fsdp2_kwargs["reshard_after_forward"]=True + elif kwargs["sharding_strategy"] == ShardingStrategy.SHARD_GRAD_OP: + # mesh + # no need to prepare mesh and go with default + + # # reshard_after_forward = False + fsdp2_kwargs["reshard_after_forward"]=False + elif kwargs["sharding_strategy"] == ShardingStrategy.HYBRID_SHARD: + # mesh + # at this point, pytorch does not set 2 d mesh by default based on inter and intra node assumption + # https://github.com/pytorch/pytorch/issues/140102 + # reshard_after_forward = True + fsdp2_kwargs["reshard_after_forward"]=True + elif kwargs["sharding_strategy"] == ShardingStrategy._HYBRID_SHARD_ZERO2: + # mesh + # at this point, pytorch does not set 2 d mesh by default based on inter and intro node assumption + # https://github.com/pytorch/pytorch/issues/140102 + # reshard_after_forward = False + fsdp2_kwargs["reshard_after_forward"]=False + + ####### + # mixed precision policy can be mapped from FSDP1 to FSDP2 arg classes + # except for output_dtype new to FSDP2 and has to come from user + ####### + + if kwargs["mixed_precision"] is not None: + # MixedPrecisionPolicy is from the new _composable design + fsdp2_kwargs["mp_policy"] = MixedPrecisionPolicy( + param_dtype=kwargs["mixed_precision"].param_dtype, + reduce_dtype=kwargs["mixed_precision"].reduce_dtype, + cast_forward_inputs=kwargs["mixed_precision"].cast_forward_inputs + # output_dtype cannot be deduced from FSDP1 args and has to come from user + # buffer_dtype is not available, is it not required for FSDP2? + ) + + ####### + # offload policy can be mapped from FSDP1 to FSDP2 arg classes + # pinning memory seems to be a new feature to FSDP2 + # offloading params is the default behaviour + ####### + + if kwargs["cpu_offload"] is not None and kwargs["cpu_offload"].offload_params: + # CPUOffloadPolicy is from the new _composable design + fsdp2_kwargs["mp_policy"] = CPUOffloadPolicy( + # pin_memory= cannot be deduced from FSDP1 args and has to come from user + # offloads params is the default behaviour + ) + + ####### + # auto_wrap_policy is not yet supported by FSDP2 + # therefore manual wrapping has to be done like below + ####### + for layer in model.model.layers: + fully_shard(layer, **fsdp2_kwargs) + fully_shard(model, **fsdp2_kwargs) + + ####### + # does existing activation_checkpointing API work out of the box with FSDP2? + ####### if fsdp_plugin.activation_checkpointing: from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, @@ -2365,7 +2458,11 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2): parameters = [p for p in parameters] for model in self._models: if parameters == [p for p in model.parameters()]: - return model.clip_grad_norm_(max_norm, norm_type) + ####### + # gradient clipping function is not part of the FSDP class object like in FSDP v1 + # rather is removed + ####### + return torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type) elif self.distributed_type == DistributedType.DEEPSPEED: # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed # We cannot return the gradient norm because DeepSpeed does it.