Skip to content

Commit

Permalink
support slicing with symints in non-strict (pytorch#143217)
Browse files Browse the repository at this point in the history
Differential Revision: [D67215745](https://our.internmc.facebook.com/intern/diff/D67215745/)
Pull Request resolved: pytorch#143217
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Dec 14, 2024
1 parent 9933e59 commit de48413
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
4 changes: 2 additions & 2 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3799,7 +3799,7 @@ def forward(self, scores, score_thr, topk: torch.Tensor, results=None):
nms_pre = torch.tensor(4)
inputs = (score, score_thr, nms_pre, dict(bbox_pred=bbox_pred))

ep = torch.export.export(M(), inputs)
ep = export(M(), inputs)
orig_res = M()(*inputs)
ep_res = ep.module()(*inputs)
self.assertTrue(torch.allclose(orig_res[0], ep_res[0]))
Expand Down Expand Up @@ -5097,7 +5097,7 @@ def forward(self, start_pos: torch.Tensor):
torch._check(pos <= 4)
return self.freq[pos] * self.freq[pos]

ep = torch.export.export(M(), (torch.tensor(1),))
ep = export(M(), (torch.tensor(1),))
FileCheck().check_count(
"torch.ops.aten._assert_scalar.default", 2, exactly=True
).run(ep.graph_module.code)
Expand Down
36 changes: 30 additions & 6 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,18 +532,40 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
In particular, conditions on unbacked symints can appear outside such
calls, and as such are not handled here.
2. Handles line-of-code logging for each torch function call in non-strict.
2. Overrides torch functions that are known to cause problems in non-strict.
Certain Python features, such as indexing/slicing, cannot be intercepted
in non-strict. When these features need special handling in the compiler,
tracing can fail in non-strict (yet surprisingly succeed in strict).
Fortunately, redirecting to other torch functions can often fix such issues.
3. Handles line-of-code logging for each torch function call in non-strict.
Usage: TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC=1 TORCH_LOGS="+export" ...
"""

def _override(self, func, args, kwargs):
if func is torch.tensor:
# Redirect to Python implementation of torch.tensor for data with symints.
# NOTE(avik): We don't unconditionally redirect to this implementation
# because it has some known incompletenesses, e.g., it doesn't support
# empty data. See https://github.com/pytorch/pytorch/issues/143216
if any(
isinstance(a, torch.SymInt) for a in pytree.tree_flatten(args[0])[0]
):
return torch._refs.tensor, args, kwargs
if func.__name__ == "__getitem__" and isinstance(args[0], torch.Tensor):
# Redirect to torch.select for indexing with symint.
if isinstance(args[1], torch.SymInt):
return torch.select, [args[0], 0, args[1]], {}
return func, args, kwargs

def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if (
not torch.compiler.is_dynamo_compiling()
and log.isEnabledFor(logging.DEBUG)
and config.extended_debug_current_loc
):
if torch.compiler.is_dynamo_compiling():
return func(*args, **kwargs)

if log.isEnabledFor(logging.DEBUG) and config.extended_debug_current_loc:
frame = _find_user_code_frame()
if frame is not None:
log.debug(
Expand All @@ -553,6 +575,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
frame.f_lineno,
frame.f_code.co_name,
)

func, args, kwargs = self._override(func, args, kwargs)
try:
return func(*args, **kwargs)
except GuardOnDataDependentSymNode as e:
Expand Down

0 comments on commit de48413

Please sign in to comment.