diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index e240e2498..686f1a39f 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -653,18 +653,9 @@ def ne(self, a, b): @register_conversion([aten.lt.Scalar, aten.lt.Tensor]) def lt(self, x, y): - x_shape = list(x.node.meta['val'].shape) - y_shape = [] if not isinstance( - y, torch.fx.proxy.Proxy) else list(y.node.meta['val'].shape) - out = list(fx_traceback.get_current_meta()['val'].shape) - out_shape = self.get_shape_proxy(out) - x, y = self.binary_cmp_cast_input(x, y) - dynamic_shape = not_all_num_shape(x_shape) or not_all_num_shape( - y_shape) or not_all_num_shape(out) - if dynamic_shape and (self.shape_prod(x_shape) < self.shape_prod(out)): - x = self.get_proxy(ascend_op.BroadcastTo, (x, out_shape)) - if dynamic_shape and (self.shape_prod(y_shape) < self.shape_prod(out)): - y = self.get_proxy(ascend_op.BroadcastTo, (y, out_shape)) + if not isinstance(y, torch.fx.proxy.Proxy): + x_dtype = x.node.meta['val'].dtype + y = self.get_const_proxy(y, x_dtype) return self.get_proxy(ascend_op.Less, (x, y)) @register_conversion(aten.masked_fill.Scalar) @@ -1711,7 +1702,8 @@ def prompt_attention_inference(self, q, k, v, seqlen, num_head, head_dim): mask = self.get_proxy(ascend_op.OnesLike, (mask,)) mask = self.get_proxy(ascend_op.Tril, (mask,)) mask = self.get_proxy(ascend_op.LogicalNot, (mask,)) - q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16))) + if q.node.meta['val'].dtype != torch.float16: + q = self.get_proxy(ascend_op.Cast, (q, get_ascend_dtype(torch.float16))) fa = self.get_proxy(ascend_op.PromptFlashAttention, (q, k, v, num_head, seqlen, mask, head_dim)) return fa