Skip to content

Commit

Permalink
add scalar_tensor for ascendgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 committed Jan 25, 2024
1 parent 8516682 commit 55e668a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
4 changes: 4 additions & 0 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,3 +1483,7 @@ def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1):
indices_expanded = indices
return self.get_proxy(ascend_op.ScatterElements,
(operand, indices_expanded, src, dim))

@register_conversion(torch.ops.aten.scalar_tensor.default)
def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None):
return self.get_const_proxy(x, dtype)
1 change: 1 addition & 0 deletions dicp/test/ascend_scripts/ops/static.ini
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ python_files =
test_relu.py
test_repeat.py
test_rsqrt.py
test_scalar_tensor.py
test_scatter.py
test_select.py
test_sigmoid.py
Expand Down
9 changes: 5 additions & 4 deletions dicp/test/op/test_scalar_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@


class OpModule(torch.nn.Module):
def forward(self, a, dtype):
def forward(self, a, redundant_input, dtype):
res_default = torch.ops.aten.scalar_tensor.default(a, dtype=dtype)
return res_default
return res_default + redundant_input, res_default


model = OpModule()
Expand All @@ -24,9 +24,10 @@ class TestScalarTensor():
@pytest.mark.parametrize("inputs", [1.0, 3.0, 0.0])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_scalar_tensor(self, inputs, dtype, compiled_model):
output = model(inputs, dtype)
redundant_input = torch.randn(1, dtype=dtype)
_, output = model(inputs, redundant_input, dtype)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(inputs, dtype)
_, dicp_output = compiled_model.model(inputs, redundant_input, dtype)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)

0 comments on commit 55e668a

Please sign in to comment.