Skip to content

Commit

Permalink
[DIPU]fix test_dipu_copy_fallback empty nan (#640)
Browse files Browse the repository at this point in the history
  • Loading branch information
ustclight-sls authored Jan 16, 2024
1 parent 3bb9dcf commit 7840fec
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions dipu/tests/python/individual_scripts/test_dipu_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_fallback(

test_fn()
output = captured.getvalue().decode()
print(output, end="")
print(output, end="", flush=True)
assert all(
f"force fallback has been set, {name} will be fallback to cpu" in output
for name in op_names
Expand Down Expand Up @@ -96,14 +96,16 @@ def fn():

assert torch.allclose(target_tensor, target_dipu.cpu())

N, C, H, W = 1, 3, 2, 2
x = torch.empty(N, C, H, W)
x1 = x.to(memory_format=torch.channels_last).cuda()
x2stride = list(x1.stride())
x2stride[0] = 4
x2 = x.as_strided(x1.size(), tuple(x2stride)).cuda()
y = torch.empty(N, C, H, W).cuda()
y.copy_(x2)
# detect the occasional occurrence problem of copy_
for i in range(100):
N, C, H, W = 1, 3, 2, 2
x = torch.randn(N, C, H, W)
x1 = x.to(memory_format=torch.channels_last).cuda()
x2stride = list(x1.stride())
x2stride[0] = 4
x2 = x.as_strided(x1.size(), tuple(x2stride)).cuda()
y = torch.randn(N, C, H, W).cuda()
y.copy_(x2)

assert torch.allclose(y, x2)

Expand Down Expand Up @@ -168,6 +170,7 @@ def fn():
["custom fallback to cpu, name=convolution_overrideable"],
)


def _test_dipu_silu_fallback():
def fn():
m = torch.nn.SiLU().cuda()
Expand Down Expand Up @@ -197,7 +200,7 @@ def fn():
_test_dipu_copy_fallback_,
_test_dipu_convolution_backward_overrideable_fallback,
_test_dipu_convolution_overrideable_fallback,
_test_dipu_silu_fallback
_test_dipu_silu_fallback,
],
in_parallel=True,
)

0 comments on commit 7840fec

Please sign in to comment.