diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 438f8599..1bc007d0 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -426,8 +426,16 @@ def __init__(self): super(Net, self).__init__() self.conv = torch.nn.Conv2d(1, 1, 3) + def map_f(self, u): + return u + 1 + def forward(self, x): - y = self.conv(x) - return list(ppe.map(lambda u: u + 1, y))[0] + y1 = self.conv(x) + y2 = self.conv(x) + y = [{"u" : y1}, {"u": y2}] + return list(ppe.map(self.map_f, y))[0] - run_model_test(Net(), (torch.rand(1, 1, 112, 112),), rtol=1e-03) + model = Net() + ppe.to(model, device="cpu") + + run_model_test(model, (torch.rand(1, 1, 112, 112),), rtol=1e-03)