Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create _preprare_fsdp to pre- prepare fsdp model training #3213

177 changes: 175 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import shutil
import sys
import warnings
from collections import OrderedDict
from collections import Counter, OrderedDict
from contextlib import contextmanager
from functools import partial
from types import MethodType
Expand All @@ -32,6 +32,7 @@
import torch
import torch.utils.hooks as hooks
from huggingface_hub import split_torch_state_dict_into_shards
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
Expand Down Expand Up @@ -1328,6 +1329,12 @@ def prepare(self, *args, device_placement=None):
args = self._prepare_ipex_or_xpu(*args)
if self.fp8_backend == "TE":
args = self._prepare_te(*args)
if self.distributed_type == DistributedType.FSDP and not self.state.fsdp_plugin.use_orig_params:
# Wrap models with FSDP and update the optimizers parameters.
# Other types of wrapping are handled in the next if-else block.
args = self._prepare_fsdp(*args, device_placement=device_placement)
# Clear non-utilized objects from all types of memory.
release_memory()
if self.distributed_type == DistributedType.DEEPSPEED:
result = self._prepare_deepspeed(*args)
elif self.distributed_type == DistributedType.MEGATRON_LM:
Expand Down Expand Up @@ -1472,7 +1479,6 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
Expand Down Expand Up @@ -2075,6 +2081,173 @@ def _prepare_msamp(self, *args, device_placement):
device_placement[optimizer_index] = False
return tuple(result), device_placement

def _prepare_fsdp(self, *args, device_placement):
# This dictionary is used to preserve the order of arguments.
models = {}
optimizers = {}
for idx, obj in enumerate(args):
if getattr(obj, "_is_accelerate_prepared", False):
continue
if isinstance(obj, torch.nn.Module):
models[idx] = obj
elif isinstance(obj, torch.optim.Optimizer):
optimizers[idx] = obj

# Validate the presence of models and optimizers
if not models and not optimizers:
return args

# Flattening weights implies that the optimizers have already been processed.
if next(next(iter(models.values())).named_parameters())[0].endswith("_flat_param"):
return args

if len(models) != len(optimizers):
raise ValueError(
f"The number of models ({len(models)}) must be equal to the number of optimizers ({len(optimizers)})."
)
Comment on lines +2096 to +2107
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to move these checks to the very start of the method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenjaminBossan The method! Do you mean .prepare? What are the benefits of doing so?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant _prepare_fsdp. It is a common pattern to perform all checks as early as possible, so following this makes the code easier to understand for readers. This is especially so if there are early returns.

The reason why it's common is so that we don't do any unnecessary work if the checks fail anyway. In this case, there is no need to determine models or optimizers if we're going to raise an error later. By skipping the unnecessary work, we ensure faster execution and prevent possibly unwanted side-effects (this might not be relevant here right now but code will change in the future and then it could be true).


# Create a mapping from a model layer `model_idx` to an optimizer `opt_idx`.
model_optimizer_map = {}
temp_optimizers = optimizers.copy()

# Iterate over models and optimizers to create the mapping
for model_idx, model in models.items():
# Use a list to avoid runtime issues while popping
for opt_idx, opt in list(temp_optimizers.items()):
found = False
for group in opt.param_groups:
for opt_param in group["params"]:
for param in model.parameters():
if param is opt_param:
model_optimizer_map[model_idx] = opt_idx
# Remove the optimizer to reduce future iterations
temp_optimizers.pop(opt_idx)
found = True
break
if found:
break
if found:
break
if found:
break

# Ensure every model has a corresponding optimizer.
if len(model_optimizer_map) != len(models):
unmatched_models = [model_idx for model_idx in models if model_idx not in model_optimizer_map]
raise ValueError(
f"We couldn't match the following models with optimizers. Unmatched model indexes: {unmatched_models}"
)

# We need to determine which layer (by name) corresponds to which group `group_idx`.
# FSDP transforms a block of layers into a single flattened parameter.
# By tracking the mapping of blocks to flattened parameters, we can create new optimizers.
# These new optimizers will retain, if possible, the parameter group partitions
# and their attributes such as learning rate, weight decay, etc.

