Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DICP] add scalar_tensor for ascendgraph #675

Merged
merged 7 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions dicp/dicp/dynamo_bridge/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ def make_tensor_meta(x) -> Optional[TensorMetadata]:
else:
continue
if 'val' in n.meta and test_infer:
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)):
assert meta_i.size() == fv_i.size(), "check infer size failed"
assert meta_i.dtype == fv_i.dtype, "check infer dtype failed"
assert meta_i.stride() == fv_i.stride(), "check infer stride failed"
assert meta_i.storage_offset() == fv_i.storage_offset(), "check infer storage offset failed"
if not isinstance(fv_i, FakeTensor):
continue
log_info = f"target: {n.target}, meta_i: {meta_i}, fv_i: {fv_i}"
assert meta_i.size() == fv_i.size(), f"check infer size failed, {log_info}"
assert meta_i.dtype == fv_i.dtype, f"check infer dtype failed, {log_info}"
assert meta_i.stride() == fv_i.stride(), f"check infer stride failed, {log_info}"
assert meta_i.storage_offset() == fv_i.storage_offset(), f"check infer storage offset failed, {log_info}"
if 'val' not in n.meta:
n.meta['val'] = fake_value
n.meta["tensor_meta"] = make_tensor_meta(n.meta['val'])
Expand Down
11 changes: 9 additions & 2 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,9 +600,12 @@ def view_as_complex(self, x):
@register_conversion(torch.ops.aten.full.default)
def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
device='cpu', pin_memory=False, memory_format=torch.preserve_format):
if len(dims) == 0:
dims = [1]
torch_dtype = dtype if dtype else torch.get_default_dtype()
# If len(dims) == 0, it means this is a scalar tensor with a dimension of 0,
# and it can directly return a const node to construct a scalar tensor.
if len(dims) == 0:
return self.get_const_proxy(value, torch_dtype)

dims = [dim.node.meta['val'] if isinstance(dim, torch.fx.proxy.Proxy) and hasattr(
dim.node, 'meta') else dim for dim in dims]
if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'):
Expand Down Expand Up @@ -1483,3 +1486,7 @@ def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1):
indices_expanded = indices
return self.get_proxy(ascend_op.ScatterElements,
(operand, indices_expanded, src, dim))

@register_conversion(torch.ops.aten.scalar_tensor.default)
def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None):
return self.get_const_proxy(x, dtype)
1 change: 1 addition & 0 deletions dicp/test/ascend_scripts/ops/static.ini
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ python_files =
test_relu.py
test_repeat.py
test_rsqrt.py
test_scalar_tensor.py
test_scatter.py
test_select.py
test_sigmoid.py
Expand Down
28 changes: 28 additions & 0 deletions dicp/test/op/test_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,27 @@ def forward(self, a, dtype):
return res_default


class AscendOpModule(torch.nn.Module):
def forward(self, a, redundant_input, dtype):
# If there is only one operator called scalar_tensor,
# ascend graph compiler will give an error:
# GE.. [Check][Param] SetInputs failed: input operator size can not be 0.
# To solve this problem, an additional redundant input is added,
# and the result of an addition operator is returned.
scalar_tensor = torch.ops.aten.scalar_tensor.default(a, dtype=dtype)
res_default = torch.ops.aten.add.Tensor(redundant_input, scalar_tensor)
return scalar_tensor, res_default


model = OpModule()
ascend_model = AscendOpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)
ascend_compiled_model = compile_model(ascend_model, args.backend, args.dynamic)


class TestScalarTensor():
@pytest.mark.skipif(args.backend == 'ascendgraph', reason="skip ascendgraph")
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64, torch.float16])
@pytest.mark.parametrize("inputs", [1.0, 3.0, 0.0])
@pytest.mark.parametrize("compiled_model", compiled_model)
Expand All @@ -30,3 +45,16 @@ def test_torch_scalar_tensor(self, inputs, dtype, compiled_model):
dicp_output = compiled_model.model(inputs, dtype)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)

@pytest.mark.skipif(args.backend == 'topsgraph', reason="skip topsgraph")
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64, torch.float16])
@pytest.mark.parametrize("inputs", [1.0, 2.0, 3.0])
@pytest.mark.parametrize("compiled_model", ascend_compiled_model)
def test_torch_ascend_scalar_tensor(self, inputs, dtype, compiled_model):
redundant_input = torch.ones(1, dtype=dtype)
output, _ = ascend_model(inputs, redundant_input, dtype)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output, _ = compiled_model.model(inputs, redundant_input, dtype)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)
Loading