Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tangzhiyi11 committed Jan 17, 2024
1 parent 153a982 commit 1a9b230
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def try_to_get_dtype(x):
return None


def is_cpp_support_dtype(dtype):
def is_dicp_cpp_support_dtype(dtype):
if dtype in [torch.float32, torch.float, torch.int32, torch.int64]:
return True
return False
Expand Down Expand Up @@ -147,7 +147,7 @@ def get_const_proxy(self, param, dtype, format=None, target_shape=None):
else:
shape = target_shape
param = param if isinstance(param, list) else [param]
if is_cpp_support_dtype(dtype):
if is_dicp_cpp_support_dtype(dtype):
param = self.get_proxy(
ascend_op.Const, (param, dtype, shape, format))
else:
Expand Down Expand Up @@ -229,11 +229,6 @@ def mul(self, x, y):
return self.mul_scalar(x, y)
x_shape = list(x.node.meta['val'].shape)
y_shape = list(y.node.meta['val'].shape)
# handling with broadcasting cases
if np.prod(x_shape) < np.prod(y_shape):
x = self.get_proxy(ascend_op.BroadcastTo, (x, y_shape))
elif np.prod(x_shape) > np.prod(y_shape):
y = self.get_proxy(ascend_op.BroadcastTo, (y, x_shape))
x, y = self.promote_dtype(x, y, target_dtype=out_dtype)
return self.get_proxy(ascend_op.Mul, (x, y), {})

Expand Down Expand Up @@ -589,11 +584,7 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'):
value = value.node.meta['val']
dims = self.get_shape_proxy(dims)
# temporarily split the path for dynamic/static shape cases
if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0:
value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, []))
else:
value = self.get_const_proxy(value, torch_dtype)
value = self.get_const_proxy(value, torch_dtype)
return self.get_proxy(ascend_op.Fill, (dims, value))

@register_conversion(torch.ops.aten.fill.Scalar)
Expand Down Expand Up @@ -820,9 +811,7 @@ def pow(self, x, exp):
def maximum(self, a, b):
a_shape = list(a.node.meta['val'].shape)
b_shape = list(b.node.meta['val'].shape)
if np.prod(b_shape) < np.prod(a_shape):
b = self.get_proxy(ascend_op.BroadcastTo, (b, a_shape))
b = self.promote_dtype(b, a.node.meta['val'].dtype)
b = self.promote_dtype(b, target_dtype=a.node.meta['val'].dtype)
return self.get_proxy(ascend_op.Maximum, (a, b))

@register_conversion(aten.sub)
Expand Down

0 comments on commit 1a9b230

Please sign in to comment.