Skip to content

Commit

Permalink
[DICP][ascend] add scalar_tensor for ascendgraph (#675)
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 authored Jan 29, 2024
1 parent a16c4b2 commit 6ebb801
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 7 deletions.
13 changes: 8 additions & 5 deletions dicp/dicp/dynamo_bridge/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ def make_tensor_meta(x) -> Optional[TensorMetadata]:
else:
continue
if 'val' in n.meta and test_infer:
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)):
assert meta_i.size() == fv_i.size(), "check infer size failed"
assert meta_i.dtype == fv_i.dtype, "check infer dtype failed"
assert meta_i.stride() == fv_i.stride(), "check infer stride failed"
assert meta_i.storage_offset() == fv_i.storage_offset(), "check infer storage offset failed"
if not isinstance(fv_i, FakeTensor):
continue
log_info = f"target: {n.target}, meta_i: {meta_i}, fv_i: {fv_i}"
assert meta_i.size() == fv_i.size(), f"check infer size failed, {log_info}"
assert meta_i.dtype == fv_i.dtype, f"check infer dtype failed, {log_info}"
assert meta_i.stride() == fv_i.stride(), f"check infer stride failed, {log_info}"
assert meta_i.storage_offset() == fv_i.storage_offset(), f"check infer storage offset failed, {log_info}"
if 'val' not in n.meta:
n.meta['val'] = fake_value
n.meta["tensor_meta"] = make_tensor_meta(n.meta['val'])
Expand Down
11 changes: 9 additions & 2 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,12 @@ def view_as_complex(self, x):
@register_conversion(torch.ops.aten.full.default)
def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
device='cpu', pin_memory=False, memory_format=torch.preserve_format):
if len(dims) == 0:
dims = [1]
torch_dtype = dtype if dtype else torch.get_default_dtype()
# If len(dims) == 0, it means this is a scalar tensor with a dimension of 0,
# and it can directly return a const node to construct a scalar tensor.
if len(dims) == 0:
return self.get_const_proxy(value, torch_dtype)

dims = [dim.node.meta['val'] if isinstance(dim, torch.fx.proxy.Proxy) and hasattr(
dim.node, 'meta') else dim for dim in dims]
if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'):
Expand Down Expand Up @@ -1483,3 +1486,7 @@ def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1):
indices_expanded = indices
return self.get_proxy(ascend_op.ScatterElements,
(operand, indices_expanded, src, dim))

@register_conversion(torch.ops.aten.scalar_tensor.default)
def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None):
return self.get_const_proxy(x, dtype)
1 change: 1 addition & 0 deletions dicp/test/ascend_scripts/ops/static.ini
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ python_files =
test_relu.py
test_repeat.py
test_rsqrt.py
test_scalar_tensor.py
test_scatter.py
test_select.py
test_sigmoid.py
Expand Down
28 changes: 28 additions & 0 deletions dicp/test/op/test_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,27 @@ def forward(self, a, dtype):
return res_default


class AscendOpModule(torch.nn.Module):
def forward(self, a, redundant_input, dtype):
# If there is only one operator called scalar_tensor,
# ascend graph compiler will give an error:
# GE.. [Check][Param] SetInputs failed: input operator size can not be 0.
# To solve this problem, an additional redundant input is added,
# and the result of an addition operator is returned.
scalar_tensor = torch.ops.aten.scalar_tensor.default(a, dtype=dtype)
res_default = torch.ops.aten.add.Tensor(redundant_input, scalar_tensor)
return scalar_tensor, res_default


model = OpModule()
ascend_model = AscendOpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)
ascend_compiled_model = compile_model(ascend_model, args.backend, args.dynamic)


class TestScalarTensor():
@pytest.mark.skipif(args.backend == 'ascendgraph', reason="skip ascendgraph")
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64, torch.float16])
@pytest.mark.parametrize("inputs", [1.0, 3.0, 0.0])
@pytest.mark.parametrize("compiled_model", compiled_model)
Expand All @@ -30,3 +45,16 @@ def test_torch_scalar_tensor(self, inputs, dtype, compiled_model):
dicp_output = compiled_model.model(inputs, dtype)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)

@pytest.mark.skipif(args.backend == 'topsgraph', reason="skip topsgraph")
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64, torch.float16])
@pytest.mark.parametrize("inputs", [1.0, 2.0, 3.0])
@pytest.mark.parametrize("compiled_model", ascend_compiled_model)
def test_torch_ascend_scalar_tensor(self, inputs, dtype, compiled_model):
redundant_input = torch.ones(1, dtype=dtype)
output, _ = ascend_model(inputs, redundant_input, dtype)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output, _ = compiled_model.model(inputs, redundant_input, dtype)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)

0 comments on commit 6ebb801

Please sign in to comment.