Skip to content

Commit

Permalink
implement paras auto cast in some binary ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Jan 9, 2024
1 parent a9bcb21 commit 70c0b96
Showing 1 changed file with 90 additions and 6 deletions.
96 changes: 90 additions & 6 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dicp.dynamo_bridge.op_transformer import SingleOpTransformer
from dicp.dynamo_bridge.compile_fx import is_torch_210
from typing import (
Union,
Optional,
)
from torch.types import (
Expand Down Expand Up @@ -72,6 +73,50 @@ def register_conversion(aten_fn):
)


def get_cast_dtype(
type1: Union[str, torch.dtype, type], type2: Union[str, torch.dtype, type]
) -> Union[str, torch.dtype, None]:
if type1 == type2:
return type1

type_map = {
int: torch.int,
float: torch.float,
complex: torch.complex,
bool: torch.bool,
}

type1 = torch.dtype(type1) if isinstance(type1, str) else type1
type2 = torch.dtype(type2) if isinstance(type2, str) else type2

type1 = type_map[type1] if isinstance(type1, type) else type1
type2 = type_map[type2] if isinstance(type2, type) else type2

if type1 == torch.bool or type2 == torch.bool:
return torch.bool
elif type1 == torch.double or type2 == torch.double:
return torch.double

complex_list = [torch.complex32, torch.complex64, torch.complex128]
float_list = [torch.float16, torch.float32, torch.float, torch.float64]
int_list = [torch.int8, torch.int16, torch.int32, torch.int, torch.int64]

if type1 in complex_list or type2 in complex_list:
t1_idx = complex_list.index(type1) if type1 in complex_list else -1
t2_idx = complex_list.index(type2) if type2 in complex_list else -1
return complex_list[max(t1_idx, t2_idx)]
elif type1 in float_list or type2 in float_list:
t1_idx = float_list.index(type1) if type1 in float_list else -1
t2_idx = float_list.index(type2) if type2 in float_list else -1
return float_list[max(t1_idx, t2_idx)]
elif type1 in int_list or type2 in int_list:
t1_idx = int_list.index(type1) if type1 in int_list else -1
t2_idx = int_list.index(type2) if type2 in int_list else -1
return int_list[max(t1_idx, t2_idx)]

assert False, str(type1) + " " + str(type2) + " can't cast these two types!"


class AtenToTopsTransformer(SingleOpTransformer):
def __init__(self, gm):
super().__init__(gm, conversions)
Expand Down Expand Up @@ -110,6 +155,16 @@ def Mul(self, a, b):
if hasattr(a.node, "meta") and 'val' in a.node.meta:
if (a.node.meta['val'].dtype == torch.complex64) or (a.node.meta['val'].dtype == torch.cfloat):
return tops_op.ComplexMul(a, b)
if isinstance(a, Proxy) and isinstance(b, Proxy):
a_dtype = a.node.meta["val"].dtype
b_dtype = b.node.meta["val"].dtype
if a_dtype != b_dtype:
out_dtype = fx_traceback.get_current_meta()['val'].dtype
if a_dtype != out_dtype:
a = self.get_proxy(tops_op.Convert, (a, out_dtype))
if b_dtype != out_dtype:
b = self.get_proxy(tops_op.Convert, (b, out_dtype))
return self.get_proxy(tops_op.Mul, (a, b))
return tops_op.Mul(a, b)

@register_conversion(aten.mul.Scalar)
Expand Down Expand Up @@ -313,16 +368,34 @@ def ReduceMean(self, a, dim=None, keepdim=False, **kwargs):
return self.get_proxy(tops_op.ReduceMean, (a, dim, keepdim))

@register_conversion(aten.lt.Tensor)
def Less(self, *args, **kwargs):
return self.get_proxy(tops_op.Less, args, kwargs)
def Less(self, a, b):
if isinstance(a, Proxy) and isinstance(b, Proxy):
a_dtype = a.node.meta["val"].dtype
b_dtype = b.node.meta["val"].dtype
if a_dtype != b_dtype:
in_dtype = get_cast_dtype(a_dtype, b_dtype)
if a_dtype != in_dtype:
a = self.get_proxy(tops_op.Convert, (a, in_dtype))
if b_dtype != in_dtype:
b = self.get_proxy(tops_op.Convert, (b, in_dtype))
return self.get_proxy(tops_op.Less, (a, b))

@register_conversion(aten.le.Scalar)
def LessEqual(self, *args, **kwargs):
return self.get_proxy(tops_op.LessEqual, args, kwargs)

@register_conversion([aten.eq.Tensor, aten.eq.Scalar])
def Equal(self, *args, **kwargs):
return self.get_proxy(tops_op.Equal, args, kwargs)
def Equal(self, a, b):
if isinstance(a, Proxy) and isinstance(b, Proxy):
a_dtype = a.node.meta["val"].dtype
b_dtype = b.node.meta["val"].dtype
if a_dtype != b_dtype:
in_dtype = get_cast_dtype(a_dtype, b_dtype)
if a_dtype != in_dtype:
a = self.get_proxy(tops_op.Convert, (a, in_dtype))
if b_dtype != in_dtype:
b = self.get_proxy(tops_op.Convert, (b, in_dtype))
return self.get_proxy(tops_op.Equal, (a, b))

@register_conversion(aten.ne.Scalar)
def NotEqual(self, a, b):
Expand Down Expand Up @@ -463,8 +536,19 @@ def FullLike(self, *args, **kwargs):
return self.get_proxy(tops_op.FullLike, args, kwargs)

@register_conversion(aten.maximum.default)
def Max(self, *args, **kwargs):
return self.get_proxy(tops_op.Max, args, kwargs)
def Max(self, a, b):
if isinstance(a, Proxy) and isinstance(b, Proxy):
a_dtype = a.node.meta["val"].dtype
b_dtype = b.node.meta["val"].dtype
if a_dtype != b_dtype:
out_dtype = fx_traceback.get_current_meta()['val'].dtype
if a_dtype != out_dtype:
a = self.get_proxy(tops_op.Convert, (a, out_dtype))
if b_dtype != out_dtype:
b = self.get_proxy(tops_op.Convert, (b, out_dtype))
return self.get_proxy(tops_op.Max, (a, b))
return self.get_proxy(tops_op.Max, (a, b))


@register_conversion([aten.pow.Tensor_Scalar, aten.pow.Tensor_Tensor])
def Pow(self, *args, **kwargs):
Expand Down

0 comments on commit 70c0b96

Please sign in to comment.