Skip to content

Commit

Permalink
[dicp][ascend] add ascendgraph index_put, repeat_interleave, slice_sc…
Browse files Browse the repository at this point in the history
…atter (#669)
  • Loading branch information
CyCle1024 authored Jan 25, 2024
1 parent 5832af8 commit 88b5a18
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 33 deletions.
34 changes: 31 additions & 3 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def __init__(self):
super().__init__("Range")

def infer_result(self, start, limit=None, delta=None):
start, start_dtype, _, _ = get_op_const_arg_kwarg(start)
limit, limit_dtype, _, _ = get_op_const_arg_kwarg(limit)
delta, delta_dtype, _, _ = get_op_const_arg_kwarg(delta)
[start], start_dtype, _, _ = get_op_const_arg_kwarg(start)
[limit], limit_dtype, _, _ = get_op_const_arg_kwarg(limit)
[delta], delta_dtype, _, _ = get_op_const_arg_kwarg(delta)

assert start is not None, (
self.__class__.__name__ + ": input 'start' can't be None!"
Expand Down Expand Up @@ -972,6 +972,34 @@ def infer_result(self, x, multiples):
return torch.ops.aten.repeat.default(x, multiples)


class TileWithAxis(Operator):
def __init__(self):
super().__init__("TileWithAxis")
self.torch_op = aten.repeat_interleave.self_int


class TensorScatterUpdate(Operator):
def __init__(self):
super().__init__("TensorScatterUpdate")

def infer_result(self, x, indices, updates):
_, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
_, indices_shape, _, indices_dtype = get_fake_tensor_meta_val(indices)
_, updates_shape, _, _ = get_fake_tensor_meta_val(updates)
assert indices_dtype in (torch.int32, torch.int64)

# following shape constraints are from:
# https://tensorflow.google.cn/versions/r2.15/api_docs/
# python/tf/tensor_scatter_nd_update
assert indices.dim() >= 2
index_depth = indices_shape[-1]
batch_shape = indices_shape[:-1]
assert index_depth <= x_dim
inner_shape = x_shape[index_depth:]
assert updates_shape == batch_shape + inner_shape
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))


def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return a, b, c

Expand Down
18 changes: 17 additions & 1 deletion dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def Squeeze(name, x, dim):
return op.to_node()

@staticmethod
def Identity(name, input, index):
def Identity(name, input, index=None):
op = OP(name, "Identity")
if index is not None and isinstance(index, int):
op.set_input_with_index("x", input, index)
Expand Down Expand Up @@ -1565,3 +1565,19 @@ def LogicalOr(name, x, y):
op.set_input("x1", x)
op.set_input("x2", y)
return op.to_node()

@staticmethod
def TileWithAxis(name, x, axis, tiles):
op = OP(name, "TileWithAxis")
op.set_input("x", x)
op.set_attr_int("axis", axis)
op.set_attr_int("tiles", tiles)
return op.to_node()

@staticmethod
def TensorScatterUpdate(name, x, indices, updates):
op = OP(name, "TensorScatterUpdate")
op.set_input("x", x)
op.set_input("indices", indices)
op.set_input("updates", updates)
return op.to_node()
177 changes: 169 additions & 8 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,117 @@ def unsafe_index(self, x, index):
return gather
return self.index_base(x, 0, index)

def compute_stacked_indices(self, indices, src_shape):
assert len(indices) <= len(src_shape)
# Check whether all not None tensors in indices are 'continguous'
# e.g. [None, a, b, None, None] is 'continguous'
# [None, a, None, b, None] is not 'continguous'
# due to there's a None between a and b
#
# Also Count the number of None in indices,
# not using indices.count(None) due to torch.fx.proxy.TraceError:
# symbolically traced variables cannot be used as inputs to control flow
tensor_none_flag = False
contiguous_flag = True
none_count_in_indices = 0
for i in range(len(indices)):
if i < len(indices) - 1:
if indices[i] is not None and indices[i + 1] is None:
tensor_none_flag = True
if tensor_none_flag and indices[i] is None and indices[i + 1] is not None:
contiguous_flag = False
if indices[i] is None:
none_count_in_indices += 1

