diff --git a/dipu/SupportedDiopiFunctions.txt b/dipu/SupportedDiopiFunctions.txt index 088ba2aef..24b97f3aa 100644 --- a/dipu/SupportedDiopiFunctions.txt +++ b/dipu/SupportedDiopiFunctions.txt @@ -192,6 +192,7 @@ diopiNormalInp diopiNormalScalarTensor diopiNormalTensor diopiNormalTensorScalar +diopiOnes diopiPolar diopiPow diopiPowInp @@ -250,3 +251,5 @@ diopiUpsampleLinearBackward diopiUpsampleNearest diopiUpsampleNearestBackward diopiWhere +diopiZeroInp +diopiZeros diff --git a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py index 6f10b07e0..fe01e00f5 100644 --- a/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py +++ b/dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py @@ -453,6 +453,7 @@ def create_call_aten_cpu_cpp_function_code_from_config(fun_config): opname = re.sub("\.dim_min", "_outf", opname) opname = re.sub("\.correction", "", opname) opname = re.sub("\.input", "", opname) + opname = re.sub("\.dim_IntList", "", opname) opname = opname.replace(".", "_") opname = opname.split(".")[0] if opname[-1] == "_" and len(get_function_return_param_from_schema(schema)) > 0: diff --git a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml index 626609ebe..f1d8cc127 100755 --- a/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml +++ b/dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml @@ -2279,6 +2279,29 @@ - schema: "atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)" interface: diopiAtan(ctx, out, self) +- schema: "ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)" + interface: diopiOnes(ctx, out, size) + +- schema: "ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor" + custom_code_at_the_beginning: | + c10::TensorOptions option; + auto shape = c10::asIntArrayRefUnchecked(size); + auto out = nodispatch::empty(shape, option.dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)); + interface: diopiOnes(ctx, out, size) + +- schema: "zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor" + custom_code_at_the_beginning: | + c10::TensorOptions option; + auto shape = c10::asIntArrayRefUnchecked(size); + auto out = nodispatch::empty(shape, option.dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)); + interface: diopiZeroInp(ctx, out) + +- schema: "zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)" + interface: diopiZeros(ctx, out, size) + +- schema: "zero_(Tensor(a!) self) -> Tensor(a!)" + interface: diopiZeroInp(ctx, self) + - schema: "im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor" size_attr: [kernel_size, stride, padding, dilation] custom_code_at_the_beginning: | diff --git a/dipu/tests/python/individual_scripts/test_rt_ddp.py b/dipu/tests/python/individual_scripts/test_rt_ddp.py index d1d1a6cfb..1fed653da 100644 --- a/dipu/tests/python/individual_scripts/test_rt_ddp.py +++ b/dipu/tests/python/individual_scripts/test_rt_ddp.py @@ -289,11 +289,28 @@ def demo_allgather_gloo(rank, world_size, port): cleanup() +def test_special_group_stuck(rank, world_size): + import torch_dipu + + print(f"test special group stuck on rank {rank} ") + + setup(rank, world_size) + + # ranks check require len(ranks) <= world_size + if world_size >= 2: + # torch 2.0 gloo pg has such a limitition. pass in duplicated rank will stuck. + # but huawei do. + ranks_dup = [rank, rank] + group = torch.distributed.new_group(ranks_dup) + print(group) + dist.destroy_process_group(group) + + cleanup() + + if __name__ == "__main__": n_gpus = torch.cuda.device_count() - # world_size = 1 - # demo_allreduce(0, world_size) - # demo_basic_ddp(0, world_size) + port = random.randint(10000, 60000) world_size = 1 @@ -311,3 +328,5 @@ def demo_allgather_gloo(rank, world_size, port): # run_demo(demo_bcast, world_size, port) # run_demo(demo_model_parallel, world_size) + + # run_demo(test_special_group_stuck, world_size) diff --git a/dipu/tests/python/individual_scripts/test_rt_tensor.py b/dipu/tests/python/individual_scripts/test_rt_tensor.py index f2757f759..042c584a8 100644 --- a/dipu/tests/python/individual_scripts/test_rt_tensor.py +++ b/dipu/tests/python/individual_scripts/test_rt_tensor.py @@ -152,6 +152,9 @@ def test_type(): res = isinstance(s4, torch.cuda.FloatTensor) assert res == True + assert dev1 in s1.type() + assert s1.device.type == dev1 + def test_device_copy(): import torch_dipu diff --git a/dipu/tests/python/unittests/test_ones.py b/dipu/tests/python/unittests/test_ones.py new file mode 100644 index 000000000..12fc94881 --- /dev/null +++ b/dipu/tests/python/unittests/test_ones.py @@ -0,0 +1,17 @@ +# Copyright (c) 2024, DeepLink. +import torch +import torch_dipu +from torch_dipu.testing._internal.common_utils import TestCase, run_tests + + +class TestOnes(TestCase): + def test_ones(self): + device = torch.device("dipu") + size = [5, 6] + x = torch.ones(size=size) + y = torch.ones(size=size, device=device) + self.assertEqual(x, y.cpu(), exact_dtype=True) + + +if __name__ == "__main__": + run_tests() diff --git a/dipu/tests/python/unittests/test_zeros.py b/dipu/tests/python/unittests/test_zeros.py new file mode 100644 index 000000000..a83ea2af5 --- /dev/null +++ b/dipu/tests/python/unittests/test_zeros.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024, DeepLink. +import torch +import torch_dipu +from torch_dipu.testing._internal.common_utils import TestCase, run_tests + + +class TestZeros(TestCase): + def test_zeros(self): + device = torch.device("dipu") + size = [5, 6] + x = torch.zeros(size=size) + y = torch.zeros(size=size, device=device) + self.assertEqual(x, y.cpu(), exact_dtype=True) + + def test_zero_(self): + size = [3, 5] + y = torch.randn(size=size).cuda() + x = y.cpu() + x.zero_() + y.zero_() + self.assertEqual(x, y.cpu(), exact_dtype=True) + + +if __name__ == "__main__": + run_tests() diff --git a/dipu/third_party/DIOPI b/dipu/third_party/DIOPI index 896ddfb6c..fe6538fab 160000 --- a/dipu/third_party/DIOPI +++ b/dipu/third_party/DIOPI @@ -1 +1 @@ -Subproject commit 896ddfb6cc774316d741f6506120739c4a861213 +Subproject commit fe6538fab7582b5b6a9b3ba3f915e89b7a55d287 diff --git a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp index ee559b568..82f36671b 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/RegisterDIPU.hpp @@ -34,13 +34,13 @@ void dipu_fallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, // NOLINTEND(bugprone-macro-parentheses) // Check the environment variable and call the DIPU_LOG_WARNING_ONCE -#define DIPU_OP_LOG_WARNING_ONCE(...) \ - do { \ - const char* env = std::getenv("DIPU_DUMP_OP_ARGS"); \ - int env_value = (env != nullptr) ? std::atoi(env) : 0; \ - if (env_value >= 0) { \ - DIPU_LOG_WARNING_ONCE(__VA_ARGS__); \ - } \ +#define DIPU_OP_LOG_WARNING_ONCE(...) \ + do { \ + const char* env = std::getenv("DIPU_DUMP_OP_ARGS"); \ + int env_value = (env != nullptr) ? std::atoi(env) : -1; \ + if (env_value >= 0) { \ + DIPU_LOG_WARNING_ONCE(__VA_ARGS__); \ + } \ } while (0) // Temporarily not implement 'sub-dispatch from box' (from torch box func -> diff --git a/dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp b/dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp index c2a4239fe..ab3537950 100644 --- a/dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp +++ b/dipu/torch_dipu/csrc_dipu/aten/ops/AutoCompareUtils.hpp @@ -192,5 +192,20 @@ inline std::string allclose_autocompare( return stream.str(); } +template ::value, bool> = true> +inline std::string allclose_autocompare(T val_expected, T val_real, + int indentation = 2) { + std::ostringstream stream; + stream << std::setfill(' '); + if (val_expected != val_real) { + stream << std::setw(indentation) << "not allclose: expected val is " + << val_expected << " but the real val is " << val_real << std::endl; + } else { + stream << "allclose" << std::endl; + } + return stream.str(); +} + } // namespace native } // namespace dipu diff --git a/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp b/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp index 122c56ec9..505eb7728 100644 --- a/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp +++ b/dipu/torch_dipu/csrc_dipu/vendor/ascend/deviceimpl.cpp @@ -16,6 +16,7 @@ namespace devapis { // ===================== // Device class related // ===================== + using AscendDeviceId = int32_t; namespace { diff --git a/dipu/torch_dipu/dipu/distributed.py b/dipu/torch_dipu/dipu/distributed.py index 8b1cfd93c..73e0cfbac 100644 --- a/dipu/torch_dipu/dipu/distributed.py +++ b/dipu/torch_dipu/dipu/distributed.py @@ -86,7 +86,40 @@ def _wrap_get_backend(group: Optional[ProcessGroup] = None) -> str: return ret +# dicl not support coalescing now. so torch2.1 batch_isend_irecv crash. +# Todo: remove after support coalesce. +def _wrap_batch_isend_irecv(p2p_op_list): + dist.distributed_c10d._check_p2p_op_list(p2p_op_list) + reqs = [] + for p2p_op in p2p_op_list: + work = p2p_op.op(p2p_op.tensor, p2p_op.peer, p2p_op.group, p2p_op.tag) + if work: + reqs.append(work) + return reqs + + +# huawei AscendSpeed pass rank list like [0, 0], which cause gloo pg +# creation fail in torch 2.0. actually it's huawei's problem, such list +# is not valid, but nothing else we can do. +# torch 2.1 not create gloo sub-device-pg when create dicl pg and no stuck happen on pg creation. +# so we keep it's behavior. but even created. it still stuck when try to do any real comm. +_raw_new_group = dist.new_group + + +def _wrap_new_group( + ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None +): + ranks = list(set(ranks)) # dedup + return _raw_new_group(ranks, timeout, backend, pg_options) + + def apply_dist_patch(): dist.get_backend = _wrap_get_backend dist.init_process_group = _wrap_init_process_groups dist.ProcessGroup._register_backend = _wrapped_register_backend + # rm batch_isend_irecv after coalse ready + if dipu.get_dipu_torch_version() != dipu.torch_ver_200: + dist.batch_isend_irecv = _wrap_batch_isend_irecv + + if dipu.get_dipu_torch_version() == dipu.torch_ver_200: + dist.new_group = _wrap_new_group diff --git a/dipu/torch_dipu/dipu/tensor.py b/dipu/torch_dipu/dipu/tensor.py index aae8667ac..6e151f4e4 100644 --- a/dipu/torch_dipu/dipu/tensor.py +++ b/dipu/torch_dipu/dipu/tensor.py @@ -1,7 +1,7 @@ # Copyright (c) 2023, DeepLink. import torch -from .device import __diputype__ +from .device import __diputype__, __dipu_device_type__ from torch_dipu import _C, mockcuda @@ -16,8 +16,20 @@ def __set_default_tensor_type(type=torch.FloatTensor): _default_tensor_type = type +_raw_tensor_type = torch.Tensor.type + + +def _wrap_tensor_type(self, *args, **kwargs): + ret = _raw_tensor_type(self, *args, **kwargs) + if isinstance(ret, str): + return ret.replace(__dipu_device_type__, "cuda") + else: + return ret + + # need enhance, seems change tensor define is need def apply_tensor_type_patch(): torch.set_default_tensor_type = __set_default_tensor_type if mockcuda: _C._mockCudaTensor() + torch.Tensor.type = _wrap_tensor_type