model_layer_group_map = {}
# Iterate over models and their parameters to build the mapping.
for model_idx, model in models.items():
optimizer = optimizers[model_optimizer_map[model_idx]]
model_layer_group_map[model_idx] = {}
for name, param in model.named_parameters():
for group_idx, opt_group in enumerate(optimizer.param_groups):
# Check if the current parameter is in the group and map it
for opt_param in opt_group["params"]:
if param is opt_param:
model_layer_group_map[model_idx][name] = group_idx
break # Exit after the first match to avoid redundant mappings

if name in model_layer_group_map[model_idx]:
break

# Clear parameter lists.
for opt in optimizers.values():
for group in opt.param_groups:
group["params"].clear()

fsdp_2_base_layer_map = {}

for model_idx, model in models.items():
# Get original parameter names
base_layer_name = list(next(zip(*model.named_parameters())))

# Prepare (wrap) the model with FSDP
model = self.prepare_model(model=model, device_placement=device_placement)
models[model_idx] = model

# Get FSDP-wrapped parameter names, skipping the first (_flat_param)
fsdp_layers_name = next(zip(*model.named_parameters()))[1:]
# Remove the added prefix '_fsdp_wrapped_module.' and the suffix '._flat_param'.
# These indicate the shared root of blocks of layers that form an FSDP unit.
fsdp_layers_name_cleaned = list(
map(
lambda string: string.replace("_fsdp_wrapped_module.", "").replace("._flat_param", ""),
fsdp_layers_name,
)
)

# Create a mapping of cleaned names to FSDP names
fsdp_layers_cleaned_to_name = dict(zip(fsdp_layers_name_cleaned, fsdp_layers_name))
# Sort cleaned names by depth to start with the most deeply wrapped layers.
fsdp_layers_name_cleaned.sort(key=lambda e: len(e.split(".")), reverse=True)

# Initialize mapping for the current model
fsdp_2_base_layer_map[model_idx] = {}
map_parms = fsdp_2_base_layer_map[model_idx]

for fsdp_cleaned in fsdp_layers_name_cleaned:
fsdp_name = fsdp_layers_cleaned_to_name[fsdp_cleaned]
map_parms[fsdp_name] = []
# Group layers that share the same FSDP wrap.
for layer in base_layer_name:
if layer.startswith(fsdp_cleaned):
map_parms[fsdp_name].append(layer)
# Remove taken layers
for layer in map_parms[fsdp_name]:
base_layer_name.remove(layer)

# Map remaining base layers under the overall FSDP wrap
map_parms["_fsdp_wrapped_module._flat_param"] = base_layer_name

# Replace optimizer parameter groups with the flattened ones.
for model_idx, opt_idx in model_optimizer_map.items():
local_fsdp_map = fsdp_2_base_layer_map[model_idx]
local_layer_group = model_layer_group_map[model_idx]
for fsdp_layer, param in models[model_idx].named_parameters():
local_groups = []
for layer in local_fsdp_map[fsdp_layer]:
local_groups.append(local_layer_group[layer])
if local_groups:
counter_groups = Counter(local_groups)
# Perform majority vote.
group_idx = counter_groups.most_common(1)[0][0]
else:
logger.warning(
f"No parameters group found for {fsdp_layer}. Default: set parameter to first group."
)
group_idx = 0
# Add FSDP flattened parameter to the appropriate group.
optimizer = optimizers[model_optimizer_map[model_idx]]
optimizer.param_groups[group_idx]["params"].append(param)

# Remove empty groups.
for opt in optimizers.values():
remove_ids = []
# Perform forward lookup.
for idx, group in enumerate(opt.param_groups):
if not group["params"]:
remove_ids.append(idx)
# Remove references backward.
for idx in remove_ids[::-1]:
opt.param_groups.pop(idx)

# Replace models with their FSDP versions.
result = list(args)
for model_idx, model in models.items():
result[model_idx] = model
# The optimizers have been modified in place.
return result

def prepare_data_loader(
self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None
):
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device: torch.c
# if no tied names just passthrough
return param_init_fn

_tied_names = model._tied_weights_keys
# get map of parameter instances to params.
# - needed for replacement later
_tied_params = {}
Expand Down