Skip to content

Commit 65e2439

Browse files
committed
seed RNG of lib3d test
1 parent 77c423e commit 65e2439

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

tests/test_lib3d.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
transform_pts,
2020
)
2121

22+
# seed the RNG for reproduceable tests
23+
np.random.seed(0)
24+
torch.manual_seed(0)
25+
2226

2327
class TestTransform(unittest.TestCase):
2428
"""
@@ -105,20 +109,29 @@ def test_quaternion_to_angle_axis(self):
105109
[self.quats_ts_norm[:, -1:], self.quats_ts_norm[:, :3]]
106110
)
107111
aa_ts = quaternion_to_angle_axis(quats_ts_norm)
108-
aa = pin.log3(pin.Quaternion(self.quats_arr_norm[1]).toRotationMatrix())
109-
self.assertTrue(np.allclose(aa_ts.numpy()[1], aa, atol=1e-6))
112+
aa_arr = np.zeros((self.N, 3))
113+
for i in range(self.N):
114+
aa_arr[i] = pin.log3(
115+
pin.Quaternion(self.quats_arr_norm[i]).toRotationMatrix()
116+
)
117+
# tolerance needs to be higher with seeds of 0 (for seeds 1, can be decreased to 1e-6)
118+
self.assertTrue(np.allclose(aa_ts.numpy(), aa_arr, atol=1e-4))
110119

111120
def test_quat2mat(self):
112121
# quat2mat assumes a wxyz quaternion order convention
113122
R_ts = quat2mat(self.quats_ts)
114-
R = pin.Quaternion(self.quats_arr_norm[1]).toRotationMatrix()
115-
self.assertTrue(np.allclose(R_ts[1, :3, :3].numpy(), R, atol=1e-6))
123+
R_arr = np.zeros((self.N, 3, 3))
124+
for i in range(self.N):
125+
R_arr[i] = pin.Quaternion(self.quats_arr_norm[i]).toRotationMatrix()
126+
self.assertTrue(np.allclose(R_ts[:, :3, :3].numpy(), R_arr, atol=1e-6))
116127

117128
def test_compute_rotation_matrix_from_quaternions(self):
118129
# quaternion_to_angle_axis assumes a xyzw quaternion order convention
119130
R_ts = compute_rotation_matrix_from_quaternions(self.quats_ts)
120-
R = pin.Quaternion(self.quats_arr_norm[1]).toRotationMatrix()
121-
self.assertTrue(np.allclose(R_ts.numpy()[1], R, atol=1e-6))
131+
R_arr = np.zeros((self.N, 3, 3))
132+
for i in range(self.N):
133+
R_arr[i] = pin.Quaternion(self.quats_arr_norm[i]).toRotationMatrix()
134+
self.assertTrue(np.allclose(R_ts.numpy(), R_arr, atol=1e-6))
122135

123136

124137
class TestTransformOps(unittest.TestCase):

0 commit comments

Comments
 (0)