Skip to content

Commit

Permalink
Fix non persistant buffer dispatch (#1941)
Browse files Browse the repository at this point in the history
* offload only persistant buffer

* add tests and fix naming

* remove_non_persistant=True by default

* style

* style again

* fix hooks

* fix logic
  • Loading branch information
SunMarc authored Nov 20, 2023
1 parent fbe00d7 commit 35b0206
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 6 deletions.
18 changes: 14 additions & 4 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
send_to_device,
set_module_tensor_to_device,
)
from .utils.modeling import get_non_persistent_buffers


class ModelHook:
Expand Down Expand Up @@ -262,22 +263,28 @@ def init_hook(self, module):
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
)
}

for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
):
set_module_tensor_to_device(module, name, "meta")
if not self.offload_buffers and self.execution_device is not None:
for name, _ in module.named_buffers(recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)
elif self.offload_buffers and self.execution_device is not None:
for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
set_module_tensor_to_device(module, name, self.execution_device)

return module

def pre_forward(self, module, *args, **kwargs):
if self.io_same_device:
self.input_device = find_device([args, kwargs])
if self.offload:
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
module,
include_buffers=self.offload_buffers,
recurse=self.place_submodules,
remove_non_persistent=True,
):
fp16_statistics = None
if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
Expand All @@ -294,7 +301,10 @@ def pre_forward(self, module, *args, **kwargs):
def post_forward(self, module, output):
if self.offload:
for name, _ in named_module_tensors(
module, include_buffers=self.offload_buffers, recurse=self.place_submodules
module,
include_buffers=self.offload_buffers,
recurse=self.place_submodules,
remove_non_persistent=True,
):
set_module_tensor_to_device(module, name, "meta")
if type(module).__name__ == "Linear8bitLt":
Expand Down
33 changes: 31 additions & 2 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ def set_module_tensor_to_device(
torch.cuda.empty_cache()


def named_module_tensors(module: nn.Module, include_buffers: bool = True, recurse: bool = False):
def named_module_tensors(
module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
):
"""
A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`.
Expand All @@ -377,13 +379,40 @@ def named_module_tensors(module: nn.Module, include_buffers: bool = True, recurs
Whether or not to include the buffers in the result.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
remove_non_persistent (`bool`, *optional*, defaults to `False`):
Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers =
True
"""
for named_parameter in module.named_parameters(recurse=recurse):
yield named_parameter

if include_buffers:
non_persistent_buffers = set()
if remove_non_persistent:
non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
yield named_buffer
name, _ = named_buffer
if name not in non_persistent_buffers:
yield named_buffer


def get_non_persistent_buffers(module: nn.Module, recurse: bool = False):
"""
Gather all non persistent buffers of a given modules into a set
Args:
module (`nn.Module`):
The module we want the non persistent buffers on.
recurse (`bool`, *optional*, defaults to `False`):
Whether or not to go look in every submodule or just return the direct non persistent buffers.
"""

non_persistent_buffers_set = module._non_persistent_buffers_set
if recurse:
for _, m in module.named_modules():
non_persistent_buffers_set |= m._non_persistent_buffers_set

return non_persistent_buffers_set


class FindTiedParametersResult(list):
Expand Down
39 changes: 39 additions & 0 deletions tests/test_big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,33 @@ def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


class LinearWithNonPersistentBuffers(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.ones((out_features, in_features), **factory_kwargs))
if bias:
self.register_buffer("bias", torch.ones(out_features, **factory_kwargs), persistent=False)
else:
self.register_buffer("bias", None)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(input, self.weight, self.bias)


class ModelForTestNonPersistentBuffers(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = LinearWithNonPersistentBuffers(3, 4)
self.batchnorm = nn.BatchNorm1d(4)
self.linear2 = LinearWithNonPersistentBuffers(4, 5)

def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


class ModelForTestCopy(nn.Module):
def __init__(self, id: int):
super().__init__()
Expand Down Expand Up @@ -302,6 +329,18 @@ def test_dispatch_model(self):
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

@require_cuda
def test_dispatch_model_with_non_persistent_buffers(self):
model = ModelForTestNonPersistentBuffers()
device_map = {"linear1": 0, "batchnorm": "cpu", "linear2": "disk"}
x = torch.randn(2, 3)
expected = model(x)

with TemporaryDirectory() as tmp_dir:
dispatch_model(model, device_map, offload_dir=tmp_dir, offload_buffers=True)
output = model(x)
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))

@require_mps
def test_dispatch_model_mps(self):
model = ModelForTest()
Expand Down
24 changes: 24 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,22 @@ def forward(self, x):
return self.linear2(self.batchnorm(self.linear1(x)))


class LinearWithNonPersistentBuffers(nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.register_buffer("bias", torch.empty(out_features, **factory_kwargs), persistent=False)
else:
self.register_buffer("bias", None)

def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(input, self.weight, self.bias)


def sequential_model(num_layers):
layers = OrderedDict([(f"linear{i}", nn.Linear(1000, 1000)) for i in range(1, num_layers + 1)])
return nn.Sequential(layers)
Expand Down Expand Up @@ -187,6 +203,14 @@ def test_named_tensors(self):
["linear1.weight", "linear1.bias", "batchnorm.weight", "batchnorm.bias", "linear2.weight", "linear2.bias"],
)

model = LinearWithNonPersistentBuffers(10, 10)

named_tensors = named_module_tensors(model, include_buffers=True, remove_non_persistent=False)
self.assertListEqual([name for name, _ in named_tensors], ["weight", "bias"])

named_tensors = named_module_tensors(model, include_buffers=True, remove_non_persistent=True)
self.assertListEqual([name for name, _ in named_tensors], ["weight"])

def test_find_tied_parameters(self):
model = sequential_model(4)
self.assertListEqual(find_tied_parameters(model), [])
Expand Down

0 comments on commit 35b0206

Please sign in to comment.