From 55e668a352f54cc761a5f183bc0ee29ebdb9cf78 Mon Sep 17 00:00:00 2001 From: tangzhiyi Date: Thu, 25 Jan 2024 09:26:44 +0000 Subject: [PATCH] add scalar_tensor for ascendgraph --- dicp/dicp/vendor/AscendGraph/conversion.py | 4 ++++ dicp/test/ascend_scripts/ops/static.ini | 1 + dicp/test/op/test_scalar_tensor.py | 9 +++++---- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 8a58f1e19b..76d9aa3b4b 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -1483,3 +1483,7 @@ def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1): indices_expanded = indices return self.get_proxy(ascend_op.ScatterElements, (operand, indices_expanded, src, dim)) + + @register_conversion(torch.ops.aten.scalar_tensor.default) + def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None): + return self.get_const_proxy(x, dtype) diff --git a/dicp/test/ascend_scripts/ops/static.ini b/dicp/test/ascend_scripts/ops/static.ini index 688562aa5d..efbd2e7cb0 100644 --- a/dicp/test/ascend_scripts/ops/static.ini +++ b/dicp/test/ascend_scripts/ops/static.ini @@ -54,6 +54,7 @@ python_files = test_relu.py test_repeat.py test_rsqrt.py + test_scalar_tensor.py test_scatter.py test_select.py test_sigmoid.py diff --git a/dicp/test/op/test_scalar_tensor.py b/dicp/test/op/test_scalar_tensor.py index 19668c5bdf..b931cbe17b 100644 --- a/dicp/test/op/test_scalar_tensor.py +++ b/dicp/test/op/test_scalar_tensor.py @@ -9,9 +9,9 @@ class OpModule(torch.nn.Module): - def forward(self, a, dtype): + def forward(self, a, redundant_input, dtype): res_default = torch.ops.aten.scalar_tensor.default(a, dtype=dtype) - return res_default + return res_default + redundant_input, res_default model = OpModule() @@ -24,9 +24,10 @@ class TestScalarTensor(): @pytest.mark.parametrize("inputs", [1.0, 3.0, 0.0]) @pytest.mark.parametrize("compiled_model", compiled_model) def test_torch_scalar_tensor(self, inputs, dtype, compiled_model): - output = model(inputs, dtype) + redundant_input = torch.randn(1, dtype=dtype) + _, output = model(inputs, redundant_input, dtype) dynamo.reset() update_dynamo_config(compiled_model.dynamic) - dicp_output = compiled_model.model(inputs, dtype) + _, dicp_output = compiled_model.model(inputs, redundant_input, dtype) assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)