Skip to content

Commit

Permalink
Redesign infer_shape interface.
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 committed Mar 1, 2024
1 parent bdf08c3 commit b912f5a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 92 deletions.
41 changes: 36 additions & 5 deletions dicp/dicp/dynamo_bridge/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,54 @@ def __init__(
if backend == 'topsgraph':
from dicp.vendor.TopsGraph.opset_transform import topsgraph_opset_transform
self.backend_opset_transform = topsgraph_opset_transform
from dicp.vendor.TopsGraph.opset_transform import topsgraph_infer_shape
self.backend_infer_shape = topsgraph_infer_shape
from dicp.vendor.TopsGraph.codegen.enflame import EnflameCodegen
self.backend_codegen = EnflameCodegen
elif backend == 'ascendgraph':
from dicp.vendor.AscendGraph.opset_convert import ascendgraph_opset_convert
self.backend_opset_transform = ascendgraph_opset_convert
from dicp.vendor.AscendGraph.opset_convert import ascendgraph_infer_shape
self.backend_infer_shape = ascendgraph_infer_shape
from dicp.vendor.AscendGraph.codegen.ascend import AscendCodegen
self.backend_codegen = AscendCodegen

def transform(self):
self.gm = self.backend_opset_transform(self.gm)

def infer_shape_dtype(self):
self.gm = self.backend_infer_shape(self.gm)
def make_tensor_meta(x) -> Optional[TensorMetadata]:
if isinstance(x, FakeTensor):
return _extract_tensor_metadata(x)
else:
return None
test_infer = bool(os.environ.get("TEST_DICP_INFER", False))
for n in self.gm.graph.nodes:
fake_value = None
if n.op == 'call_function':
fake_value = (n.target(*n.args, **n.kwargs))
elif n.op == 'get_attr':
target_atoms = n.target.split('.')
attr_itr = self.gm
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
attr_size, attr_dtye = attr_itr.shape, attr_itr.dtype
with FakeTensorMode():
fake_value = torch.empty(attr_size, dtype=attr_dtye)
else:
continue
if 'val' in n.meta and test_infer:
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)):
if not isinstance(fv_i, FakeTensor):
continue
log_info = f"target: {n.target}, meta_i: {meta_i}, fv_i: {fv_i}"
assert meta_i.size() == fv_i.size(), f"check infer size failed, {log_info}"
assert meta_i.dtype == fv_i.dtype, f"check infer dtype failed, {log_info}"
assert meta_i.stride() == fv_i.stride(), f"check infer stride failed, {log_info}"
assert meta_i.storage_offset() == fv_i.storage_offset(), f"check infer storage offset failed, {log_info}"
if 'val' not in n.meta:
n.meta['val'] = fake_value
n.meta["tensor_meta"] = make_tensor_meta(n.meta['val'])

def codegen(self):
return self.backend_codegen(self.gm, self.cpu_gm, self.folder, self.graph_key).codegen()
Expand Down
50 changes: 4 additions & 46 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer
from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata
from torch._subclasses import FakeTensor, FakeTensorMode
from ...dynamo_bridge.graph import GraphTransformer

if is_torch_210:
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
Expand Down Expand Up @@ -87,54 +88,11 @@ def ascendgraph_opset_convert(
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, aten_patterns_cls_list).transform(gm)
gm = AtenToAscendTransformer(gm).transform()
return gm

def ascendgraph_infer_shape(
gm: torch.fx.GraphModule,
):
def make_tensor_meta(x) -> Optional[TensorMetadata]:
if isinstance(x, FakeTensor):
return _extract_tensor_metadata(x)
else:
return None

def _infer_shape(gm):
test_infer = bool(os.environ.get("TEST_DICP_INFER", False))
for n in gm.graph.nodes:
fake_value = None
if n.op == 'call_function':
fake_value = (n.target(*n.args, **n.kwargs))
elif n.op == 'get_attr':
target_atoms = n.target.split('.')
attr_itr = gm
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
attr_size, attr_dtye = attr_itr.shape, attr_itr.dtype
with FakeTensorMode():
fake_value = torch.empty(attr_size, dtype=attr_dtye)
else:
continue
if 'val' in n.meta and test_infer:
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)):
if not isinstance(fv_i, FakeTensor):
continue
log_info = f"target: {n.target}, meta_i: {meta_i}, fv_i: {fv_i}"
assert meta_i.size() == fv_i.size(), f"check infer size failed, {log_info}"
assert meta_i.dtype == fv_i.dtype, f"check infer dtype failed, {log_info}"
assert meta_i.stride() == fv_i.stride(), f"check infer stride failed, {log_info}"
assert meta_i.storage_offset() == fv_i.storage_offset(), f"check infer storage offset failed, {log_info}"
if 'val' not in n.meta:
n.meta['val'] = fake_value
n.meta["tensor_meta"] = make_tensor_meta(n.meta['val'])
return gm

gm = _infer_shape(gm)
# For bug in pytorch
# Avoid for dynamic shape
gt = GraphTransformer(gm, "ascendgraph")
gt.infer_shape_dtype()
gm = gt.gm
if is_torch_210 and not symint_in_inputs(list(gm.graph.nodes)):
gm = BackendPatternMatcherTransformer(
ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm)
Expand Down
41 changes: 0 additions & 41 deletions dicp/dicp/vendor/TopsGraph/opset_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,44 +51,3 @@ def topsgraph_opset_transform(
gm = HandleInplaceCopyPass().transform(gm)

return gm

def topsgraph_infer_shape(
gm: torch.fx.GraphModule,
):
def make_tensor_meta(x) -> Optional[TensorMetadata]:
if isinstance(x, FakeTensor):
return _extract_tensor_metadata(x)
else:
return None
test_infer = bool(os.environ.get("TEST_DICP_INFER", False))
for n in gm.graph.nodes:
fake_value = None
if n.op == 'call_function':
fake_value = (n.target(*n.args, **n.kwargs))
elif n.op == 'get_attr':
target_atoms = n.target.split('.')
attr_itr = gm
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(
f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
attr_size, attr_dtye = attr_itr.shape, attr_itr.dtype
with FakeTensorMode():
fake_value = torch.empty(attr_size, dtype=attr_dtye)
else:
continue
if 'val' in n.meta and test_infer:
(n_meta_val, fake_val) = ((n.meta['val'],),(fake_value,)) if not isinstance(n.meta['val'],(Tuple,List)) else (n.meta['val'], fake_value)
for i,(meta_i,fv_i) in enumerate(zip(n_meta_val, fake_val)):
if not isinstance(fv_i, FakeTensor):
continue
log_info = f"target: {n.target}, meta_i: {meta_i}, fv_i: {fv_i}"
assert meta_i.size() == fv_i.size(), f"check infer size failed, {log_info}"
assert meta_i.dtype == fv_i.dtype, f"check infer dtype failed, {log_info}"
assert meta_i.stride() == fv_i.stride(), f"check infer stride failed, {log_info}"
assert meta_i.storage_offset() == fv_i.storage_offset(), f"check infer storage offset failed, {log_info}"
if 'val' not in n.meta:
n.meta['val'] = fake_value
n.meta["tensor_meta"] = make_tensor_meta(n.meta['val'])
return gm

0 comments on commit b912f5a

Please sign in to comment.