# collect None dim_size and tensor reshape shape
tensor_reshape_shape = []
none_dim_size = []
first_tensor_pos = -1
tensor_unsqueeze_len = 0
for i, index in enumerate(indices):
if index is None:
assert not isinstance(src_shape[i], torch.SymInt)
none_dim_size.append(src_shape[i])
else:
assert isinstance(index.node.meta['val'], torch.Tensor)
if first_tensor_pos == -1:
first_tensor_pos = i
tensor_unsqueeze_len = none_count_in_indices - i if contiguous_flag \
else none_count_in_indices
indice_i_shape = index.node.meta['val'].shape
assert not symint_in_shape(indice_i_shape)
tensor_reshape_shape.append(list(indice_i_shape) + [1] * tensor_unsqueeze_len)
assert first_tensor_pos != -1, "all elements of indices is None, unsupported"
tensor_broadcast_shape = list(torch.broadcast_shapes(*tensor_reshape_shape))

# in case contiguous_flag is True, e.g. [None, None, a, b, None]
# the tensor_broadcase_shape of (a, b) is inserted into final shape at
# the origin start position of the contiguous tensors, final e.g. shape:
# [src_shape[0], src_shape[1], *broadcase_shape_of(a,b), src_shape[4]]
#
# in case contiguous_flag is False, e.g. [None, a, None, b, None]
# the tensor_broadcase_shape of (a, b) is inserted into final shape at
# start of final shape, final e.g. shape:
# [*broadcase_shape_of(a,b), src_shape[0], src_shape[2], src_shape[4]
tensor_shape_insert_pos = first_tensor_pos if contiguous_flag else 0
# collect None reshape shape
none_reshape_shape = []
none_idx = 0
for i, index in enumerate(indices):
if index is None:
if i < tensor_shape_insert_pos:
none_unsqueeze_len = tensor_shape_insert_pos - i - 1 + len(tensor_broadcast_shape)
else:
none_unsqueeze_len = none_count_in_indices - 1 - none_idx
none_reshape_shape.append([none_dim_size[none_idx]] + [1] * none_unsqueeze_len)
none_idx += 1

# stack(pack) all index
target_indices_broadcast_shape = list(torch.broadcast_shapes(*(none_reshape_shape + tensor_reshape_shape)))
target_broadcast_shape_proxy = self.get_const_proxy(target_indices_broadcast_shape, torch.int32)
stack_input_list = []
const_int32_0 = self.get_const_proxy(0, torch.int32, target_shape=[])
const_int32_1 = self.get_const_proxy(1, torch.int32, target_shape=[])
none_idx = 0
tensor_idx = 0
for index in indices:
if index is None:
# for index that is None, range corresponding dim size, unsqueeze some dims and then broadcast to target result shape
const_range_max = self.get_const_proxy(none_dim_size[none_idx], torch.int32, target_shape=[])
to_be_reshape_proxy = self.get_proxy(ascend_op.Range, (const_int32_0, const_range_max, const_int32_1))
index_reshape_shape = none_reshape_shape[none_idx]
none_idx += 1
else:
to_be_reshape_proxy = index
index_reshape_shape = tensor_reshape_shape[tensor_idx]
tensor_idx += 1
reshape_shape_proxy = self.get_const_proxy(index_reshape_shape, torch.int32)
reshape_proxy = self.get_proxy(ascend_op.Reshape, (to_be_reshape_proxy, reshape_shape_proxy))
stack_input_list.append(self.get_proxy(ascend_op.BroadcastTo, (reshape_proxy, target_broadcast_shape_proxy)))
return self.get_proxy(ascend_op.Pack, (stack_input_list, len(target_indices_broadcast_shape))), \
target_indices_broadcast_shape, len(stack_input_list)

@register_conversion(torch.ops.aten.index_put.default)
def index_put_default(self, x, indices, values):
# following comment is from tensorflow tensor_scatter_nd_update:
# index_depth = indices.shape[-1]
# batch_shape = indices.shape[:-1]
# assert index_depth <= tf.rank(x)
# outer_shape = x.shape[:index_depth]
# inner_shape = x.shape[index_depth:]
# assert values.shape == batch_shape + inner_shape
#
# tf.tensor_scatter_nd_update param 'indices' is different from
# indices in torch.ops.aten.index_put.default, we use broadcast and
# stack to construct param 'indices' in tf.tensor_scatter_nd_update
x_shape = list(x.node.meta['val'].shape)
stacked_indices, indices_broadcast_shape, stacked_indices_last_dim = \
self.compute_stacked_indices(indices, x.node.meta['val'].shape)
values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape
values_broadcast_shape_op = self.get_const_proxy(values_broadcast_shape, torch.int32)
broadcasted_values = self.get_proxy(ascend_op.BroadcastTo, (values, values_broadcast_shape_op))
return self.get_proxy(ascend_op.TensorScatterUpdate, (x, stacked_indices, broadcasted_values))

