Skip to content

Commit

Permalink
[Ascend]zzf/add equal op (#739)
Browse files Browse the repository at this point in the history
* add equal op for ascend

* add unittest for equal op

* add unittest for equal op

* add unittest for equal op

* add unittest for equal op

* add torch.equal to dipu

* Update diopi_functions.yaml
  • Loading branch information
zhangzefeng92 authored Mar 23, 2024
1 parent 5afa987 commit 2ce08a7
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 2 deletions.
1 change: 1 addition & 0 deletions dipu/SupportedDiopiFunctions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ diopiEq
diopiEqInp
diopiEqInpScalar
diopiEqScalar
diopiEqual
diopiErfinv
diopiErfinvInp
diopiExp
Expand Down
5 changes: 4 additions & 1 deletion dipu/scripts/autogen_diopi_wrapper/autogen_diopi_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,21 @@ def get_function_return_param_from_schema(schema):
args = return_params[i]
inplace_match = re.search("Tensor\([a-zA-Z]+!\)", args)
pure_out_match = re.search("Tensor[ ,]?", args)
bool_out_match = re.search("bool", args)
if inplace_match is not None:
arg_label = re.sub(".*(\(.*\))", r"\1", inplace_match.group())
index = schema.find(arg_label) + len(arg_label)
param = re.search("[a-zA-Z0-9_::]+", schema[index:]).group()
params.append(param)
elif inplace_match is None and pure_out_match is not None:
elif pure_out_match is not None:
name_from_schema = re.sub("\(?Tensor[ ]+([\w\d_]+)\)?", R"\1", args)
if name_from_schema == args:
name = "out" + (str(i) if len(return_params) > 1 else "")
else:
name = name_from_schema
params.append(name)
elif bool_out_match is not None:
params.append("out")
return params


Expand Down
5 changes: 5 additions & 0 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@
}
interface: diopiEqInp(ctx, self, other)

- schema: equal(Tensor self, Tensor other) -> bool
custom_code_at_the_beginning: |
bool out;
interface: diopiEqual(ctx, &out, self, other)

- schema: "lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)"
interface: diopiLtScalar(ctx, out, self, other)

Expand Down
36 changes: 36 additions & 0 deletions dipu/tests/python/unittests/test_equal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2023, DeepLink.
import torch
import math
import torch_dipu
from torch_dipu.testing._internal.common_utils import TestCase, run_tests, onlyOn


class TestEqual(TestCase):
@onlyOn("NPU")
def test_equal_1(self):
device = torch.device("dipu")
x = torch.rand(3, 4).to(device)
self.assertEqual(True, torch.equal(x, x))

@onlyOn("NPU")
def test_equal_2(self):
device = torch.device("dipu")
x = torch.rand(3, 4).to(device)
self.assertEqual(False, torch.equal(x, x.to(torch.float16)))

@onlyOn("NPU")
def test_equal_3(self):
device = torch.device("dipu")
x = torch.zeros(3, 4).to(device)
self.assertEqual(True, torch.equal(x, x.to(torch.float16)))

@onlyOn("NPU")
def test_equal_4(self):
device = torch.device("dipu")
x = torch.rand(3, 4).to(device)
y = torch.rand(3, 5).to(device)
self.assertEqual(False, torch.equal(x, y))


if __name__ == "__main__":
run_tests()

0 comments on commit 2ce08a7

Please sign in to comment.