Skip to content

Commit

Permalink
[PyTorch] Remove unnecessary Pylint overrides (NVIDIA#794)
Browse files Browse the repository at this point in the history
* Remove unnecessary Pylint overrides

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Fixes to lint

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
  • Loading branch information
2 people authored and pggPL committed May 23, 2024
1 parent 3ba02f1 commit fab53a4
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 36 deletions.
8 changes: 5 additions & 3 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# See LICENSE for license information.

"""Transformer Engine bindings for pyTorch"""
import torch

from .module import LayerNormLinear
from .module import Linear
from .module import LayerNormMLP
Expand Down Expand Up @@ -32,8 +34,8 @@
onnx_rmsnorm_fwd,
onnx_rmsnorm_fwd_fp8
)

try:
import torch
torch._dynamo.config.error_on_nested_jit_trace = False
except: # pylint: disable=bare-except
pass
except AttributeError:
pass # error_on_nested_jit_trace was added in PyTorch 2.2.0
29 changes: 19 additions & 10 deletions transformer_engine/pytorch/cpu_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# See LICENSE for license information.

"""Functionality for CPU offloading of tensors saved for backward pass."""
from typing import Any
from __future__ import annotations
from contextlib import nullcontext
from typing import Any, Dict, Optional

import torch

from .float8_tensor import Float8Tensor
Expand Down Expand Up @@ -99,10 +101,17 @@ class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook):
and `tensor_pop` interface. How the offload-handler manages the offloading, recovering
or prefetching timing is transparent to this hook.
"""
def __init__(self, offload_handler, handler_extra_kwargs={}, debug=False) -> None: # pylint: disable=dangerous-default-value
self.debug = debug
self.offload_handler = offload_handler
self.handler_extra_kwargs = handler_extra_kwargs
def __init__(
self,
offload_handler: OffloadHandler,
handler_extra_kwargs: Optional[Dict[str,Any]] = None,
debug: bool = False,
) -> None:
if handler_extra_kwargs is None:
handler_extra_kwargs = {}
self.debug: bool = debug
self.offload_handler: OffloadHandler = offload_handler
self.handler_extra_kwargs: Dict[str,Any] = handler_extra_kwargs
super().__init__()

def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
Expand Down Expand Up @@ -290,10 +299,10 @@ def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag):
allocate_new_buf = True
else:
tensor_buf = id_buf_map[tensor_id]
if not (tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype): # pylint: disable=simplifiable-if-statement
allocate_new_buf = True
else:
allocate_new_buf = False # in this case, reuse the old buffer
allocate_new_buf = (
tensor_buf.size() != tensor.size()
or tensor_buf.dtype != tensor.dtype
)

if allocate_new_buf:
# supposed to only execute once
Expand Down Expand Up @@ -491,7 +500,7 @@ def tensor_need_offloading_checker_activations(tensor):
def tensor_need_offloading_checker_weights(tensor):
return hasattr(tensor, "weight_offloading")

def tensor_need_offloading_checker_all(tensor): # pylint: disable=unused-argument
def tensor_need_offloading_checker_all(tensor):
return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading"))

if offload_activations and offload_weights:
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
return None

# Slice op
# TODO Consider additional bookkeeping so we invalidate caches # pylint: disable=fixme
# if these slices are modified in-place
if func == aten.slice.Tensor:
tensor = args[0]
data = tensor._data
Expand Down
20 changes: 10 additions & 10 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,12 @@ def fp8_model_init(enabled: bool = True) -> None:
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
try:
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters


@contextmanager
Expand Down Expand Up @@ -555,16 +555,16 @@ def fp8_autocast(
distributed group over which amaxes for the fp8 tensors
are reduced at the end of each training step.
"""
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=_graph)
try:
fp8_state = FP8GlobalStateManager.get_fp8_autocast_state()
FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled,
calibrating=calibrating,
fp8_recipe=fp8_recipe,
fp8_group=fp8_group,
_graph=_graph)
yield
finally:
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment
FP8GlobalStateManager.set_fp8_autocast_state(fp8_state)
FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph)


Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,7 @@ def grad_output_preprocess(
out=grad_output_c,
)
else:
grad_output_c = grad_ouput_mat # pylint: disable=undefined-variable
grad_output_c = grad_output_mat
if not ctx.ub_overlap_ag:
grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group)
if not isinstance(grad_output_c, Float8Tensor):
Expand Down
21 changes: 11 additions & 10 deletions transformer_engine/pytorch/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,19 +336,20 @@ def forward(
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)

def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool: # pylint: disable=too-many-return-statements
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np

if ( # pylint: disable=too-many-boolean-expressions
not self.scaled_masked_softmax_fusion # user doesn't want to fuse
or not self.input_in_float16 # input must be fp16
or sk < 16
or sk > 16384 # sk must be 16 ~ 16384
or sk % 8 != 0 # sk must be divisor of 8
or self.attn_mask_type == "arbitrary" # Custom masks not supported
):
return False
if not self.scaled_masked_softmax_fusion:
return False # user doesn't want to fuse
if not self.input_in_float16:
return False # input must be fp16
if not 16 < sk < 16384:
return False # sk must be 16 ~ 16384
if sk % 8 != 0:
return False # sk must be divisor of 8
if self.attn_mask_type == "arbitrary":
return False # Custom masks not supported

if self.attn_mask_type == "causal": # unfused causal softmax kernel
return True
Expand Down

0 comments on commit fab53a4

Please sign in to comment.