From bc5ff90e8253919e578f74e22c8079bc26c653f6 Mon Sep 17 00:00:00 2001 From: tangzhiyi11 Date: Tue, 16 Jan 2024 15:58:30 +0800 Subject: [PATCH] update --- dicp/dicp/vendor/AscendGraph/conversion.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index d91a78af37..29adff84d5 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -229,11 +229,6 @@ def mul(self, x, y): return self.mul_scalar(x, y) x_shape = list(x.node.meta['val'].shape) y_shape = list(y.node.meta['val'].shape) - # handling with broadcasting cases - if np.prod(x_shape) < np.prod(y_shape): - x = self.get_proxy(ascend_op.BroadcastTo, (x, y_shape)) - elif np.prod(x_shape) > np.prod(y_shape): - y = self.get_proxy(ascend_op.BroadcastTo, (y, x_shape)) x, y = self.promote_dtype(x, y, target_dtype=out_dtype) return self.get_proxy(ascend_op.Mul, (x, y), {}) @@ -589,11 +584,7 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided, if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'): value = value.node.meta['val'] dims = self.get_shape_proxy(dims) - # temporarily split the path for dynamic/static shape cases - if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: - value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, [])) - else: - value = self.get_const_proxy(value, torch_dtype) + value = self.get_const_proxy(value, torch_dtype) return self.get_proxy(ascend_op.Fill, (dims, value)) @register_conversion(torch.ops.aten.fill.Scalar) @@ -820,9 +811,7 @@ def pow(self, x, exp): def maximum(self, a, b): a_shape = list(a.node.meta['val'].shape) b_shape = list(b.node.meta['val'].shape) - if np.prod(b_shape) < np.prod(a_shape): - b = self.get_proxy(ascend_op.BroadcastTo, (b, a_shape)) - b = self.promote_dtype(b, a.node.meta['val'].dtype) + b = self.promote_dtype(b, target_dtype=a.node.meta['val'].dtype) return self.get_proxy(ascend_op.Maximum, (a, b)) @register_conversion(aten.sub)