Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp] add ascendgraph index_put, repeat_interleave, slice_scatter #669

Merged
merged 1 commit into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def __init__(self):
super().__init__("Range")

def infer_result(self, start, limit=None, delta=None):
start, start_dtype, _, _ = get_op_const_arg_kwarg(start)
limit, limit_dtype, _, _ = get_op_const_arg_kwarg(limit)
delta, delta_dtype, _, _ = get_op_const_arg_kwarg(delta)
[start], start_dtype, _, _ = get_op_const_arg_kwarg(start)
[limit], limit_dtype, _, _ = get_op_const_arg_kwarg(limit)
[delta], delta_dtype, _, _ = get_op_const_arg_kwarg(delta)

assert start is not None, (
self.__class__.__name__ + ": input 'start' can't be None!"
Expand Down Expand Up @@ -972,6 +972,34 @@ def infer_result(self, x, multiples):
return torch.ops.aten.repeat.default(x, multiples)


class TileWithAxis(Operator):
def __init__(self):
super().__init__("TileWithAxis")
self.torch_op = aten.repeat_interleave.self_int


class TensorScatterUpdate(Operator):
def __init__(self):
super().__init__("TensorScatterUpdate")

def infer_result(self, x, indices, updates):
_, x_shape, x_dim, x_dtype = get_fake_tensor_meta_val(x)
_, indices_shape, _, indices_dtype = get_fake_tensor_meta_val(indices)
_, updates_shape, _, _ = get_fake_tensor_meta_val(updates)
assert indices_dtype in (torch.int32, torch.int64)

# following shape constraints are from:
# https://tensorflow.google.cn/versions/r2.15/api_docs/
# python/tf/tensor_scatter_nd_update
assert indices.dim() >= 2
index_depth = indices_shape[-1]
batch_shape = indices_shape[:-1]
assert index_depth <= x_dim
inner_shape = x_shape[index_depth:]
assert updates_shape == batch_shape + inner_shape
return torch.empty(x_shape, dtype=x_dtype, memory_format=get_memory_format(x))


def ret_triple(a, b, c) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return a, b, c

Expand Down
18 changes: 17 additions & 1 deletion dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def Squeeze(name, x, dim):
return op.to_node()

@staticmethod
def Identity(name, input, index):
def Identity(name, input, index=None):
op = OP(name, "Identity")
if index is not None and isinstance(index, int):
op.set_input_with_index("x", input, index)
Expand Down Expand Up @@ -1565,3 +1565,19 @@ def LogicalOr(name, x, y):
op.set_input("x1", x)
op.set_input("x2", y)
return op.to_node()

@staticmethod
def TileWithAxis(name, x, axis, tiles):
op = OP(name, "TileWithAxis")
op.set_input("x", x)
op.set_attr_int("axis", axis)
op.set_attr_int("tiles", tiles)
return op.to_node()

@staticmethod
def TensorScatterUpdate(name, x, indices, updates):
op = OP(name, "TensorScatterUpdate")
op.set_input("x", x)
op.set_input("indices", indices)
op.set_input("updates", updates)
return op.to_node()
177 changes: 169 additions & 8 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,117 @@ def unsafe_index(self, x, index):
return gather
return self.index_base(x, 0, index)

def compute_stacked_indices(self, indices, src_shape):
assert len(indices) <= len(src_shape)
# Check whether all not None tensors in indices are 'continguous'
# e.g. [None, a, b, None, None] is 'continguous'
# [None, a, None, b, None] is not 'continguous'
# due to there's a None between a and b
#
# Also Count the number of None in indices,
# not using indices.count(None) due to torch.fx.proxy.TraceError:
# symbolically traced variables cannot be used as inputs to control flow
tensor_none_flag = False
contiguous_flag = True
none_count_in_indices = 0
for i in range(len(indices)):
if i < len(indices) - 1:
if indices[i] is not None and indices[i + 1] is None:
tensor_none_flag = True
if tensor_none_flag and indices[i] is None and indices[i + 1] is not None:
contiguous_flag = False
if indices[i] is None:
none_count_in_indices += 1

# collect None dim_size and tensor reshape shape
tensor_reshape_shape = []
none_dim_size = []
first_tensor_pos = -1
tensor_unsqueeze_len = 0
for i, index in enumerate(indices):
if index is None:
assert not isinstance(src_shape[i], torch.SymInt)
none_dim_size.append(src_shape[i])
else:
assert isinstance(index.node.meta['val'], torch.Tensor)
if first_tensor_pos == -1:
first_tensor_pos = i
tensor_unsqueeze_len = none_count_in_indices - i if contiguous_flag \
else none_count_in_indices
indice_i_shape = index.node.meta['val'].shape
assert not symint_in_shape(indice_i_shape)
tensor_reshape_shape.append(list(indice_i_shape) + [1] * tensor_unsqueeze_len)
assert first_tensor_pos != -1, "all elements of indices is None, unsupported"
tensor_broadcast_shape = list(torch.broadcast_shapes(*tensor_reshape_shape))

# in case contiguous_flag is True, e.g. [None, None, a, b, None]
# the tensor_broadcase_shape of (a, b) is inserted into final shape at
# the origin start position of the contiguous tensors, final e.g. shape:
# [src_shape[0], src_shape[1], *broadcase_shape_of(a,b), src_shape[4]]
#
# in case contiguous_flag is False, e.g. [None, a, None, b, None]
# the tensor_broadcase_shape of (a, b) is inserted into final shape at
# start of final shape, final e.g. shape:
# [*broadcase_shape_of(a,b), src_shape[0], src_shape[2], src_shape[4]
tensor_shape_insert_pos = first_tensor_pos if contiguous_flag else 0
# collect None reshape shape
none_reshape_shape = []
none_idx = 0
for i, index in enumerate(indices):
if index is None:
if i < tensor_shape_insert_pos:
none_unsqueeze_len = tensor_shape_insert_pos - i - 1 + len(tensor_broadcast_shape)
else:
none_unsqueeze_len = none_count_in_indices - 1 - none_idx
none_reshape_shape.append([none_dim_size[none_idx]] + [1] * none_unsqueeze_len)
none_idx += 1

# stack(pack) all index
target_indices_broadcast_shape = list(torch.broadcast_shapes(*(none_reshape_shape + tensor_reshape_shape)))
target_broadcast_shape_proxy = self.get_const_proxy(target_indices_broadcast_shape, torch.int32)
stack_input_list = []
const_int32_0 = self.get_const_proxy(0, torch.int32, target_shape=[])
const_int32_1 = self.get_const_proxy(1, torch.int32, target_shape=[])
none_idx = 0
tensor_idx = 0
for index in indices:
if index is None:
# for index that is None, range corresponding dim size, unsqueeze some dims and then broadcast to target result shape
const_range_max = self.get_const_proxy(none_dim_size[none_idx], torch.int32, target_shape=[])
to_be_reshape_proxy = self.get_proxy(ascend_op.Range, (const_int32_0, const_range_max, const_int32_1))
index_reshape_shape = none_reshape_shape[none_idx]
none_idx += 1
else:
to_be_reshape_proxy = index
index_reshape_shape = tensor_reshape_shape[tensor_idx]
tensor_idx += 1
reshape_shape_proxy = self.get_const_proxy(index_reshape_shape, torch.int32)
reshape_proxy = self.get_proxy(ascend_op.Reshape, (to_be_reshape_proxy, reshape_shape_proxy))
stack_input_list.append(self.get_proxy(ascend_op.BroadcastTo, (reshape_proxy, target_broadcast_shape_proxy)))
return self.get_proxy(ascend_op.Pack, (stack_input_list, len(target_indices_broadcast_shape))), \
target_indices_broadcast_shape, len(stack_input_list)

@register_conversion(torch.ops.aten.index_put.default)
def index_put_default(self, x, indices, values):
# following comment is from tensorflow tensor_scatter_nd_update:
# index_depth = indices.shape[-1]
# batch_shape = indices.shape[:-1]
# assert index_depth <= tf.rank(x)
# outer_shape = x.shape[:index_depth]
# inner_shape = x.shape[index_depth:]
# assert values.shape == batch_shape + inner_shape
#
# tf.tensor_scatter_nd_update param 'indices' is different from
# indices in torch.ops.aten.index_put.default, we use broadcast and
# stack to construct param 'indices' in tf.tensor_scatter_nd_update
x_shape = list(x.node.meta['val'].shape)
stacked_indices, indices_broadcast_shape, stacked_indices_last_dim = \
self.compute_stacked_indices(indices, x.node.meta['val'].shape)
values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape
values_broadcast_shape_op = self.get_const_proxy(values_broadcast_shape, torch.int32)
broadcasted_values = self.get_proxy(ascend_op.BroadcastTo, (values, values_broadcast_shape_op))
return self.get_proxy(ascend_op.TensorScatterUpdate, (x, stacked_indices, broadcasted_values))

@register_conversion(torch.ops.aten.index_select.default)
def index_arg3_(self, x, dim, index):
return self.index_base(x, dim, index)
Expand Down Expand Up @@ -1075,14 +1186,28 @@ def log_softmax_backward_data(self, grad_output, output, dim, input_dtype):
dim = [dim] if not isinstance(dim, list) else dim
return self.get_proxy(ascend_op.LogSoftmaxGrad, (grad_output, output, dim))

@register_conversion(torch.ops.aten.repeat_interleave)
def repeat_interleave(self, repeats, output_size=1):
x_shape = list(repeats.node.meta['val'].shape)
assert len(x_shape) == 1
assert x_shape[0] == 1
# TODO! fix implementation of repeatinterleave
# Consider situation for repeats > 1
return self.get_const_proxy(0, torch.int64, target_shape=[1])
@register_conversion(torch.ops.aten.repeat_interleave.self_int)
def repeat_interleave(self, x, repeats, dim):
# do not support dim is None or repeat is Tensor yet
assert dim is not None
assert not isinstance(repeats, torch.Tensor)
if repeats == 1:
return self.get_proxy(ascend_op.Identity, (x, ))
tile_with_axis = self.get_proxy(ascend_op.TileWithAxis, (x, dim, repeats))
x_shape = list(x.node.meta['val'].shape)
reshape_shape = x_shape[:dim] + [repeats] + x_shape[dim:]
reshape_shape_proxy = self.get_const_proxy(reshape_shape, torch.int32)
reshape_proxy = self.get_proxy(ascend_op.Reshape, (tile_with_axis, reshape_shape_proxy))

transpose_perm = list(range(len(reshape_shape)))
transpose_perm[dim], transpose_perm[dim + 1] = \
transpose_perm[dim + 1], transpose_perm[dim]
transpose_perm_proxy = self.get_shape_proxy(transpose_perm)
transpose_proxy = self.get_proxy(ascend_op.Transpose, (reshape_proxy, transpose_perm_proxy))

result_reshape = x_shape[:dim] + [x_shape[dim] * repeats] + x_shape[dim+1:]
result_reshape_shape_proxy = self.get_const_proxy(result_reshape, torch.int32)
return self.get_proxy(ascend_op.Reshape, (transpose_proxy, result_reshape_shape_proxy))

@register_conversion([aten.lift_fresh_copy, aten.lift_fresh_copy.default])
def lift_fresh_copy(self, tensor_constant):
Expand Down Expand Up @@ -1305,3 +1430,39 @@ def Ge(self, x, y):
@register_conversion(torch.ops.aten.logical_or.default)
def LogicalOr(self, x, y):
return self.get_proxy(ascend_op.LogicalOr, (x, y))

@register_conversion(torch.ops.aten.slice_scatter.default)
def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1):
# modified from torchair
if start is None:
start = 0
if end is None:
end = 9223372036854775807
if (isinstance(start, int) and start == 0) and (
isinstance(end, int) and end == 9223372036854775807) and (
isinstance(step, int) and step == 1):
return self.get_proxy(ascend_op.Identity, (src, ))

