Skip to content

Commit

Permalink
Modify logic of lt dtype, and prompt_attention fp16 conversion.
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 committed May 23, 2024
1 parent 4e499b9 commit 3811319
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,18 +653,9 @@ def ne(self, a, b):

@register_conversion([aten.lt.Scalar, aten.lt.Tensor])
def lt(self, x, y):
x_shape = list(x.node.meta['val'].shape)
y_shape = [] if not isinstance(
y, torch.fx.proxy.Proxy) else list(y.node.meta['val'].shape)
out = list(fx_traceback.get_current_meta()['val'].shape)
out_shape = self.get_shape_proxy(out)
x, y = self.binary_cmp_cast_input(x, y)
dynamic_shape = not_all_num_shape(x_shape) or not_all_num_shape(
y_shape) or not_all_num_shape(out)
if dynamic_shape and (self.shape_prod(x_shape) < self.shape_prod(out)):
x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape))
if dynamic_shape and (self.shape_prod(y_shape) < self.shape_prod(out)):
y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape))
if not isinstance(y, torch.fx.proxy.Proxy):
x_dtype = x.node.meta['val'].dtype
y = self.get_const_proxy(y, x_dtype)
return self.get_proxy(ascend_op.Less, (x, y))

@register_conversion(aten.masked_fill.Scalar)
Expand Down Expand Up @@ -1711,7 +1702,8 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim):
mask = self.get_proxy(ascend_op.OnesLike, (mask,))
mask = self.get_proxy(ascend_op.Tril, (mask,))
mask = self.get_proxy(ascend_op.LogicalNot, (mask,))
q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16)))
if q.node.meta['val'].dtype != torch.float16:
q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16)))
fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen, mask, head_dim))
return fa

Expand Down

0 comments on commit 3811319

Please sign in to comment.