diff --git a/dicp/dicp/dynamo_bridge/graph.py b/dicp/dicp/dynamo_bridge/graph.py index f5f5095d2..8b94d1443 100644 --- a/dicp/dicp/dynamo_bridge/graph.py +++ b/dicp/dicp/dynamo_bridge/graph.py @@ -61,12 +61,15 @@ def make_tensor_meta(x) -> Optional[TensorMetadata]: else: continue if 'val' in n.meta and test_infer: - (n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value) + (n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value) for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)): - assert meta_i.size() == fv_i.size(), "check infer size failed" - assert meta_i.dtype == fv_i.dtype, "check infer dtype failed" - assert meta_i.stride() == fv_i.stride(), "check infer stride failed" - assert meta_i.storage_offset() == fv_i.storage_offset(), "check infer storage offset failed" + if not isinstance(fv_i, FakeTensor): + continue + log_info = f"target: {n.target}, meta_i: {meta_i}, fv_i: {fv_i}" + assert meta_i.size() == fv_i.size(), f"check infer size failed, {log_info}" + assert meta_i.dtype == fv_i.dtype, f"check infer dtype failed, {log_info}" + assert meta_i.stride() == fv_i.stride(), f"check infer stride failed, {log_info}" + assert meta_i.storage_offset() == fv_i.storage_offset(), f"check infer storage offset failed, {log_info}" if 'val' not in n.meta: n.meta['val'] = fake_value n.meta["tensor_meta"] = make_tensor_meta(n.meta['val']) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 8a58f1e19..c662c33ea 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -600,9 +600,12 @@ def view_as_complex(self, x): @register_conversion(torch.ops.aten.full.default) def full(self, dims, value, dtype=torch.float32, layout=torch.strided, device='cpu', pin_memory=False, memory_format=torch.preserve_format): - if len(dims) == 0: - dims = [1] torch_dtype = dtype if dtype else torch.get_default_dtype() + # If len(dims) == 0, it means this is a scalar tensor with a dimension of 0, + # and it can directly return a const node to construct a scalar tensor. + if len(dims) == 0: + return self.get_const_proxy(value, torch_dtype) + dims = [dim.node.meta['val'] if isinstance(dim, torch.fx.proxy.Proxy) and hasattr( dim.node, 'meta') else dim for dim in dims] if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'): @@ -1483,3 +1486,7 @@ def SliceScatter(self, operand, src, dim=0, start=None, end=None, step=1): indices_expanded = indices return self.get_proxy(ascend_op.ScatterElements, (operand, indices_expanded, src, dim)) + + @register_conversion(torch.ops.aten.scalar_tensor.default) + def scalar_tensor(self, x, dtype=None, layout=None, device=None, pin_memory=None): + return self.get_const_proxy(x, dtype) diff --git a/dicp/test/ascend_scripts/ops/static.ini b/dicp/test/ascend_scripts/ops/static.ini index 688562aa5..efbd2e7cb 100644 --- a/dicp/test/ascend_scripts/ops/static.ini +++ b/dicp/test/ascend_scripts/ops/static.ini @@ -54,6 +54,7 @@ python_files = test_relu.py test_repeat.py test_rsqrt.py + test_scalar_tensor.py test_scatter.py test_select.py test_sigmoid.py diff --git a/dicp/test/op/test_scalar_tensor.py b/dicp/test/op/test_scalar_tensor.py index 19668c5bd..05894c24c 100644 --- a/dicp/test/op/test_scalar_tensor.py +++ b/dicp/test/op/test_scalar_tensor.py @@ -14,12 +14,27 @@ def forward(self, a, dtype): return res_default +class AscendOpModule(torch.nn.Module): + def forward(self, a, redundant_input, dtype): + # If there is only one operator called scalar_tensor, + # ascend graph compiler will give an error: + # GE.. [Check][Param] SetInputs failed: input operator size can not be 0. + # To solve this problem, an additional redundant input is added, + # and the result of an addition operator is returned. + scalar_tensor = torch.ops.aten.scalar_tensor.default(a, dtype=dtype) + res_default = torch.ops.aten.add.Tensor(redundant_input, scalar_tensor) + return scalar_tensor, res_default + + model = OpModule() +ascend_model = AscendOpModule() args = parse_args() compiled_model = compile_model(model, args.backend, args.dynamic) +ascend_compiled_model = compile_model(ascend_model, args.backend, args.dynamic) class TestScalarTensor(): + @pytest.mark.skipif(args.backend == 'ascendgraph', reason="skip ascendgraph") @pytest.mark.parametrize("dtype", [torch.float32, torch.int64, torch.float16]) @pytest.mark.parametrize("inputs", [1.0, 3.0, 0.0]) @pytest.mark.parametrize("compiled_model", compiled_model) @@ -30,3 +45,16 @@ def test_torch_scalar_tensor(self, inputs, dtype, compiled_model): dicp_output = compiled_model.model(inputs, dtype) assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) + + @pytest.mark.skipif(args.backend == 'topsgraph', reason="skip topsgraph") + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64, torch.float16]) + @pytest.mark.parametrize("inputs", [1.0, 2.0, 3.0]) + @pytest.mark.parametrize("compiled_model", ascend_compiled_model) + def test_torch_ascend_scalar_tensor(self, inputs, dtype, compiled_model): + redundant_input = torch.ones(1, dtype=dtype) + output, _ = ascend_model(inputs, redundant_input, dtype) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output, _ = compiled_model.model(inputs, redundant_input, dtype) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)