Skip to content

Commit

Permalink
Fixed formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
  • Loading branch information
Selvaraj Anandaraj committed Jan 18, 2024
1 parent d45ad8c commit 3e4b3d5
Showing 1 changed file with 31 additions and 31 deletions.
62 changes: 31 additions & 31 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,35 +412,35 @@ def get_cpu_offload_context(enabled: bool = False,
"""

def tensor_need_offloading_checker_activations(tensor):
return not hasattr(tensor,"weight_offloading")

def tensor_need_offloading_checker_weights(tensor): #This includes the Gradient Accumulation Buffer
return hasattr(tensor,"weight_offloading")

def tensor_need_offloading_checker_all(tensor):
return True

if offload_activations and offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_all
elif offload_activations:
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
elif offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_weights
else:
raise ValueError("CPU Offloading is enabled while it is not mentioned what to offload (weights/activations)")

cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_prefetch_group=1,
tensor_need_offloading_checker=tensor_need_offloading_checker
)

def group_prefetch_offload_commit_async(tensor):
return group_prefetch_offload_commit(tensor,cpu_offload_handler)

if enabled:
return CpuOffloadHookWithOffloadHandler(offload_handler = cpu_offload_handler), group_prefetch_offload_commit_async
else:
return nullcontext(), group_prefetch_offload_commit_async
def tensor_need_offloading_checker_activations(tensor):
return not hasattr(tensor,"weight_offloading")
def tensor_need_offloading_checker_weights(tensor): #This includes the Gradient Accumulation Buffer
return hasattr(tensor,"weight_offloading")
def tensor_need_offloading_checker_all(tensor):
return True
if offload_activations and offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_all
elif offload_activations:
tensor_need_offloading_checker = tensor_need_offloading_checker_activations
elif offload_weights:
tensor_need_offloading_checker = tensor_need_offloading_checker_weights
else:
raise ValueError("CPU Offloading is enabled while it is not mentioned what to offload (weights/activations)")
cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler(
num_offload_group=num_layers,
num_prefetch_group=1,
tensor_need_offloading_checker=tensor_need_offloading_checker
)
def group_prefetch_offload_commit_async(tensor):
return group_prefetch_offload_commit(tensor,cpu_offload_handler)
if enabled:
return CpuOffloadHookWithOffloadHandler(offload_handler = cpu_offload_handler), group_prefetch_offload_commit_async
else:
return nullcontext(), group_prefetch_offload_commit_async

0 comments on commit 3e4b3d5

Please sign in to comment.