diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index bb4b40bc2..2c3e68dcd 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -63,9 +63,9 @@ def __init__(self): super().__init__("Range") def infer_result(self, start, limit=None, delta=None): - start, start_dtype, _, _ = get_op_const_arg_kwarg(start) - limit, limit_dtype, _, _ = get_op_const_arg_kwarg(limit) - delta, delta_dtype, _, _ = get_op_const_arg_kwarg(delta) + [start], start_dtype, _, _ = get_op_const_arg_kwarg(start) + [limit], limit_dtype, _, _ = get_op_const_arg_kwarg(limit) + [delta], delta_dtype, _, _ = get_op_const_arg_kwarg(delta) assert start is not None, ( self.__class__.__name__ + ": input 'start' can't be None!" @@ -972,6 +972,34 @@ def infer_result(self, x, multiples): return torch.ops.aten.repeat.default(x, multiples) +class TileWithAxis(Operator): + def __init__(self): + super().__init__("TileWithAxis") + self.torch_op = aten.repeat_interleave.self_int + + +class TensorScatterUpdate(Operator): + def __init__(self): + super().__init__("TensorScatterUpdate") + + def infer_result(self, x, indices, updates): + _, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x) + _, indices_shape, _, indices_dtype = get_fake_tensor_meta_val(indices) + _, updates_shape, _, _ = get_fake_tensor_meta_val(updates) + assert indices_dtype in (torch.int32, torch.int64) + + # following shape constraints are from: + # https://tensorflow.google.cn/versions/r2.15/api_docs/ + # python/tf/tensor_scatter_nd_update + assert indices.dim() >= 2 + index_depth = indices_shape[-1] + batch_shape = indices_shape[:-1] + assert index_depth <= x_dim + inner_shape = x_shape[index_depth:] + assert updates_shape == batch_shape + inner_shape + return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x)) + + def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return a, b, c diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index d43046233..c367627af 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -909,7 +909,7 @@ def Squeeze(name, x, dim): return op.to_node() @staticmethod - def Identity(name, input, index): + def Identity(name, input, index=None): op = OP(name, "Identity") if index is not None and isinstance(index, int): op.set_input_with_index("x", input, index) @@ -1565,3 +1565,19 @@ def LogicalOr(name, x, y): op.set_input("x1", x) op.set_input("x2", y) return op.to_node() + + @staticmethod + def TileWithAxis(name, x, axis, tiles): + op = OP(name, "TileWithAxis") + op.set_input("x", x) + op.set_attr_int("axis", axis) + op.set_attr_int("tiles", tiles) + return op.to_node() + + @staticmethod + def TensorScatterUpdate(name, x, indices, updates): + op = OP(name, "TensorScatterUpdate") + op.set_input("x", x) + op.set_input("indices", indices) + op.set_input("updates", updates) + return op.to_node() diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index e57c5133e..fd4955e1b 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -739,6 +739,117 @@ def unsafe_index(self, x, index): return gather return self.index_base(x, 0, index) + def compute_stacked_indices(self, indices, src_shape): + assert len(indices) <= len(src_shape) + # Check whether all not None tensors in indices are 'continguous' + # e.g. [None, a, b, None, None] is 'continguous' + # [None, a, None, b, None] is not 'continguous' + # due to there's a None between a and b + # + # Also Count the number of None in indices, + # not using indices.count(None) due to torch.fx.proxy.TraceError: + # symbolically traced variables cannot be used as inputs to control flow + tensor_none_flag = False + contiguous_flag = True + none_count_in_indices = 0 + for i in range(len(indices)): + if i < len(indices) - 1: + if indices[i] is not None and indices[i + 1] is None: + tensor_none_flag = True + if tensor_none_flag and indices[i] is None and indices[i + 1] is not None: + contiguous_flag = False + if indices[i] is None: + none_count_in_indices += 1 + + # collect None dim_size and tensor reshape shape + tensor_reshape_shape = [] + none_dim_size = [] + first_tensor_pos = -1 + tensor_unsqueeze_len = 0 + for i, index in enumerate(indices): + if index is None: + assert not isinstance(src_shape[i], torch.SymInt) + none_dim_size.append(src_shape[i]) + else: + assert isinstance(index.node.meta['val'], torch.Tensor) + if first_tensor_pos == -1: + first_tensor_pos = i + tensor_unsqueeze_len = none_count_in_indices - i if contiguous_flag \ + else none_count_in_indices + indice_i_shape = index.node.meta['val'].shape + assert not symint_in_shape(indice_i_shape) + tensor_reshape_shape.append(list(indice_i_shape) + [1] * tensor_unsqueeze_len) + assert first_tensor_pos != -1, "all elements of indices is None, unsupported" + tensor_broadcast_shape = list(torch.broadcast_shapes(*tensor_reshape_shape)) + + # in case contiguous_flag is True, e.g. [None, None, a, b, None] + # the tensor_broadcase_shape of (a, b) is inserted into final shape at + # the origin start position of the contiguous tensors, final e.g. shape: + # [src_shape[0], src_shape[1], *broadcase_shape_of(a,b), src_shape[4]] + # + # in case contiguous_flag is False, e.g. [None, a, None, b, None] + # the tensor_broadcase_shape of (a, b) is inserted into final shape at + # start of final shape, final e.g. shape: + # [*broadcase_shape_of(a,b), src_shape[0], src_shape[2], src_shape[4] + tensor_shape_insert_pos = first_tensor_pos if contiguous_flag else 0 + # collect None reshape shape + none_reshape_shape = [] + none_idx = 0 + for i, index in enumerate(indices): + if index is None: + if i < tensor_shape_insert_pos: + none_unsqueeze_len = tensor_shape_insert_pos - i - 1 + len(tensor_broadcast_shape) + else: + none_unsqueeze_len = none_count_in_indices - 1 - none_idx + none_reshape_shape.append([none_dim_size[none_idx]] + [1] * none_unsqueeze_len) + none_idx += 1 + + # stack(pack) all index + target_indices_broadcast_shape = list(torch.broadcast_shapes(*(none_reshape_shape + tensor_reshape_shape))) + target_broadcast_shape_proxy = self.get_const_proxy(target_indices_broadcast_shape, torch.int32) + stack_input_list = [] + const_int32_0 = self.get_const_proxy(0, torch.int32, target_shape=[]) + const_int32_1 = self.get_const_proxy(1, torch.int32, target_shape=[]) + none_idx = 0 + tensor_idx = 0 + for index in indices: + if index is None: + # for index that is None, range corresponding dim size, unsqueeze some dims and then broadcast to target result shape + const_range_max = self.get_const_proxy(none_dim_size[none_idx], torch.int32, target_shape=[]) + to_be_reshape_proxy = self.get_proxy(ascend_op.Range, (const_int32_0, const_range_max, const_int32_1)) + index_reshape_shape = none_reshape_shape[none_idx] + none_idx += 1 + else: + to_be_reshape_proxy = index + index_reshape_shape = tensor_reshape_shape[tensor_idx] + tensor_idx += 1 + reshape_shape_proxy = self.get_const_proxy(index_reshape_shape, torch.int32) + reshape_proxy = self.get_proxy(ascend_op.Reshape, (to_be_reshape_proxy, reshape_shape_proxy)) + stack_input_list.append(self.get_proxy(ascend_op.BroadcastTo, (reshape_proxy, target_broadcast_shape_proxy))) + return self.get_proxy(ascend_op.Pack, (stack_input_list, len(target_indices_broadcast_shape))), \ + target_indices_broadcast_shape, len(stack_input_list) + + @register_conversion(torch.ops.aten.index_put.default) + def index_put_default(self, x, indices, values): + # following comment is from tensorflow tensor_scatter_nd_update: + # index_depth = indices.shape[-1] + # batch_shape = indices.shape[:-1] + # assert index_depth <= tf.rank(x) + # outer_shape = x.shape[:index_depth] + # inner_shape = x.shape[index_depth:] + # assert values.shape == batch_shape + inner_shape + # + # tf.tensor_scatter_nd_update param 'indices' is different from + # indices in torch.ops.aten.index_put.default, we use broadcast and + # stack to construct param 'indices' in tf.tensor_scatter_nd_update + x_shape = list(x.node.meta['val'].shape) + stacked_indices, indices_broadcast_shape, stacked_indices_last_dim = \ + self.compute_stacked_indices(indices, x.node.meta['val'].shape) + values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape + values_broadcast_shape_op = self.get_const_proxy(values_broadcast_shape, torch.int32) + broadcasted_values = self.get_proxy(ascend_op.BroadcastTo, (values, values_broadcast_shape_op)) + return self.get_proxy(ascend_op.TensorScatterUpdate, (x, stacked_indices, broadcasted_values)) + @register_conversion(torch.ops.aten.index_select.default) def index_arg3_(self, x, dim, index): return self.index_base(x, dim, index) @@ -1075,14 +1186,28 @@ def log_softmax_backward_data(self, grad_output, output, dim, input_dtype): dim = [dim] if not isinstance(dim, list) else dim return self.get_proxy(ascend_op.LogSoftmaxGrad, (grad_output, output, dim)) - @register_conversion(torch.ops.aten.repeat_interleave) - def repeat_interleave(self, repeats, output_size=1): - x_shape = list(repeats.node.meta['val'].shape) - assert len(x_shape) == 1 - assert x_shape[0] == 1 - # TODO! fix implementation of repeatinterleave - # Consider situation for repeats > 1 - return self.get_const_proxy(0, torch.int64, target_shape=[1]) + @register_conversion(torch.ops.aten.repeat_interleave.self_int) + def repeat_interleave(self, x, repeats, dim): + # do not support dim is None or repeat is Tensor yet + assert dim is not None + assert not isinstance(repeats, torch.Tensor) + if repeats == 1: + return self.get_proxy(ascend_op.Identity, (x, )) + tile_with_axis = self.get_proxy(ascend_op.TileWithAxis, (x, dim, repeats)) + x_shape = list(x.node.meta['val'].shape) + reshape_shape = x_shape[:dim] + [repeats] + x_shape[dim:] + reshape_shape_proxy = self.get_const_proxy(reshape_shape, torch.int32) + reshape_proxy = self.get_proxy(ascend_op.Reshape, (tile_with_axis, reshape_shape_proxy)) + + transpose_perm = list(range(len(reshape_shape))) + transpose_perm[dim], transpose_perm[dim + 1] = \ + transpose_perm[dim + 1], transpose_perm[dim] + transpose_perm_proxy = self.get_shape_proxy(transpose_perm) + transpose_proxy = self.get_proxy(ascend_op.Transpose, (reshape_proxy, transpose_perm_proxy)) + + result_reshape = x_shape[:dim] + [x_shape[dim] * repeats] + x_shape[dim+1:] + result_reshape_shape_proxy = self.get_const_proxy(result_reshape, torch.int32) + return self.get_proxy(ascend_op.Reshape, (transpose_proxy, result_reshape_shape_proxy)) @register_conversion([aten.lift_fresh_copy, aten.lift_fresh_copy.default]) def lift_fresh_copy(self, tensor_constant): @@ -1305,3 +1430,39 @@ def Ge(self, x, y): @register_conversion(torch.ops.aten.logical_or.default) def LogicalOr(self, x, y): return self.get_proxy(ascend_op.LogicalOr, (x, y)) + + @register_conversion(torch.ops.aten.slice_scatter.default) + def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1): + # modified from torchair + if start is None: + start = 0 + if end is None: + end = 9223372036854775807 + if (isinstance(start, int) and start == 0) and ( + isinstance(end, int) and end == 9223372036854775807) and ( + isinstance(step, int) and step == 1): + return self.get_proxy(ascend_op.Identity, (src, )) + + # repeat = self.get_const_proxy([repeat], torch.int32, target_shape=[1]) + # default dtype of the output of ascend_op Shape is int32 + operand_shape = self.get_proxy(ascend_op.Shape, (operand, )) + start = self.get_const_proxy(start, torch.int32, target_shape=[]) + if end == 9223372036854775807: + gather_axis = self.get_const_proxy(0, torch.int32, target_shape=[]) + end = self.get_proxy(ascend_op.GatherV2, (operand_shape, dim, gather_axis)) + else: + end = self.get_const_proxy(end, torch.int32, target_shape=[]) + step = self.get_const_proxy(step, torch.int32, target_shape=[]) + indices = self.get_proxy(ascend_op.Range, (start, end, step)) + + dims_to_expand = [i for i in range(src.node.meta['val'].dim())] + dims_to_expand.remove(dim) + + if dims_to_expand: + indices_unsqueezed = self.get_proxy(ascend_op.Unsqueeze, (indices, dims_to_expand)) + src_shape = self.get_proxy(ascend_op.Shape, (src, )) + indices_expanded = self.get_proxy(ascend_op.Expand, (indices_unsqueezed, src_shape)) + else: + indices_expanded = indices + return self.get_proxy(ascend_op.ScatterElements, + (operand, indices_expanded, src, dim)) diff --git a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py index c8480a196..5aa8122c0 100644 --- a/dicp/dicp/vendor/AscendGraph/pattern_replacement.py +++ b/dicp/dicp/vendor/AscendGraph/pattern_replacement.py @@ -27,6 +27,22 @@ def replacement(input, dims): varVal = torch.ops.aten.var(input, dims, correction=1, keepdim=True) return ascend_op.ret_tuple(varVal, meanVal) +@register_aten_pattern +class FusedRepeatInterleaveSelfInt(BackendPatternBase): + @staticmethod + def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape, + expand_1_shape, repeat_interleave_output_size): + empty = torch.ops.aten.empty.memory_format(input_shape, dtype = torch.int64, layout = torch.strided, device=empty_device) + fill = torch.ops.aten.fill.Scalar(empty, repeat) + view_1 = torch.ops.aten.view.default(fill, view_1_shape) + expand_1 = torch.ops.aten.expand.default(view_1, expand_1_shape) + repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(expand_1, output_size = repeat_interleave_output_size) + index_select = torch.ops.aten.index_select.default(self, dim, repeat_interleave) + return index_select + + @staticmethod + def replacement(self, repeat, dim): + return torch.ops.aten.repeat_interleave.self_int(self, repeat, dim) Const = torch.fx.wrap(ascend_op.Const.get_singleton()) Transpose = torch.fx.wrap(ascend_op.Transpose.get_singleton()) diff --git a/dicp/test/ascend_scripts/ops/static.ini b/dicp/test/ascend_scripts/ops/static.ini index d408dc19a..5870062fe 100644 --- a/dicp/test/ascend_scripts/ops/static.ini +++ b/dicp/test/ascend_scripts/ops/static.ini @@ -29,6 +29,7 @@ python_files = test_gather.py test_ge.py test_getitem.py + test_index_put.py test_index.py test_le.py ; test_lift_fresh_copy.py @@ -54,11 +55,13 @@ python_files = test_scatter.py test_select.py test_sigmoid.py + test_slice_scatter.py test_slice.py test_sqrt.py test_squeeze.py test_sub.py test_sum.py + test_repeat_interleave.py ; test_transpose.py test_tril.py test_unsqueeze.py diff --git a/dicp/test/op/test_index_put.py b/dicp/test/op/test_index_put.py new file mode 100644 index 000000000..98beeb8c3 --- /dev/null +++ b/dicp/test/op/test_index_put.py @@ -0,0 +1,61 @@ +import pytest +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, x, indices, values): + res_Tensor = torch.ops.aten.index_put.default(x, indices, values) + return res_Tensor + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestIndexPut(): + @pytest.mark.parametrize("dtype", [torch.float16]) + @pytest.mark.parametrize("sizes", [Size(((1, 32, 208, 128), (None, None, (6,)), (32, 6, 128)), + ((1, 32, 208, 128), (None, None, (6,)), (32, 6, 128))), + Size(((1, 2, 10, 8, 7, 11), (None, None, (2, 3), (4, 1, 1), (1, 2, 1), None), + (1, 1, 4, 1, 3, 11)), + ((1, 2, 10, 8 ,7, 11), (None, None, (2, 3), (4, 1, 1), (1, 2, 1), None), + (1, 1, 4, 1, 3, 11))), + Size(((1, 2, 10, 8, 7, 11), (None, None, (2, 3), (4, 1, 1), None, (1, 2, 1)), + (4, 2, 3, 1, 2, 7)), + ((1, 2, 10, 8 ,7, 11), (None, None, (2, 3), (4, 1, 1), None, (1, 2, 1)), + (4, 2, 3, 1, 2, 7)))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_split(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + x_size = size[0] + indices_size_tuple = size[1] + values_size = size[2] + + input1 = torch.randn(x_size, dtype=dtype) + indices = [] + for dim_idx, idx_size in enumerate(indices_size_tuple): + if idx_size is None: + indices.append(None) + else: + indices.append(torch.randint(x_size[dim_idx], idx_size, dtype=torch.int32)) + value = torch.randn(values_size, dtype=dtype) + dicp_input1 = input1.to(device) + dicp_indices = [None if index is None else index.to(device) for index in indices] + dicp_value = value.to(device) + + output = model(input1, indices, value) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1, dicp_indices, dicp_value) + + assert torch.allclose(output.cpu(), dicp_output.cpu(), equal_nan=True) diff --git a/dicp/test/op/test_repeat_interleave.py b/dicp/test/op/test_repeat_interleave.py new file mode 100644 index 000000000..c09ff5b4c --- /dev/null +++ b/dicp/test/op/test_repeat_interleave.py @@ -0,0 +1,44 @@ +import pytest +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, x, repeats, dim): + res_default = x.repeat_interleave(repeats, dim) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestRepeat(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size(((3, 5), 5, 0), ((3, 5), 5, 0)), + Size(((4, 6, 8), 2, 1), ((4, 6, 8), 2, 1)), + Size(((4, 3, 2, 2), 3, -1), ((4, 3, 2, 2), 3, -1))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_repeat_self_int(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size[0], dtype=dtype) + repeats = size[1] + dim = size[2] + + dicp_input1 = input1.to(device) + + output = model(input1, repeats, dim) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1, repeats, dim) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) diff --git a/dicp/test/op/test_slice_scatter.py b/dicp/test/op/test_slice_scatter.py index 6cdb4fe79..f3f43f251 100644 --- a/dicp/test/op/test_slice_scatter.py +++ b/dicp/test/op/test_slice_scatter.py @@ -11,8 +11,9 @@ class OpModule(torch.nn.Module): - def forward(self, x, slice, dim): - res_default = torch.ops.aten.slice_scatter.default(x, slice, dim=dim, start=2, end=16, step=1) + def forward(self, x, slice, dim, start, end, step): + res_default = torch.ops.aten.slice_scatter.default( + x, slice, dim=dim, start=start, end=end, step=step) return res_default @@ -21,25 +22,42 @@ def forward(self, x, slice, dim): compiled_model = compile_model(model, args.backend, args.dynamic) +def torch_slice_scatter_test_base(sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size[0], dtype=dtype) + input2 = torch.randn(size[1], dtype=dtype) + dim = size[2] + start, end, step = size[3:] + + dicp_input1 = input1.to(device) + dicp_input2 = input2.to(device) + + output = model(input1, input2, dim, start, end, step) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1, dicp_input2, dim, start, end, step) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) + + class TestSliceScatter(): + @pytest.mark.skipif(args.backend != "ascendgraph", + reason="This is the test case for slice_scatter in ascendgraph!") + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size(((1, 32, 208, 128), (1, 32, 208, 128), 0, 0, 9223372036854775807, 1), + ((1, 32, 208, 128), (1, 32, 208, 128), 0, 0, 9223372036854775807, 1)), + Size(((1, 32, 208, 128), (1, 32, 208, 128), 1, 0, 9223372036854775807, 1), + ((1, 32, 208, 128), (1, 32, 208, 128), 1, 0, 9223372036854775807, 1))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_slice_scatter_ascend(self, sizes, dtype, compiled_model): + torch_slice_scatter_test_base(sizes, dtype, compiled_model) + @pytest.mark.parametrize("dtype", [torch.float32]) - @pytest.mark.parametrize("sizes", [Size(((16,), (14,), 0), ((16, 32), (3, 32), 0)), - Size(((32, 16), (32, 14), 1), ((32, 16), (32, 14), 1)), - Size(((32, 64, 16), (32, 64, 14), 2), ((32, 64), (33, 62), 2))]) + @pytest.mark.parametrize("sizes", [Size(((16,), (14,), 0, 2, 16, 1), ((16, 32), (3, 32), 0, 2, 16, 1)), + Size(((32, 16), (32, 14), 1, 2, 16, 1), ((32, 16), (32, 14), 1, 2, 16, 1)), + Size(((32, 64, 16), (32, 64, 14), 2, 2, 16, 1), + ((32, 64), (33, 62), 2, 2, 16, 1))]) @pytest.mark.parametrize("compiled_model", compiled_model) - def test_torch_slice_scatter(self, sizes, dtype, compiled_model): - device = get_device() - size = sizes.dynamic if compiled_model.dynamic else sizes.static - input1 = torch.randn(size[0], dtype=dtype) - input2 = torch.randn(size[1], dtype=dtype) - dim = size[2] - - dicp_input1 = input1.to(device) - dicp_input2 = input2.to(device) - - output = model(input1, input2, dim) - dynamo.reset() - update_dynamo_config(compiled_model.dynamic) - dicp_output = compiled_model.model(dicp_input1, dicp_input2, dim) - - assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) + def test_torch_slice_scatter_all(self, sizes, dtype, compiled_model): + torch_slice_scatter_test_base(sizes, dtype, compiled_model)