@register_conversion(torch.ops.aten.index_select.default)
def index_arg3_(self, x, dim, index):
return self.index_base(x, dim, index)
Expand Down Expand Up @@ -1075,14 +1186,28 @@ def log_softmax_backward_data(self, grad_output, output, dim, input_dtype):
dim = [dim] if not isinstance(dim, list) else dim
return self.get_proxy(ascend_op.LogSoftmaxGrad, (grad_output, output, dim))

@register_conversion(torch.ops.aten.repeat_interleave)
def repeat_interleave(self, repeats, output_size=1):
x_shape = list(repeats.node.meta['val'].shape)
assert len(x_shape) == 1
assert x_shape[0] == 1
# TODO! fix implementation of repeatinterleave
# Consider situation for repeats > 1
return self.get_const_proxy(0, torch.int64, target_shape=[1])
@register_conversion(torch.ops.aten.repeat_interleave.self_int)
def repeat_interleave(self, x, repeats, dim):
# do not support dim is None or repeat is Tensor yet
assert dim is not None
assert not isinstance(repeats, torch.Tensor)
if repeats == 1:
return self.get_proxy(ascend_op.Identity, (x, ))
tile_with_axis = self.get_proxy(ascend_op.TileWithAxis, (x, dim, repeats))
x_shape = list(x.node.meta['val'].shape)
reshape_shape = x_shape[:dim] + [repeats] + x_shape[dim:]
reshape_shape_proxy = self.get_const_proxy(reshape_shape, torch.int32)
reshape_proxy = self.get_proxy(ascend_op.Reshape, (tile_with_axis, reshape_shape_proxy))

transpose_perm = list(range(len(reshape_shape)))
transpose_perm[dim], transpose_perm[dim + 1] = \
transpose_perm[dim + 1], transpose_perm[dim]
transpose_perm_proxy = self.get_shape_proxy(transpose_perm)
transpose_proxy = self.get_proxy(ascend_op.Transpose, (reshape_proxy, transpose_perm_proxy))

result_reshape = x_shape[:dim] + [x_shape[dim] * repeats] + x_shape[dim+1:]
result_reshape_shape_proxy = self.get_const_proxy(result_reshape, torch.int32)
return self.get_proxy(ascend_op.Reshape, (transpose_proxy, result_reshape_shape_proxy))

@register_conversion([aten.lift_fresh_copy, aten.lift_fresh_copy.default])
def lift_fresh_copy(self, tensor_constant):
Expand Down Expand Up @@ -1305,3 +1430,39 @@ def Ge(self, x, y):
@register_conversion(torch.ops.aten.logical_or.default)
def LogicalOr(self, x, y):
return self.get_proxy(ascend_op.LogicalOr, (x, y))

@register_conversion(torch.ops.aten.slice_scatter.default)
def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1):
# modified from torchair
if start is None:
start = 0
if end is None:
end = 9223372036854775807
if (isinstance(start, int) and start == 0) and (
isinstance(end, int) and end == 9223372036854775807) and (
isinstance(step, int) and step == 1):
return self.get_proxy(ascend_op.Identity, (src, ))

# repeat = self.get_const_proxy([repeat], torch.int32, target_shape=[1])
# default dtype of the output of ascend_op Shape is int32
operand_shape = self.get_proxy(ascend_op.Shape, (operand, ))
start = self.get_const_proxy(start, torch.int32, target_shape=[])
if end == 9223372036854775807:
gather_axis = self.get_const_proxy(0, torch.int32, target_shape=[])
end = self.get_proxy(ascend_op.GatherV2, (operand_shape, dim, gather_axis))
else:
end = self.get_const_proxy(end, torch.int32, target_shape=[])
step = self.get_const_proxy(step, torch.int32, target_shape=[])
indices = self.get_proxy(ascend_op.Range, (start, end, step))

dims_to_expand = [i for i in range(src.node.meta['val'].dim())]
dims_to_expand.remove(dim)