# repeat = self.get_const_proxy([repeat], torch.int32, target_shape=[1])
# default dtype of the output of ascend_op Shape is int32
operand_shape = self.get_proxy(ascend_op.Shape, (operand, ))
start = self.get_const_proxy(start, torch.int32, target_shape=[])
if end == 9223372036854775807:
gather_axis = self.get_const_proxy(0, torch.int32, target_shape=[])
end = self.get_proxy(ascend_op.GatherV2, (operand_shape, dim, gather_axis))
else:
end = self.get_const_proxy(end, torch.int32, target_shape=[])
step = self.get_const_proxy(step, torch.int32, target_shape=[])
indices = self.get_proxy(ascend_op.Range, (start, end, step))

dims_to_expand = [i for i in range(src.node.meta['val'].dim())]
dims_to_expand.remove(dim)

if dims_to_expand:
indices_unsqueezed = self.get_proxy(ascend_op.Unsqueeze, (indices, dims_to_expand))
src_shape = self.get_proxy(ascend_op.Shape, (src, ))
indices_expanded = self.get_proxy(ascend_op.Expand, (indices_unsqueezed, src_shape))
else:
indices_expanded = indices
return self.get_proxy(ascend_op.ScatterElements,
(operand, indices_expanded, src, dim))
16 changes: 16 additions & 0 deletions dicp/dicp/vendor/AscendGraph/pattern_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ def replacement(input, dims):
varVal = torch.ops.aten.var(input, dims, correction=1, keepdim=True)
return ascend_op.ret_tuple(varVal, meanVal)

