Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 committed Jan 29, 2024
1 parent 001860c commit 89dc97e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
1 change: 1 addition & 0 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def ascendgraph_opset_convert(
if is_torch_210:
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, aten_patterns_cls_list).transform(gm)
gm.print_readable()
gm = AtenToAscendTransformer(gm).transform()

# For bug in pytorch
Expand Down
31 changes: 26 additions & 5 deletions dicp/test/op/test_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,52 @@


class OpModule(torch.nn.Module):
def forward(self, a, dtype):
res_default = torch.ops.aten.scalar_tensor.default(a, dtype=dtype)
return res_default


class AscendOpModule(torch.nn.Module):
def forward(self, a, redundant_input, dtype):
# If there is only one operator called scalar_tensor,
# ascend graph compiler will give an error:
# GE.. [Check][Param] SetInputs failed: input operator size can not be 0.
# To solve this problem, an additional redundant input is added,
# and the result of an addition operator is returned.
res_default = torch.ops.aten.scalar_tensor.default(a, dtype=dtype)
redundant_output = torch.ops.aten.add.Tensor(redundant_input, res_default)
return redundant_output, res_default
scalar_tensor = torch.ops.aten.scalar_tensor.default(a, dtype=dtype)
res_default = torch.ops.aten.add.Tensor(redundant_input, scalar_tensor)
return res_default


model = OpModule()
ascend_model = AscendOpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)
ascend_compiled_model = compile_model(ascend_model, args.backend, args.dynamic)


class TestScalarTensor():
@pytest.mark.skipif(args.backend == 'ascendgraph', reason="skip ascendgraph")
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64, torch.float16])
@pytest.mark.parametrize("inputs", [1.0, 3.0, 0.0])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_scalar_tensor(self, inputs, dtype, compiled_model):
output = model(inputs, dtype)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(inputs, dtype)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)

@pytest.mark.skipif(args.backend == 'topsgraph', reason="skip topsgraph")
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64, torch.float16])
@pytest.mark.parametrize("inputs", [1.0, 3.0, 0.0])
@pytest.mark.parametrize("compiled_model", ascend_compiled_model)
def test_torch_ascend_scalar_tensor(self, inputs, dtype, compiled_model):
redundant_input = torch.ones(1, dtype=dtype)
_, output = model(inputs, redundant_input, dtype)
output = ascend_model(inputs, redundant_input, dtype)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
_, dicp_output = compiled_model.model(inputs, redundant_input, dtype)
dicp_output = compiled_model.model(inputs, redundant_input, dtype)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)

0 comments on commit 89dc97e

Please sign in to comment.