Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 committed Jan 17, 2024
1 parent 153a982 commit bc5ff90
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,6 @@ def mul(self, x, y):
return self.mul_scalar(x, y)
x_shape = list(x.node.meta['val'].shape)
y_shape = list(y.node.meta['val'].shape)
# handling with broadcasting cases
if np.prod(x_shape) < np.prod(y_shape):
x = self.get_proxy(ascend_op.BroadcastTo, (x, y_shape))
elif np.prod(x_shape) > np.prod(y_shape):
y = self.get_proxy(ascend_op.BroadcastTo, (y, x_shape))
x, y = self.promote_dtype(x, y, target_dtype=out_dtype)
return self.get_proxy(ascend_op.Mul, (x, y), {})

Expand Down Expand Up @@ -589,11 +584,7 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'):
value = value.node.meta['val']
dims = self.get_shape_proxy(dims)
# temporarily split the path for dynamic/static shape cases
if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0:
value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, []))
else:
value = self.get_const_proxy(value, torch_dtype)
value = self.get_const_proxy(value, torch_dtype)
return self.get_proxy(ascend_op.Fill, (dims, value))

@register_conversion(torch.ops.aten.fill.Scalar)
Expand Down Expand Up @@ -820,9 +811,7 @@ def pow(self, x, exp):
def maximum(self, a, b):
a_shape = list(a.node.meta['val'].shape)
b_shape = list(b.node.meta['val'].shape)
if np.prod(b_shape) < np.prod(a_shape):
b = self.get_proxy(ascend_op.BroadcastTo, (b, a_shape))
b = self.promote_dtype(b, a.node.meta['val'].dtype)
b = self.promote_dtype(b, target_dtype=a.node.meta['val'].dtype)
return self.get_proxy(ascend_op.Maximum, (a, b))

@register_conversion(aten.sub)
Expand Down

0 comments on commit bc5ff90

Please sign in to comment.