@register_aten_pattern
class FusedRepeatInterleaveSelfInt(BackendPatternBase):
@staticmethod
def pattern(self, repeat, dim, input_shape, empty_device, view_1_shape,
expand_1_shape, repeat_interleave_output_size):
empty = torch.ops.aten.empty.memory_format(input_shape, dtype = torch.int64, layout = torch.strided, device=empty_device)
fill = torch.ops.aten.fill.Scalar(empty, repeat)
view_1 = torch.ops.aten.view.default(fill, view_1_shape)
expand_1 = torch.ops.aten.expand.default(view_1, expand_1_shape)
repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(expand_1, output_size = repeat_interleave_output_size)
index_select = torch.ops.aten.index_select.default(self, dim, repeat_interleave)
return index_select

@staticmethod
def replacement(self, repeat, dim):
return torch.ops.aten.repeat_interleave.self_int(self, repeat, dim)

Const = torch.fx.wrap(ascend_op.Const.get_singleton())
Transpose = torch.fx.wrap(ascend_op.Transpose.get_singleton())
Expand Down
3 changes: 3 additions & 0 deletions dicp/test/ascend_scripts/ops/static.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ python_files =
test_gather.py
test_ge.py
test_getitem.py
test_index_put.py
test_index.py
test_le.py
; test_lift_fresh_copy.py
Expand All @@ -54,11 +55,13 @@ python_files =
test_scatter.py
test_select.py
test_sigmoid.py
test_slice_scatter.py
test_slice.py
test_sqrt.py
test_squeeze.py
test_sub.py
test_sum.py
test_repeat_interleave.py
; test_transpose.py
test_tril.py
test_unsqueeze.py
Expand Down
61 changes: 61 additions & 0 deletions dicp/test/op/test_index_put.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, x, indices, values):
res_Tensor = torch.ops.aten.index_put.default(x, indices, values)
return res_Tensor


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestIndexPut():
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("sizes", [Size(((1, 32, 208, 128), (None, None, (6,)), (32, 6, 128)),
((1, 32, 208, 128), (None, None, (6,)), (32, 6, 128))),
Size(((1, 2, 10, 8, 7, 11), (None, None, (2, 3), (4, 1, 1), (1, 2, 1), None),
(1, 1, 4, 1, 3, 11)),
((1, 2, 10, 8 ,7, 11), (None, None, (2, 3), (4, 1, 1), (1, 2, 1), None),
(1, 1, 4, 1, 3, 11))),
Size(((1, 2, 10, 8, 7, 11), (None, None, (2, 3), (4, 1, 1), None, (1, 2, 1)),
(4, 2, 3, 1, 2, 7)),
((1, 2, 10, 8 ,7, 11), (None, None, (2, 3), (4, 1, 1), None, (1, 2, 1)),
(4, 2, 3, 1, 2, 7)))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_split(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
x_size = size[0]
indices_size_tuple = size[1]
values_size = size[2]

input1 = torch.randn(x_size, dtype=dtype)
indices = []
for dim_idx, idx_size in enumerate(indices_size_tuple):
if idx_size is None:
indices.append(None)
else:
indices.append(torch.randint(x_size[dim_idx], idx_size, dtype=torch.int32))
value = torch.randn(values_size, dtype=dtype)
dicp_input1 = input1.to(device)
dicp_indices = [None if index is None else index.to(device) for index in indices]
dicp_value = value.to(device)

output = model(input1, indices, value)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1, dicp_indices, dicp_value)

assert torch.allclose(output.cpu(), dicp_output.cpu(), equal_nan=True)
Loading