Skip to content

Commit

Permalink
test plot grid and support 2d images
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed Mar 13, 2024
1 parent 8400d62 commit 47e2372
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
21 changes: 16 additions & 5 deletions ants/viz/plot_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ def plot_grid(
... [mni3, mni4]])
>>> slices = np.asarray([[100, 100],
... [100, 100]])
>>> #axes = np.asarray([[2,2],[2,2]])
>>> # standard plotting
>>> ants.plot_grid(images=images, slices=slices, title='2x2 Grid')
>>> images2d = np.asarray([[mni1.slice_image(2,100), mni2.slice_image(2,100)],
... [mni3.slice_image(2,100), mni4.slice_image(2,100)]])
>>> ants.plot_grid(images=images2d, title='2x2 Grid Pre-Sliced')
>>> ants.plot_grid(images.reshape(1,4), slices.reshape(1,4), title='1x4 Grid')
>>> ants.plot_grid(images.reshape(4,1), slices.reshape(4,1), title='4x1 Grid')
Expand Down Expand Up @@ -178,6 +179,8 @@ def slice_image(img, axis, idx):
if not isinstance(images[0], list):
images = [images]

if slices is None:
one_slice = True
if isinstance(slices, int):
one_slice = True
if isinstance(slices, np.ndarray):
Expand Down Expand Up @@ -320,9 +323,17 @@ def slice_image(img, axis, idx):
tmpaxis = axes
else:
tmpaxis = axes[rowidx][colidx]
sliceidx = slices[rowidx][colidx] if not one_slice else slices
tmpslice = slice_image(tmpimg, tmpaxis, sliceidx)
tmpslice = reorient_slice(tmpslice, tmpaxis)

if tmpimg.dimension == 2:
tmpslice = tmpimg.numpy()
tmpslice = reorient_slice(tmpslice, tmpaxis)
else:
sliceidx = slices[rowidx][colidx] if not one_slice else slices
if sliceidx is None:
sliceidx = math.ceil(tmpimg.shape[tmpaxis] / 2)
tmpslice = slice_image(tmpimg, tmpaxis, sliceidx)
tmpslice = reorient_slice(tmpslice, tmpaxis)

im = ax.imshow(tmpslice, cmap=rcmap, aspect="auto", vmin=rvmin, vmax=rvmax)
ax.axis("off")

Expand Down
28 changes: 27 additions & 1 deletion tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,32 @@ def test_plot_example(self):
filename = mktemp(suffix='.png')
for img in self.imgs:
ants.plot_hist(img)


class TestModule_plot_grid(unittest.TestCase):

def setUp(self):
mni1 = ants.image_read(ants.get_data('mni'))
mni2 = mni1.smooth_image(1.)
mni3 = mni1.smooth_image(2.)
mni4 = mni1.smooth_image(3.)
self.images3d = np.asarray([[mni1, mni2],
[mni3, mni4]])
self.images2d = np.asarray([[mni1.slice_image(2,100), mni2.slice_image(2,100)],
[mni3.slice_image(2,100), mni4.slice_image(2,100)]])

def tearDown(self):
pass

def test_plot_example(self):
ants.plot_grid(self.images3d, slices=100)
# should take middle slices if none are given
ants.plot_grid(self.images3d)
# should work with 2d images
ants.plot_grid(self.images2d)





if __name__ == '__main__':
run_tests()

0 comments on commit 47e2372

Please sign in to comment.