Skip to content

Commit

Permalink
Add decomposition for view_copy (pytorch#130938)
Browse files Browse the repository at this point in the history
* Extracted from pytorch#128416
Pull Request resolved: pytorch#130938
Approved by: https://github.com/peterbell10
ghstack dependencies: pytorch#130937
  • Loading branch information
rec authored and pytorchmergebot committed Jul 21, 2024
1 parent f628813 commit 500cbb5
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 7 deletions.
4 changes: 4 additions & 0 deletions test/expect/HasDecompTest.test_aten_core_operators.expect
Original file line number Diff line number Diff line change
Expand Up @@ -524,5 +524,9 @@ aten::var.correction_out
aten::var_mean.correction
aten::var_mean.correction_out
aten::view
aten::view_copy
aten::view_copy.dtype
aten::view_copy.dtype_out
aten::view_copy.out
aten::where.self
aten::where.self_out
4 changes: 0 additions & 4 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -1345,10 +1345,6 @@ aten::view_as_complex_copy.out
aten::view_as_real
aten::view_as_real_copy
aten::view_as_real_copy.out
aten::view_copy
aten::view_copy.dtype
aten::view_copy.dtype_out
aten::view_copy.out
aten::zeros.names
aten::zeros.names_out
aten::zeros.out
1 change: 1 addition & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def mps_ops_modifier(ops):
'view_as',
'view_as_real',
'view',
'view_copy',
'vsplit',
'zero_',
'zeros',
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2063,7 +2063,6 @@ def f(t):
xfail('scatter', ''),
xfail('take_along_dim', ''),
xfail('triangular_solve', ''),
xfail('view_copy', ''),

# SymIntArrayRef expected to contain only concrete
xfail('ones', ''),
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
"reshape",
"reshape_as",
"view_as",
"view_copy",
"roll",
"clone",
"block_diag",
Expand Down
2 changes: 2 additions & 0 deletions torch/_refs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@
"unsqueeze",
"view",
"view_as",
"view_copy",
"vsplit",
"vstack",
"view_as_complex",
Expand Down Expand Up @@ -6317,6 +6318,7 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int)
# TODO: This must return a sparse tensor if the input is sparse, but refs have
# no sparse support. See narrow_copy_sparse in core.
narrow_copy = _make_copy_from_view(aten.narrow)
view_copy = _make_copy_from_view(aten.view)


# xref: isStorage in torch/csrc/DynamicTypes.cpp
Expand Down
16 changes: 14 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16873,14 +16873,21 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
)),
OpInfo('view_copy',
dtypes=all_types_and(torch.bool, torch.bfloat16, torch.float16),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
ref=lambda x, newshape: np.reshape(x, newshape).copy(),
supports_out=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_autograd=True,
sample_inputs_func=sample_inputs_view_reshape,
error_inputs_func=error_inputs_view_reshape),
error_inputs_func=error_inputs_view_reshape,
skips=(
# RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
DecorateInfo(
unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"
),
)),
UnaryUfuncInfo('neg',
aliases=('negative', ),
ref=np.negative,
Expand Down Expand Up @@ -23913,6 +23920,11 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
"_refs.view_as",
torch_opinfo_name="view_as",
),
PythonRefInfo(
"_refs.view_copy",
torch_opinfo_name="view_copy",
supports_out=True,
),
PythonRefInfo(
"_refs.vstack",
torch_opinfo_name="vstack",
Expand Down

0 comments on commit 500cbb5

Please sign in to comment.