Skip to content

Commit

Permalink
Add some comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
yao-fengchen committed Dec 25, 2023
1 parent 524195a commit 706fa91
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
7 changes: 6 additions & 1 deletion dicp/dicp/vendor/TopsGraph/codegen/enflame.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,12 @@ def ComplexMul(op_var, out_shape, out_dtype, x, y):
def Concatenate(op_var, out_shape, out_dtype, tensors, dim):
return f"builder::Op {op_var} = builder::Concatenate({'{' + ', '.join(tensors) + '}'}, {dim});"

# Add an additional true flag for accuration in tops softmax.
"""
Add an additional true flag for accuration in hlir_builder Softmax.
The third parameter, half_to_float, in aten._softmax represents whether cast
inputs from float16 to float32 or not, while the third parameter ,accurate,
in hlir_builder represents whether precision calculation is performed.
"""
@staticmethod
def Softmax(op_var, out_shape, out_dtype, x, y):
return f"builder::Op {op_var} = builder::Softmax({x}, {y}, true);"
Expand Down
6 changes: 6 additions & 0 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ def BatchNorm(self, *args, **kwargs):
def BatchNormBackward(*args, **kwargs):
return tops_op.BatchNormBackward(*args, **kwargs)

"""
Add an additional true flag for accuration in hlir_builder Softmax.
The third parameter, half_to_float, in aten._softmax represents whether cast
inputs from float16 to float32 or not, while the third parameter ,accurate,
in hlir_builder represents whether precision calculation is performed.
"""
@register_conversion(aten._softmax)
def Softmax(self, a, dim, half_to_float):
out_shape = fx_traceback.get_current_meta()["val"].shape
Expand Down
1 change: 1 addition & 0 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ def __init__(self, *args, **kwargs):
self.kwargs = kwargs
self.torch_op = aten.expand.default

# The third parameter broadcast_dims is only required in hlir_builder Expand.
def __call__(self, *args, **kwargs):
new_args = args[:2]
return super().__call__(*new_args, **kwargs)
Expand Down

0 comments on commit 706fa91

Please sign in to comment.