if dims_to_expand:
indices_unsqueezed = self.get_proxy(ascend_op.Unsqueeze, (indices, dims_to_expand))
src_shape = self.get_proxy(ascend_op.Shape, (src, ))
indices_expanded = self.get_proxy(ascend_op.Expand, (indices_unsqueezed, src_shape))
else:
indices_expanded = indices
return self.get_proxy(ascend_op.ScatterElements,
(operand, indices_expanded, src, dim))
16 changes: 16 additions & 0 deletions dicp/dicp/vendor/AscendGraph/pattern_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def replacement(input, dims):
varVal = torch.ops.aten.var(input, dims, correction=1, keepdim=True)
return ascend_op.ret_tuple(varVal, meanVal)

@register_aten_pattern
class FusedRepeatInterleaveSelfInt(BackendPatternBase):
@staticmethod
def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape,
expand_1_shape, repeat_interleave_output_size):
empty = torch.ops.aten.empty.memory_format(input_shape, dtype = torch.int64, layout = torch.strided, device=empty_device)
fill = torch.ops.aten.fill.Scalar(empty, repeat)
view_1 = torch.ops.aten.view.default(fill, view_1_shape)
expand_1 = torch.ops.aten.expand.default(view_1, expand_1_shape)
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(expand_1, output_size = repeat_interleave_output_size)
index_select = torch.ops.aten.index_select.default(self, dim, repeat_interleave)
return index_select

@staticmethod
def replacement(self, repeat, dim):
return torch.ops.aten.repeat_interleave.self_int(self, repeat, dim)

Const = torch.fx.wrap(ascend_op.Const.get_singleton())
Transpose = torch.fx.wrap(ascend_op.Transpose.get_singleton())
Expand Down
3 changes: 3 additions & 0 deletions dicp/test/ascend_scripts/ops/static.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ python_files =
test_gather.py
test_ge.py
test_getitem.py
test_index_put.py
test_index.py
test_le.py
; test_lift_fresh_copy.py
Expand All @@ -54,11 +55,13 @@ python_files =
test_scatter.py
test_select.py
test_sigmoid.py
test_slice_scatter.py
test_slice.py
test_sqrt.py
test_squeeze.py
test_sub.py
test_sum.py
test_repeat_interleave.py
; test_transpose.py
test_tril.py
test_unsqueeze.py
Expand Down
61 changes: 61 additions & 0 deletions dicp/test/op/test_index_put.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, x, indices, values):
res_Tensor = torch.ops.aten.index_put.default(x, indices, values)
return res_Tensor


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestIndexPut():
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("sizes", [Size(((1, 32, 208, 128), (None, None, (6,)), (32, 6, 128)),
((1, 32, 208, 128), (None, None, (6,)), (32, 6, 128))),
Size(((1, 2, 10, 8, 7, 11), (None, None, (2, 3), (4, 1, 1), (1, 2, 1), None),
(1, 1, 4, 1, 3, 11)),
((1, 2, 10, 8 ,7, 11), (None, None, (2, 3), (4, 1, 1), (1, 2, 1), None),
(1, 1, 4, 1, 3, 11))),
Size(((1, 2, 10, 8, 7, 11), (None, None, (2, 3), (4, 1, 1), None, (1, 2, 1)),
(4, 2, 3, 1, 2, 7)),
((1, 2, 10, 8 ,7, 11), (None, None, (2, 3), (4, 1, 1), None, (1, 2, 1)),
(4, 2, 3, 1, 2, 7)))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_split(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
x_size = size[0]
indices_size_tuple = size[1]
values_size = size[2]

input1 = torch.randn(x_size, dtype=dtype)
indices = []
for dim_idx, idx_size in enumerate(indices_size_tuple):
if idx_size is None:
indices.append(None)
else:
indices.append(torch.randint(x_size[dim_idx], idx_size, dtype=torch.int32))
value = torch.randn(values_size, dtype=dtype)
dicp_input1 = input1.to(device)
dicp_indices = [None if index is None else index.to(device) for index in indices]
dicp_value = value.to(device)

output = model(input1, indices, value)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1, dicp_indices, dicp_value)

assert torch.allclose(output.cpu(), dicp_output.cpu(), equal_nan=True)
Loading

0 comments on commit 88b5a18

Please sign in to comment.