Skip to content

Commit

Permalink
[dipu]Add format_convert error handling hook. (#644)
Browse files Browse the repository at this point in the history
* Add AscendFormatTensor dispatch hook for format_convert error handling.

* Rename format_cast function name.
  • Loading branch information
pdx1989 authored Jan 17, 2024
1 parent e8b1dec commit f6f47ec
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions dipu/torch_dipu/dipu/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) 2023, DeepLink.

import torch
import torch_dipu
from torch_dipu import _C
import warnings
import os
import traceback
import threading
from multiprocessing.util import register_after_fork as _register_after_fork
from torch.utils._pytree import tree_map
import re

_initialized = True
Expand All @@ -24,6 +26,39 @@

_torch_ver_pattern = re.compile(r'^(\d+)\.(\d+)\.(\d+)\.*')


class AscendFormatTensor(torch.Tensor):
def __init__(cls, t):
cls.elem = t
cls.func_list = ['aten.transpose.int', 'aten.slice.Tensor', 'aten.view.default']

def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
# unwrap the wrapper tensors to get the inner tensor object
def unwrap(x):
return x.elem if isinstance(x, AscendFormatTensor) else x

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if str(func) in cls.func_list:
assert len(args) > 0
if 'FRACTAL_NZ' in str(torch_dipu.get_native_memory_format(args[0])):
raise RuntimeError("View type op calculation does not support FRACTAL_NZ!")
out = func(*args, **kwargs)

def wrap(x):
return AscendFormatTensor(x) if isinstance(x, torch.Tensor) else x

return tree_map(wrap, out)


def ascend_format_cast(tensor:torch.Tensor, format:_C.NativeMemoryFormat) -> torch.Tensor:
if isinstance(tensor, torch.Tensor) and 'FRACTAL_NZ' in str(format):
return AscendFormatTensor(torch_dipu.native_memory_format_cast(tensor, format))
if isinstance(tensor, AscendFormatTensor) and not 'FRACTAL_NZ' in str(format):
return torch_dipu.native_memory_format_cast(tensor.elem, format)
return torch_dipu.native_memory_format_cast(tensor, format)


def check_dipu_torch_compatiable():
def _replacer(matched):
replaced = ''
Expand Down

0 comments on commit f6f47ec

Please sign in to comment.