diff --git a/tests/test_array.py b/tests/test_array.py index c9b2cbce91..82534fcc38 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1463,4 +1463,6 @@ async def test_sharding_coordinate_selection() -> None: shards=(2, 4, 4), ) arr[:] = np.arange(2 * 3 * 4).reshape((2, 3, 4)) - assert (arr[1, [0, 1]] == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all() # type: ignore[index] + result = arr[1, [0, 1]] + assert isinstance(result, NDArrayLike) + assert (result == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all() # type: ignore[index]