Skip to content

Commit

Permalink
Merge branch 'DeepLink-org:main' into wgs/add_1424
Browse files Browse the repository at this point in the history
  • Loading branch information
wugeshui authored Feb 5, 2024
2 parents 1f4b15d + f7afefa commit 3cdb976
Show file tree
Hide file tree
Showing 15 changed files with 541 additions and 233 deletions.
33 changes: 26 additions & 7 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,17 +500,17 @@ def ne(self, a, b):

@register_conversion([aten.lt.Scalar, aten.lt.Tensor])
def lt(self, x, y):
y_shape = [1]
if isinstance(y, torch.fx.proxy.Proxy):
y_shape = list(y.node.meta['val'].shape)
x_shape = list(x.node.meta['val'].shape)
y_shape = [] if not isinstance(
y, torch.fx.proxy.Proxy) else list(y.node.meta['val'].shape)
out = list(fx_traceback.get_current_meta()['val'].shape)
out_shape = self.get_shape_proxy(out)
x, y = self.binary_cmp_cast_input(x, y)

if self.shape_prod(x_shape) < self.shape_prod(out):
dynamic_shape = symint_in_shape(x_shape) or symint_in_shape(
y_shape) or symint_in_shape(out)
if dynamic_shape and (self.shape_prod(x_shape) < self.shape_prod(out)):
x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape))
if self.shape_prod(y_shape) < self.shape_prod(out):
if dynamic_shape and (self.shape_prod(y_shape) < self.shape_prod(out)):
y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape))
return self.get_proxy(ascend_op.Less, (x, y))

Expand Down Expand Up @@ -834,6 +834,26 @@ def compute_stacked_indices(self, indices, src_shape):

@register_conversion(torch.ops.aten.index_put.default)
def index_put_default(self, x, indices, values):
x_shape = list(x.node.meta['val'].shape)

# When the element type of indices is bool, the masked_fill operator
# should be used to achieve this. Currently, only indices with a length
# of 1 are supported.
if any([index.node.meta['val'].dtype in [torch.bool]
for index in indices if index is not None]):
assert len(indices) == 1
index = indices[0]
index_shape = list(index.node.meta['val'].shape)
index_shape_size = len(index_shape)
x_shape_size = len(x_shape)
if index_shape_size == x_shape_size:
return self.masked_fill(x, index, values)
reshape_shape = index_shape + [1] * \
(x_shape_size - index_shape_size)
reshape_op = self.get_const_proxy(reshape_shape, torch.int32)
index = self.get_proxy(ascend_op.Reshape, (index, reshape_op))
return self.masked_fill(x, index, values)

# following comment is from tensorflow tensor_scatter_nd_update:
# index_depth = indices.shape[-1]
# batch_shape = indices.shape[:-1]
Expand All @@ -845,7 +865,6 @@ def index_put_default(self, x, indices, values):
# tf.tensor_scatter_nd_update param 'indices' is different from
# indices in torch.ops.aten.index_put.default, we use broadcast and
# stack to construct param 'indices' in tf.tensor_scatter_nd_update
x_shape = list(x.node.meta['val'].shape)
stacked_indices, indices_broadcast_shape, stacked_indices_last_dim = \
self.compute_stacked_indices(indices, x.node.meta['val'].shape)
values_broadcast_shape = indices_broadcast_shape + x_shape[stacked_indices_last_dim:] # batch_shape + inner_shape
Expand Down
2 changes: 1 addition & 1 deletion dicp/test/model/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@
response_list.append(response.split('\n'))

for idx, dicp_result in enumerate(response_list):
assert dicp_result == cuda_results[idx]
assert dicp_result == cuda_results[idx], f"dicp result:{dicp_result}, cuda_result:{cuda_results[idx]}"
2 changes: 1 addition & 1 deletion dicp/test/model/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ def test_inference(
prompt, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, device=device
)
dicp_result = dicp_result[0].split("\n")
assert dicp_result == cuda_results[i]
assert dicp_result == cuda_results[i], f"dicp result:{dicp_result}, cuda_result:{cuda_results[i]}"
6 changes: 4 additions & 2 deletions dicp/test/model/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def test_inference(
dicp_pipe = StableDiffusionPipeline.from_pretrained(model_path).to(device)
dicp_pipe.text_encoder = torch.compile(dicp_pipe.text_encoder, backend=backend, dynamic=dynamic)
dicp_pipe.unet = torch.compile(dicp_pipe.unet, backend=backend, dynamic=dynamic)
if backend == "ascendgraph":
dicp_pipe.vae.decoder = torch.compile(dicp_pipe.vae.decoder, backend=backend, dynamic=dynamic)

# Temporarily run decoder on CPU
# if backend == "ascendgraph":
# dicp_pipe.vae.decoder = torch.compile(dicp_pipe.vae.decoder, backend=backend, dynamic=dynamic)
dicp_image = dicp_pipe(prompt, num_inference_steps=num_inference_steps).images[0]

similarity = get_similarity(cpu_image, dicp_image)
Expand Down
26 changes: 25 additions & 1 deletion dicp/test/op/test_index_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TestIndexPut():
((1, 2, 10, 8 ,7, 11), (None, None, (2, 3), (4, 1, 1), None, (1, 2, 1)),
(4, 2, 3, 1, 2, 7)))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_split(self, sizes, dtype, compiled_model):
def test_torch_index_put(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
x_size = size[0]
Expand All @@ -59,3 +59,27 @@ def test_torch_split(self, sizes, dtype, compiled_model):
dicp_output = compiled_model.model(dicp_input1, dicp_indices, dicp_value)

assert torch.allclose(output.cpu(), dicp_output.cpu(), equal_nan=True)

@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_index_put_to_masked_fill(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
mask_size = size if len(size) == 1 else size[0]

input = torch.randn(size, dtype=dtype)
mask = torch.randn(mask_size, dtype=dtype) > 0
value = torch.tensor(1).to(dtype)
indices = [mask]

dicp_input = input.to(device)
dicp_indices = [mask.to(device)]
dicp_value = value.to(device)

output = model(input, indices, value)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input, dicp_indices, dicp_value)

assert torch.allclose(output.cpu(), dicp_output.cpu(), equal_nan=True)
2 changes: 2 additions & 0 deletions dipu/.clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ CheckOptions:
value: true
- key: readability-implicit-bool-conversion.AllowPointerConditions
value: true
- key: readability-simplify-boolean-expr.SimplifyDeMorgan
value: false
# --- Google's naming convention BEGIN ---
# modified part is marked as comment
- key: readability-identifier-naming.ClassCase
Expand Down
Loading

0 comments on commit 3cdb976

Please sign in to comment.