-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Support FSDP2 #3231
base: main
Are you sure you want to change the base?
[RFC] Support FSDP2 #3231
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1464,13 +1464,15 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e | |
elif self.distributed_type == DistributedType.FSDP: | ||
# We need to fix the optimizer *before* sharding the model | ||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP | ||
|
||
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy, FSDPModule, OffloadPolicy | ||
from torch.distributed.fsdp.api import ShardingStrategy | ||
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so, | ||
# don't wrap it again | ||
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it | ||
# is a FSDP model, don't wrap it again | ||
is_type_fsdp = isinstance(model, FSDP) or ( | ||
is_compiled_module(model) and isinstance(model._orig_mod, FSDP) | ||
# We check for FSDPModule instead of FSDP class for FSDP v2 | ||
is_type_fsdp = isinstance(model, FSDPModule) or ( | ||
is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) | ||
) | ||
|
||
if not is_type_fsdp: | ||
|
@@ -1498,8 +1500,100 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e | |
"ignored_modules": fsdp_plugin.ignored_modules, | ||
"limit_all_gathers": fsdp_plugin.limit_all_gathers, | ||
"device_id": self.device, | ||
} | ||
####### | ||
# fsdp2_kwargs holds all the args supported by | ||
# FSDP2 through fully_shard API | ||
# Most of FSDP2 args can be deduced from the existing FSDP1 args | ||
# Some of the existing FSDP1 args not supported or by default set to True | ||
# information can be found here - https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md | ||
####### | ||
fsdp2_kwargs = { | ||
"reshard_after_forward": True, | ||
"mesh": None, | ||
"mp_policy": MixedPrecisionPolicy(), | ||
"offload_policy": OffloadPolicy() | ||
# shard_placement_fn has been a feature quite recently | ||
# "shard_placement_fn": None | ||
} | ||
model = FSDP(model, **kwargs) | ||
|
||
####### | ||
# Preparation of mesh and reshard_after_forward | ||
# Both of these params may be exposed directly to user to be passed through FSDP config | ||
# However, otherway could be to hide them and set them based on sharding strategy | ||
|
||
# Deduction of the mesh and reshard_after_forward from sharding strategy analogy | ||
# borrowed from https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md | ||
# 1 process_group + FULL_SHARD ==> 1D mesh + reshard_after_forward=True | ||
# 1 process_group + SHARD_GRAD_OP ==> 1D mesh + reshard_after_forward=False | ||
# 2 process_groups/2D device_mesh + HYBRID_SHARD ==> 2D mesh + reshard_after_forward=True | ||
# 2 process_groups/2D device_mesh + _HYBRID_SHARD_ZERO2 ==> 2D mesh + reshard_after_forward=False | ||
####### | ||
|
||
if kwargs["sharding_strategy"] == ShardingStrategy.FULL_SHARD: | ||
# mesh | ||
# no need to prepare mesh and go with default | ||
|
||
# reshard_after_forward = True | ||
fsdp2_kwargs["reshard_after_forward"]=True | ||
elif kwargs["sharding_strategy"] == ShardingStrategy.SHARD_GRAD_OP: | ||
# mesh | ||
# no need to prepare mesh and go with default | ||
|
||
# # reshard_after_forward = False | ||
fsdp2_kwargs["reshard_after_forward"]=False | ||
elif kwargs["sharding_strategy"] == ShardingStrategy.HYBRID_SHARD: | ||
# mesh | ||
# at this point, pytorch does not set 2 d mesh by default based on inter and intra node assumption | ||
# https://github.com/pytorch/pytorch/issues/140102 | ||
# reshard_after_forward = True | ||
fsdp2_kwargs["reshard_after_forward"]=True | ||
elif kwargs["sharding_strategy"] == ShardingStrategy._HYBRID_SHARD_ZERO2: | ||
# mesh | ||
# at this point, pytorch does not set 2 d mesh by default based on inter and intro node assumption | ||
# https://github.com/pytorch/pytorch/issues/140102 | ||
# reshard_after_forward = False | ||
fsdp2_kwargs["reshard_after_forward"]=False | ||
|
||
####### | ||
# mixed precision policy can be mapped from FSDP1 to FSDP2 arg classes | ||
# except for output_dtype new to FSDP2 and has to come from user | ||
####### | ||
|
||
if kwargs["mixed_precision"] is not None: | ||
# MixedPrecisionPolicy is from the new _composable design | ||
fsdp2_kwargs["mp_policy"] = MixedPrecisionPolicy( | ||
param_dtype=kwargs["mixed_precision"].param_dtype, | ||
reduce_dtype=kwargs["mixed_precision"].reduce_dtype, | ||
cast_forward_inputs=kwargs["mixed_precision"].cast_forward_inputs | ||
# output_dtype cannot be deduced from FSDP1 args and has to come from user | ||
# buffer_dtype is not available, is it not required for FSDP2? | ||
) | ||
|
||
####### | ||
# offload policy can be mapped from FSDP1 to FSDP2 arg classes | ||
# pinning memory seems to be a new feature to FSDP2 | ||
# offloading params is the default behaviour | ||
####### | ||
|
||
if kwargs["cpu_offload"] is not None and kwargs["cpu_offload"].offload_params: | ||
# CPUOffloadPolicy is from the new _composable design | ||
fsdp2_kwargs["mp_policy"] = CPUOffloadPolicy( | ||
# pin_memory= cannot be deduced from FSDP1 args and has to come from user | ||
# offloads params is the default behaviour | ||
) | ||
|
||
####### | ||
# auto_wrap_policy is not yet supported by FSDP2 | ||
# therefore manual wrapping has to be done like below | ||
####### | ||
for layer in model.model.layers: | ||
fully_shard(layer, **fsdp2_kwargs) | ||
fully_shard(model, **fsdp2_kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am a bit unsure what is the reason for this final |
||
|
||
####### | ||
# does existing activation_checkpointing API work out of the box with FSDP2? | ||
####### | ||
if fsdp_plugin.activation_checkpointing: | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
CheckpointImpl, | ||
|
@@ -2364,7 +2458,11 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2): | |
parameters = [p for p in parameters] | ||
for model in self._models: | ||
if parameters == [p for p in model.parameters()]: | ||
return model.clip_grad_norm_(max_norm, norm_type) | ||
####### | ||
# gradient clipping function is not part of the FSDP class object like in FSDP v1 | ||
# rather is removed | ||
####### | ||
return torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, norm_type) | ||
elif self.distributed_type == DistributedType.DEEPSPEED: | ||
# `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed | ||
# We cannot return the gradient norm because DeepSpeed does it. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one doesn't seem to apply to general use case.
Feels like it should be something like below that checks and apply fully_shard from bottom up.