Skip to content

Commit

Permalink
add more ops
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 committed Jan 21, 2025
1 parent e4a4fb3 commit de59f73
Showing 1 changed file with 10 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.fx.subgraph_rewriter import Match
from typing_extensions import TypeAlias

from neural_compressor.common import utils
from neural_compressor.common import logger, utils

# =============================================================================
# Search and replace patterns
Expand Down Expand Up @@ -53,32 +53,24 @@ class PatternPair:
FuncArgsMappingType: TypeAlias = Dict[TorchFuncType, Tuple[torch.Tensor, ...]]


# Align with https://pytorch.org/docs/stable/amp.html#cpu-ops-that-can-autocast-to-bfloat16
# conv1d, conv2d, conv3d, bmm, mm, linalg_vecdot, baddbmm, addmm, addbmm,
# linear, matmul, _convolution, conv_tbc, mkldnn_rnn_layer, conv_transpose1d,
# conv_transpose2d, conv_transpose3d, prelu, scaled_dot_product_attention, _native_multi_head_attention

# Align with xiq, as it relay on xiq's set_module_xx capability
FN_ARGS_MAPPING: FuncArgsMappingType = {
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0)), # linear w/o bias
torch.nn.functional.linear: (torch.randn(0, 0), torch.randn(0, 0), torch.randn(0)), # linear w/ bias
torch.nn.functional.conv2d: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)), # conv2d w/o bias
torch.nn.functional.conv2d: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0), torch.randn(0)), # conv2d w/ bias
torch.bmm: (torch.randn(0, 0, 0), torch.randn(0, 0, 0)), # bmm
torch.mm: (torch.randn(0, 0), torch.randn(0, 0)), # mm
torch.nn.functional.conv2d: (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1)), # conv2d w/o bias
torch.nn.functional.conv2d: (torch.randn(1, 1, 1, 1), torch.randn(1, 1, 1, 1), torch.randn(1)), # conv2d w/ bias
torch.matmul: (torch.randn(0, 0), torch.randn(0, 0)), # matmul
torch.matmul: (torch.randn(0, 0, 0), torch.randn(0, 0, 0)), # matmul
torch.matmul: (torch.randn(0, 0, 0, 0), torch.randn(0, 0, 0, 0)), # matmul
}

# module cls -> function name
NN_MODULES_MAPPING = {
# module cls <-> function name
NN_MODULES_TO_NN_FN = {
torch.nn.Linear: torch.nn.functional.linear,
torch.nn.Conv2d: torch.nn.functional.conv2d,
torch.nn.MaxPool2d: torch.nn.functional.max_pool2d,
}

for nn_cls, fn in NN_MODULES_MAPPING.items():
if fn in FN_ARGS_MAPPING:
FN_ARGS_MAPPING[nn_cls] = FN_ARGS_MAPPING[fn]


# Use the mapping from xiq
FN_ATEN_OPS_MAPPING = xiq._map_module_function_to_aten_operator_type()

Expand Down Expand Up @@ -117,6 +109,7 @@ def replace_fn_wrapper(fn_args, fn):

def _register_pattern_pair(dtype: torch.dtype) -> None:
for fn, fn_args in FN_ARGS_MAPPING.items():
logger.debug(f"Registering search and replace patterns for {fn} with args: {fn_args}.")
pattern_pair = pattern_factory(fn, fn_args)
HALF_PRECISION_PATTERN_REGISTRY[dtype][fn] = pattern_pair
utils.logger.debug(
Expand Down

0 comments on commit de59f73

Please sign in to comment.