diff --git a/ants/viz/__init__.py b/ants/viz/__init__.py index 1ebd977d..d58b0556 100644 --- a/ants/viz/__init__.py +++ b/ants/viz/__init__.py @@ -1,14 +1,14 @@ from .create_tiled_mosaic import create_tiled_mosaic -from .plot import ( - plot, - movie, - plot_hist, - plot_grid, - plot_ortho, - plot_ortho_double, - plot_ortho_stack, - plot_directory, -) + +from .plot import plot +from .movie import movie +from .plot_hist import plot_hist +from .plot_grid import plot_grid +from .plot_ortho import plot_ortho +from .plot_ortho_double import plot_ortho_double +from .plot_ortho_stack import plot_ortho_stack +from .plot_directory import plot_directory + from .render_surface_function import render_surface_function from .surface import (surf, surf_fold, surf_smooth, get_canonical_views) from .volume import (vol, vol_fold) diff --git a/ants/viz/movie.py b/ants/viz/movie.py new file mode 100644 index 00000000..39ed0a40 --- /dev/null +++ b/ants/viz/movie.py @@ -0,0 +1,91 @@ +""" +Functions for plotting ants images +""" + + +__all__ = [ + "movie" +] + +import fnmatch +import math +import os +import warnings + +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + + +import numpy as np + +from .. import registration as reg +from ..core import ants_image as iio +from ..core import ants_image_io as iio2 +from ..core import ants_transform as tio +from ..core import ants_transform_io as tio2 + +def movie(image, filename=None, writer=None, fps=30): + """ + Create and save a movie - mp4, gif, etc - of the various + 2D slices of a 3D ants image + + Try this: + conda install -c conda-forge ffmpeg + + Example + ------- + >>> import ants + >>> mni = ants.image_read(ants.get_data('mni')) + >>> ants.movie(mni, filename='~/desktop/movie.mp4') + """ + + image = image.pad_image() + img_arr = image.numpy() + + minidx = max(0, np.where(image > 0)[0][0] - 5) + maxidx = max(image.shape[0], np.where(image > 0)[0][-1] + 5) + + # Creare your figure and axes + fig, ax = plt.subplots(1) + + im = ax.imshow( + img_arr[minidx, :, :], + animated=True, + cmap="Greys_r", + vmin=image.quantile(0.05), + vmax=image.quantile(0.95), + ) + + ax.axis("off") + + def init(): + fig.axes("off") + return (im,) + + def updatefig(frame): + im.set_array(img_arr[frame, :, :]) + return (im,) + + ani = animation.FuncAnimation( + fig, + updatefig, + frames=np.arange(minidx, maxidx), + # init_func=init, + interval=50, + blit=True, + ) + + if writer is None: + writer = animation.FFMpegWriter(fps=fps) + + if filename is not None: + filename = os.path.expanduser(filename) + ani.save(filename, writer=writer) + else: + plt.show() diff --git a/ants/viz/plot.py b/ants/viz/plot.py index 188132a6..5b1a9faf 100644 --- a/ants/viz/plot.py +++ b/ants/viz/plot.py @@ -1,2099 +1,34 @@ """ -Create a static 2D image of a 2D ANTsImage -or a tile of slices from a 3D ANTsImage - -TODO: -- add `plot_multichannel` function for plotting multi-channel images - - support for quivers as well -- add `plot_gif` function for making a gif/video or 2D slices across a 3D image -""" - - -__all__ = [ - "plot", - "movie", - "plot_hist", - "plot_grid", - "plot_ortho", - "plot_ortho_double", - "plot_ortho_stack", - "plot_directory", -] - -import fnmatch -import math -import os -import warnings - -from matplotlib import gridspec -import matplotlib.pyplot as plt -import matplotlib.patheffects as path_effects -import matplotlib.lines as mlines -import matplotlib.patches as patches -import matplotlib.mlab as mlab -import matplotlib.animation as animation -from mpl_toolkits.axes_grid1.inset_locator import inset_axes - - -import numpy as np - -from .. import registration as reg -from ..core import ants_image as iio -from ..core import ants_image_io as iio2 -from ..core import ants_transform as tio -from ..core import ants_transform_io as tio2 - - -def movie(image, filename=None, writer=None, fps=30): - """ - Create and save a movie - mp4, gif, etc - of the various - 2D slices of a 3D ants image - - Try this: - conda install -c conda-forge ffmpeg - - Example - ------- - >>> import ants - >>> mni = ants.image_read(ants.get_data('mni')) - >>> ants.movie(mni, filename='~/desktop/movie.mp4') - """ - - image = image.pad_image() - img_arr = image.numpy() - - minidx = max(0, np.where(image > 0)[0][0] - 5) - maxidx = max(image.shape[0], np.where(image > 0)[0][-1] + 5) - - # Creare your figure and axes - fig, ax = plt.subplots(1) - - im = ax.imshow( - img_arr[minidx, :, :], - animated=True, - cmap="Greys_r", - vmin=image.quantile(0.05), - vmax=image.quantile(0.95), - ) - - ax.axis("off") - - def init(): - fig.axes("off") - return (im,) - - def updatefig(frame): - im.set_array(img_arr[frame, :, :]) - return (im,) - - ani = animation.FuncAnimation( - fig, - updatefig, - frames=np.arange(minidx, maxidx), - # init_func=init, - interval=50, - blit=True, - ) - - if writer is None: - writer = animation.FFMpegWriter(fps=fps) - - if filename is not None: - filename = os.path.expanduser(filename) - ani.save(filename, writer=writer) - else: - plt.show() - - -def plot_hist( - image, - threshold=0.0, - fit_line=False, - normfreq=True, - ## plot label arguments - title=None, - grid=True, - xlabel=None, - ylabel=None, - ## other plot arguments - facecolor="green", - alpha=0.75, -): - """ - Plot a histogram from an ANTsImage - - Arguments - --------- - image : ANTsImage - image from which histogram will be created - """ - img_arr = image.numpy().flatten() - img_arr = img_arr[np.abs(img_arr) > threshold] - - if normfreq != False: - normfreq = 1.0 if normfreq == True else normfreq - n, bins, patches = plt.hist( - img_arr, 50, facecolor=facecolor, alpha=alpha - ) - - if fit_line: - # add a 'best fit' line - y = mlab.normpdf(bins, img_arr.mean(), img_arr.std()) - l = plt.plot(bins, y, "r--", linewidth=1) - - if xlabel is not None: - plt.xlabel(xlabel) - if ylabel is not None: - plt.ylabel(ylabel) - if title is not None: - plt.title(title) - - plt.grid(grid) - plt.show() - - -def plot_grid( - images, - slices=None, - axes=2, - # general figure arguments - figsize=1.0, - rpad=0, - cpad=0, - vmin=None, - vmax=None, - colorbar=True, - cmap="Greys_r", - # title arguments - title=None, - tfontsize=20, - title_dx=0, - title_dy=0, - # row arguments - rlabels=None, - rfontsize=14, - rfontcolor="white", - rfacecolor="black", - # column arguments - clabels=None, - cfontsize=14, - cfontcolor="white", - cfacecolor="black", - # save arguments - filename=None, - dpi=400, - transparent=True, - # other args - **kwargs -): - """ - Plot a collection of images in an arbitrarily-defined grid - - Matplotlib named colors: https://matplotlib.org/examples/color/named_colors.html - - Arguments - --------- - images : list of ANTsImage types - image(s) to plot. - if one image, this image will be used for all grid locations. - if multiple images, they should be arrange in a list the same - shape as the `gridsize` argument. - - slices : integer or list of integers - slice indices to plot - if one integer, this slice index will be used for all images - if multiple integers, they should be arranged in a list the same - shape as the `gridsize` argument - - axes : integer or list of integers - axis or axes along which to plot image slices - if one integer, this axis will be used for all images - if multiple integers, they should be arranged in a list the same - shape as the `gridsize` argument - - Example - ------- - >>> import ants - >>> import numpy as np - >>> mni1 = ants.image_read(ants.get_data('mni')) - >>> mni2 = mni1.smooth_image(1.) - >>> mni3 = mni1.smooth_image(2.) - >>> mni4 = mni1.smooth_image(3.) - >>> images = np.asarray([[mni1, mni2], - ... [mni3, mni4]]) - >>> slices = np.asarray([[100, 100], - ... [100, 100]]) - >>> #axes = np.asarray([[2,2],[2,2]]) - >>> # standard plotting - >>> ants.plot_grid(images=images, slices=slices, title='2x2 Grid') - >>> ants.plot_grid(images.reshape(1,4), slices.reshape(1,4), title='1x4 Grid') - >>> ants.plot_grid(images.reshape(4,1), slices.reshape(4,1), title='4x1 Grid') - - >>> # Padding between rows and/or columns - >>> ants.plot_grid(images, slices, cpad=0.02, title='Col Padding') - >>> ants.plot_grid(images, slices, rpad=0.02, title='Row Padding') - >>> ants.plot_grid(images, slices, rpad=0.02, cpad=0.02, title='Row and Col Padding') - - >>> # Adding plain row and/or column labels - >>> ants.plot_grid(images, slices, title='Adding Row Labels', rlabels=['Row #1', 'Row #2']) - >>> ants.plot_grid(images, slices, title='Adding Col Labels', clabels=['Col #1', 'Col #2']) - >>> ants.plot_grid(images, slices, title='Row and Col Labels', - rlabels=['Row 1', 'Row 2'], clabels=['Col 1', 'Col 2']) - - >>> # Making a publication-quality image - >>> images = np.asarray([[mni1, mni2, mni2], - ... [mni3, mni4, mni4]]) - >>> slices = np.asarray([[100, 100, 100], - ... [100, 100, 100]]) - >>> axes = np.asarray([[0, 1, 2], - [0, 1, 2]]) - >>> ants.plot_grid(images, slices, axes, title='Publication Figures with ANTsPy', - tfontsize=20, title_dy=0.03, title_dx=-0.04, - rlabels=['Row 1', 'Row 2'], - clabels=['Col 1', 'Col 2', 'Col 3'], - rfontsize=16, cfontsize=16) - """ - - def mirror_matrix(x): - return x[::-1, :] - - def rotate270_matrix(x): - return mirror_matrix(x.T) - - def rotate180_matrix(x): - return x[::-1, ::-1] - - def rotate90_matrix(x): - return mirror_matrix(x).T - - def flip_matrix(x): - return mirror_matrix(rotate180_matrix(x)) - - def reorient_slice(x, axis): - if axis != 1: - x = rotate90_matrix(x) - if axis == 1: - x = rotate90_matrix(x) - x = mirror_matrix(x) - return x - - def slice_image(img, axis, idx): - if axis == 0: - return img[idx, :, :] - elif axis == 1: - return img[:, idx, :] - elif axis == 2: - return img[:, :, idx] - elif axis == -1: - return img[:, :, idx] - elif axis == -2: - return img[:, idx, :] - elif axis == -3: - return img[idx, :, :] - else: - raise ValueError("axis %i not valid" % axis) - - if isinstance(images, np.ndarray): - images = images.tolist() - if not isinstance(images, list): - raise ValueError("images argument must be of type list") - if not isinstance(images[0], list): - images = [images] - - if isinstance(slices, int): - one_slice = True - if isinstance(slices, np.ndarray): - slices = slices.tolist() - if isinstance(slices, list): - one_slice = False - if not isinstance(slices[0], list): - slices = [slices] - nslicerow = len(slices) - nslicecol = len(slices[0]) - - nrow = len(images) - ncol = len(images[0]) - - if rlabels is None: - rlabels = [None] * nrow - if clabels is None: - clabels = [None] * ncol - - if not one_slice: - if (nrow != nslicerow) or (ncol != nslicecol): - raise ValueError( - "`images` arg shape (%i,%i) must equal `slices` arg shape (%i,%i)!" - % (nrow, ncol, nslicerow, nslicecol) - ) - - fig = plt.figure(figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize)) - - if title is not None: - basex = 0.5 - basey = 0.9 if clabels[0] is None else 0.95 - fig.suptitle(title, fontsize=tfontsize, x=basex + title_dx, y=basey + title_dy) - - if (cpad > 0) and (rpad > 0): - bothgridpad = max(cpad, rpad) - cpad = 0 - rpad = 0 - else: - bothgridpad = 0.0 - - gs = gridspec.GridSpec( - nrow, - ncol, - wspace=bothgridpad, - hspace=0.0, - top=1.0 - 0.5 / (nrow + 1), - bottom=0.5 / (nrow + 1) + cpad, - left=0.5 / (ncol + 1) + rpad, - right=1 - 0.5 / (ncol + 1), - ) - - if isinstance(vmin, (int, float)): - vmins = [vmin] * nrow - elif vmin is None: - vmins = [None] * nrow - else: - vmins = vmin - - if isinstance(vmax, (int, float)): - vmaxs = [vmax] * nrow - elif vmax is None: - vmaxs = [None] * nrow - else: - vmaxs = vmax - - if isinstance(cmap, str): - cmaps = [cmap] * nrow - elif cmap is None: - cmaps = [None] * nrow - else: - cmaps = cmap - - for rowidx, rvmin, rvmax, rcmap in zip(range(nrow), vmins, vmaxs, cmaps): - for colidx in range(ncol): - ax = plt.subplot(gs[rowidx, colidx]) - - if colidx == 0: - if rlabels[rowidx] is not None: - bottom, height = 0.25, 0.5 - top = bottom + height - # add label text - ax.text( - -0.07, - 0.5 * (bottom + top), - rlabels[rowidx], - horizontalalignment="right", - verticalalignment="center", - rotation="vertical", - transform=ax.transAxes, - color=rfontcolor, - fontsize=rfontsize, - ) - - # add label background - extra = 0.3 if rowidx == 0 else 0.0 - - rect = patches.Rectangle( - (-0.3, 0), - 0.3, - 1.0 + extra, - facecolor=rfacecolor, - alpha=1.0, - transform=ax.transAxes, - clip_on=False, - ) - ax.add_patch(rect) - - if rowidx == 0: - if clabels[colidx] is not None: - bottom, height = 0.25, 0.5 - left, width = 0.25, 0.5 - right = left + width - top = bottom + height - ax.text( - 0.5 * (left + right), - 0.09 + top + bottom, - clabels[colidx], - horizontalalignment="center", - verticalalignment="center", - rotation="horizontal", - transform=ax.transAxes, - color=cfontcolor, - fontsize=cfontsize, - ) - - # add label background - rect = patches.Rectangle( - (0, 1.0), - 1.0, - 0.3, - facecolor=cfacecolor, - alpha=1.0, - transform=ax.transAxes, - clip_on=False, - ) - ax.add_patch(rect) - - tmpimg = images[rowidx][colidx] - if isinstance(axes, int): - tmpaxis = axes - else: - tmpaxis = axes[rowidx][colidx] - sliceidx = slices[rowidx][colidx] if not one_slice else slices - tmpslice = slice_image(tmpimg, tmpaxis, sliceidx) - tmpslice = reorient_slice(tmpslice, tmpaxis) - im = ax.imshow(tmpslice, cmap=rcmap, aspect="auto", vmin=rvmin, vmax=rvmax) - ax.axis("off") - - # A colorbar solution with make_axes_locatable will not allow y-scaling of the colorbar. - # from mpl_toolkits.axes_grid1 import make_axes_locatable - # divider = make_axes_locatable(ax) - # cax = divider.append_axes('right', size='5%', pad=0.05) - if colorbar: - axins = inset_axes(ax, - width="5%", # width = 5% of parent_bbox width - height="90%", # height : 50% - loc='center left', - bbox_to_anchor=(1.03, 0., 1, 1), - bbox_transform=ax.transAxes, - borderpad=0, - ) - fig.colorbar(im, cax=axins, orientation='vertical') - - if filename is not None: - filename = os.path.expanduser(filename) - plt.savefig(filename, dpi=dpi, transparent=transparent, bbox_inches="tight") - plt.close(fig) - else: - plt.show() - - -def plot_ortho_stack( - images, - overlays=None, - reorient=True, - # xyz arguments - xyz=None, - xyz_lines=False, - xyz_color="red", - xyz_alpha=0.6, - xyz_linewidth=2, - xyz_pad=5, - # base image arguments - cmap="Greys_r", - alpha=1, - # overlay arguments - overlay_cmap="jet", - overlay_alpha=0.9, - # background arguments - black_bg=True, - bg_thresh_quant=0.01, - bg_val_quant=0.99, - # scale/crop/domain arguments - crop=False, - scale=False, - domain_image_map=None, - # title arguments - title=None, - titlefontsize=24, - title_dx=0, - title_dy=0, - # 4th panel text arguemnts - text=None, - textfontsize=24, - textfontcolor="white", - text_dx=0, - text_dy=0, - # save & size arguments - filename=None, - dpi=500, - figsize=1.0, - colpad=0, - rowpad=0, - transpose=False, - transparent=True, - orient_labels=True, -): - """ - Create a stack of orthographic plots with optional overlays. - - Use mask_image and/or threshold_image to preprocess images to be be - overlaid and display the overlays in a given range. See the wiki examples. - - Example - ------- - >>> import ants - >>> mni = ants.image_read(ants.get_data('mni')) - >>> ch2 = ants.image_read(ants.get_data('ch2')) - >>> ants.plot_ortho_stack([mni,mni,mni]) - """ - - def mirror_matrix(x): - return x[::-1, :] - - def rotate270_matrix(x): - return mirror_matrix(x.T) - - def reorient_slice(x, axis): - return rotate270_matrix(x) - - # need this hack because of a weird NaN warning from matplotlib with overlays - warnings.simplefilter("ignore") - - n_images = len(images) - - # handle `image` argument - for i in range(n_images): - if isinstance(images[i], str): - images[i] = iio2.image_read(images[i]) - if not isinstance(images[i], iio.ANTsImage): - raise ValueError("image argument must be an ANTsImage") - if images[i].dimension != 3: - raise ValueError("Input image must have 3 dimensions!") - - if overlays is None: - overlays = [None] * n_images - # handle `overlay` argument - for i in range(n_images): - if overlays[i] is not None: - if isinstance(overlays[i], str): - overlays[i] = iio2.image_read(overlays[i]) - if not isinstance(overlays[i], iio.ANTsImage): - raise ValueError("overlay argument must be an ANTsImage") - if overlays[i].components > 1: - raise ValueError("overlays[i] cannot have more than one voxel component") - if overlays[i].dimension != 3: - raise ValueError("Overlay image must have 3 dimensions!") - - if not iio.image_physical_space_consistency(images[i], overlays[i]): - overlays[i] = reg.resample_image_to_target( - overlays[i], images[i], interp_type="linear" - ) - - for i in range(1, n_images): - if not iio.image_physical_space_consistency(images[0], images[i]): - images[i] = reg.resample_image_to_target( - images[0], images[i], interp_type="linear" - ) - - # reorient images - if reorient != False: - if reorient == True: - reorient = "RPI" - - for i in range(n_images): - images[i] = images[i].reorient_image2(reorient) - - if overlays[i] is not None: - overlays[i] = overlays[i].reorient_image2(reorient) - - # handle `slices` argument - if xyz is None: - xyz = [int(s / 2) for s in images[0].shape] - for i in range(3): - if xyz[i] is None: - xyz[i] = int(images[0].shape[i] / 2) - - # resample image if spacing is very unbalanced - spacing = [s for i, s in enumerate(images[0].spacing)] - if (max(spacing) / min(spacing)) > 3.0: - new_spacing = (1, 1, 1) - for i in range(n_images): - images[i] = images[i].resample_image(tuple(new_spacing)) - if overlays[i] is not None: - overlays[i] = overlays[i].resample_image(tuple(new_spacing)) - xyz = [ - int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) - ] - - # potentially crop image - if crop: - for i in range(n_images): - plotmask = images[i].get_mask(cleanup=0) - if plotmask.max() == 0: - plotmask += 1 - images[i] = images[i].crop_image(plotmask) - if overlays[i] is not None: - overlays[i] = overlays[i].crop_image(plotmask) - - # pad images - for i in range(n_images): - if i == 0: - images[i], lowpad, uppad = images[i].pad_image(return_padvals=True) - else: - images[i] = images[i].pad_image() - if overlays[i] is not None: - overlays[i] = overlays[i].pad_image() - xyz = [v + l for v, l in zip(xyz, lowpad)] - - # handle `domain_image_map` argument - if domain_image_map is not None: - if isinstance(domain_image_map, iio.ANTsImage): - tx = tio2.new_ants_transform( - precision="float", transform_type="AffineTransform", dimension=3 - ) - for i in range(n_images): - images[i] = tio.apply_ants_transform_to_image( - tx, images[i], domain_image_map - ) - - if overlays[i] is not None: - overlays[i] = tio.apply_ants_transform_to_image( - tx, overlays[i], domain_image_map, interpolation="linear" - ) - elif isinstance(domain_image_map, (list, tuple)): - # expect an image and transformation - if len(domain_image_map) != 2: - raise ValueError("domain_image_map list or tuple must have length == 2") - - dimg = domain_image_map[0] - if not isinstance(dimg, iio.ANTsImage): - raise ValueError("domain_image_map first entry should be ANTsImage") - - tx = domain_image_map[1] - for i in range(n_images): - images[i] = reg.apply_transforms(dimg, images[i], transform_list=tx) - if overlays[i] is not None: - overlays[i] = reg.apply_transforms( - dimg, overlays[i], transform_list=tx, interpolator="linear" - ) - - # potentially find dynamic range - if scale == True: - vmins = [] - vmaxs = [] - for i in range(n_images): - vmin, vmax = images[i].quantile((0.05, 0.95)) - vmins.append(vmin) - vmaxs.append(vmax) - elif isinstance(scale, (list, tuple)): - if len(scale) != 2: - raise ValueError( - "scale argument must be boolean or list/tuple with two values" - ) - vmins = [] - vmaxs = [] - for i in range(n_images): - vmin, vmax = images[i].quantile(scale) - vmins.append(vmin) - vmaxs.append(vmax) - else: - vmin = None - vmax = None - - if not transpose: - nrow = n_images - ncol = 3 - else: - nrow = 3 - ncol = n_images - - fig = plt.figure(figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize)) - if title is not None: - basey = 0.93 - basex = 0.5 - fig.suptitle( - title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy - ) - - if (colpad > 0) and (rowpad > 0): - bothgridpad = max(colpad, rowpad) - colpad = 0 - rowpad = 0 - else: - bothgridpad = 0.0 - - gs = gridspec.GridSpec( - nrow, - ncol, - wspace=bothgridpad, - hspace=0.0, - top=1.0 - 0.5 / (nrow + 1), - bottom=0.5 / (nrow + 1) + colpad, - left=0.5 / (ncol + 1) + rowpad, - right=1 - 0.5 / (ncol + 1), - ) - - # pad image to have isotropic array dimensions - vminols=[] - vmaxols=[] - for i in range(n_images): - images[i] = images[i].numpy() - if overlays[i] is not None: - vminols.append( overlays[i].min() ) - vmaxols.append( overlays[i].max() ) - overlays[i] = overlays[i].numpy() - if overlays[i].dtype not in ["uint8", "uint32"]: - overlays[i][np.abs(overlays[i]) == 0] = np.nan - - #################### - #################### - for i in range(n_images): - yz_slice = reorient_slice(images[i][xyz[0], :, :], 0) - if not transpose: - ax = plt.subplot(gs[i, 0]) - else: - ax = plt.subplot(gs[0, i]) - ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlays[i] is not None: - yz_overlay = reorient_slice(overlays[i][xyz[0], :, :], 0) - ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, - vmin=vminols[i], vmax=vmaxols[i]) - if xyz_lines: - # add lines - l = mlines.Line2D( - [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], - [xyz_pad, yz_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, yz_slice.shape[1] - xyz_pad], - [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - if orient_labels: - ax.text( - 0.5, - 0.98, - "S", - horizontalalignment="center", - verticalalignment="top", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.5, - 0.02, - "I", - horizontalalignment="center", - verticalalignment="bottom", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.98, - 0.5, - "A", - horizontalalignment="right", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.02, - 0.5, - "P", - horizontalalignment="left", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.axis("off") - #################### - #################### - - xz_slice = reorient_slice(images[i][:, xyz[1], :], 1) - if not transpose: - ax = plt.subplot(gs[i, 1]) - else: - ax = plt.subplot(gs[1, i]) - ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlays[i] is not None: - xz_overlay = reorient_slice(overlays[i][:, xyz[1], :], 1) - ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, - vmin=vminols[i], vmax=vmaxols[i]) - if xyz_lines: - # add lines - l = mlines.Line2D( - [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], - [xyz_pad, xz_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, xz_slice.shape[1] - xyz_pad], - [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - if orient_labels: - ax.text( - 0.5, - 0.98, - "A", - horizontalalignment="center", - verticalalignment="top", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.5, - 0.02, - "P", - horizontalalignment="center", - verticalalignment="bottom", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.98, - 0.5, - "L", - horizontalalignment="right", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.02, - 0.5, - "R", - horizontalalignment="left", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.axis("off") - - #################### - #################### - xy_slice = reorient_slice(images[i][:, :, xyz[2]], 2) - if not transpose: - ax = plt.subplot(gs[i, 2]) - else: - ax = plt.subplot(gs[2, i]) - ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlays[i] is not None: - xy_overlay = reorient_slice(overlays[i][:, :, xyz[2]], 2) - ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap, - vmin=vminols[i], vmax=vmaxols[i]) - if xyz_lines: - # add lines - l = mlines.Line2D( - [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], - [xyz_pad, xy_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, xy_slice.shape[1] - xyz_pad], - [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - if orient_labels: - ax.text( - 0.5, - 0.98, - "A", - horizontalalignment="center", - verticalalignment="top", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.5, - 0.02, - "P", - horizontalalignment="center", - verticalalignment="bottom", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.98, - 0.5, - "L", - horizontalalignment="right", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.02, - 0.5, - "R", - horizontalalignment="left", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.axis("off") - - #################### - #################### - - if filename is not None: - plt.savefig(filename, dpi=dpi, transparent=transparent) - plt.close(fig) - else: - plt.show() - - # turn warnings back to default - warnings.simplefilter("default") - - -def plot_ortho_double( - image, - image2, - overlay=None, - overlay2=None, - reorient=True, - # xyz arguments - xyz=None, - xyz_lines=True, - xyz_color="red", - xyz_alpha=0.6, - xyz_linewidth=2, - xyz_pad=5, - # base image arguments - cmap="Greys_r", - alpha=1, - cmap2="Greys_r", - alpha2=1, - # overlay arguments - overlay_cmap="jet", - overlay_alpha=0.9, - overlay_cmap2="jet", - overlay_alpha2=0.9, - # background arguments - black_bg=True, - bg_thresh_quant=0.01, - bg_val_quant=0.99, - # scale/crop/domain arguments - crop=False, - scale=False, - crop2=False, - scale2=True, - domain_image_map=None, - # title arguments - title=None, - titlefontsize=24, - title_dx=0, - title_dy=0, - # 4th panel text arguemnts - text=None, - textfontsize=24, - textfontcolor="white", - text_dx=0, - text_dy=0, - # save & size arguments - filename=None, - dpi=500, - figsize=1.0, - flat=True, - transpose=False, - transparent=True, -): - """ - Create a pair of orthographic plots with overlays. - - Use mask_image and/or threshold_image to preprocess images to be be - overlaid and display the overlays in a given range. See the wiki examples. - - Example - ------- - >>> import ants - >>> mni = ants.image_read(ants.get_data('mni')) - >>> ch2 = ants.image_read(ants.get_data('ch2')) - >>> ants.plot_ortho_double(mni, ch2) - """ - - def mirror_matrix(x): - return x[::-1, :] - - def rotate270_matrix(x): - return mirror_matrix(x.T) - - def reorient_slice(x, axis): - return rotate270_matrix(x) - - # need this hack because of a weird NaN warning from matplotlib with overlays - warnings.simplefilter("ignore") - - # handle `image` argument - if isinstance(image, str): - image = iio2.image_read(image) - if not isinstance(image, iio.ANTsImage): - raise ValueError("image argument must be an ANTsImage") - if image.dimension != 3: - raise ValueError("Input image must have 3 dimensions!") - - if isinstance(image2, str): - image2 = iio2.image_read(image2) - if not isinstance(image2, iio.ANTsImage): - raise ValueError("image2 argument must be an ANTsImage") - if image2.dimension != 3: - raise ValueError("Input image2 must have 3 dimensions!") - - # handle `overlay` argument - if overlay is not None: - if isinstance(overlay, str): - overlay = iio2.image_read(overlay) - if not isinstance(overlay, iio.ANTsImage): - raise ValueError("overlay argument must be an ANTsImage") - if overlay.components > 1: - raise ValueError("overlay cannot have more than one voxel component") - if overlay.dimension != 3: - raise ValueError("Overlay image must have 3 dimensions!") - - if not iio.image_physical_space_consistency(image, overlay): - overlay = reg.resample_image_to_target(overlay, image, interp_type="linear") - - if overlay2 is not None: - if isinstance(overlay2, str): - overlay2 = iio2.image_read(overlay2) - if not isinstance(overlay2, iio.ANTsImage): - raise ValueError("overlay2 argument must be an ANTsImage") - if overlay2.components > 1: - raise ValueError("overlay2 cannot have more than one voxel component") - if overlay2.dimension != 3: - raise ValueError("Overlay2 image must have 3 dimensions!") - - if not iio.image_physical_space_consistency(image2, overlay2): - overlay2 = reg.resample_image_to_target( - overlay2, image2, interp_type="linear" - ) - - if not iio.image_physical_space_consistency(image, image2): - image2 = reg.resample_image_to_target(image2, image, interp_type="linear") - - if image.pixeltype not in {"float", "double"}: - scale = False # turn off scaling if image is discrete - - if image2.pixeltype not in {"float", "double"}: - scale2 = False # turn off scaling if image is discrete - - # reorient images - if reorient != False: - if reorient == True: - reorient = "RPI" - image = image.reorient_image2(reorient) - image2 = image2.reorient_image2(reorient) - if overlay is not None: - overlay = overlay.reorient_image2(reorient) - if overlay2 is not None: - overlay2 = overlay2.reorient_image2(reorient) - - # handle `slices` argument - if xyz is None: - xyz = [int(s / 2) for s in image.shape] - for i in range(3): - if xyz[i] is None: - xyz[i] = int(image.shape[i] / 2) - - # resample image if spacing is very unbalanced - spacing = [s for i, s in enumerate(image.spacing)] - if (max(spacing) / min(spacing)) > 3.0: - new_spacing = (1, 1, 1) - image = image.resample_image(tuple(new_spacing)) - image2 = image2.resample_image_to_target(tuple(new_spacing)) - if overlay is not None: - overlay = overlay.resample_image(tuple(new_spacing)) - if overlay2 is not None: - overlay2 = overlay2.resample_image(tuple(new_spacing)) - xyz = [ - int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) - ] - - # pad images - image, lowpad, uppad = image.pad_image(return_padvals=True) - image2, lowpad2, uppad2 = image2.pad_image(return_padvals=True) - xyz = [v + l for v, l in zip(xyz, lowpad)] - if overlay is not None: - overlay = overlay.pad_image() - if overlay2 is not None: - overlay2 = overlay2.pad_image() - - # handle `domain_image_map` argument - if domain_image_map is not None: - if isinstance(domain_image_map, iio.ANTsImage): - tx = tio2.new_ants_transform( - precision="float", - transform_type="AffineTransform", - dimension=image.dimension, - ) - image = tio.apply_ants_transform_to_image(tx, image, domain_image_map) - image2 = tio.apply_ants_transform_to_image(tx, image2, domain_image_map) - if overlay is not None: - overlay = tio.apply_ants_transform_to_image( - tx, overlay, domain_image_map, interpolation="linear" - ) - if overlay2 is not None: - overlay2 = tio.apply_ants_transform_to_image( - tx, overlay2, domain_image_map, interpolation="linear" - ) - elif isinstance(domain_image_map, (list, tuple)): - # expect an image and transformation - if len(domain_image_map) != 2: - raise ValueError("domain_image_map list or tuple must have length == 2") - - dimg = domain_image_map[0] - if not isinstance(dimg, iio.ANTsImage): - raise ValueError("domain_image_map first entry should be ANTsImage") - - tx = domain_image_map[1] - image = reg.apply_transforms(dimg, image, transform_list=tx) - if overlay is not None: - overlay = reg.apply_transforms( - dimg, overlay, transform_list=tx, interpolator="linear" - ) - - image2 = reg.apply_transforms(dimg, image2, transform_list=tx) - if overlay2 is not None: - overlay2 = reg.apply_transforms( - dimg, overlay2, transform_list=tx, interpolator="linear" - ) - - ## single-channel images ## - if image.components == 1: - - # potentially crop image - if crop: - plotmask = image.get_mask(cleanup=0) - if plotmask.max() == 0: - plotmask += 1 - image = image.crop_image(plotmask) - if overlay is not None: - overlay = overlay.crop_image(plotmask) - - if crop2: - plotmask2 = image2.get_mask(cleanup=0) - if plotmask2.max() == 0: - plotmask2 += 1 - image2 = image2.crop_image(plotmask2) - if overlay2 is not None: - overlay2 = overlay2.crop_image(plotmask2) - - # potentially find dynamic range - if scale == True: - vmin, vmax = image.quantile((0.05, 0.95)) - elif isinstance(scale, (list, tuple)): - if len(scale) != 2: - raise ValueError( - "scale argument must be boolean or list/tuple with two values" - ) - vmin, vmax = image.quantile(scale) - else: - vmin = None - vmax = None - - if scale2 == True: - vmin2, vmax2 = image2.quantile((0.05, 0.95)) - elif isinstance(scale2, (list, tuple)): - if len(scale2) != 2: - raise ValueError( - "scale2 argument must be boolean or list/tuple with two values" - ) - vmin2, vmax2 = image2.quantile(scale2) - else: - vmin2 = None - vmax2 = None - - if not flat: - nrow = 2 - ncol = 4 - else: - if not transpose: - nrow = 2 - ncol = 3 - else: - nrow = 3 - ncol = 2 - - fig = plt.figure( - figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize) - ) - if title is not None: - basey = 0.88 if not flat else 0.66 - basex = 0.5 - fig.suptitle( - title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy - ) - - gs = gridspec.GridSpec( - nrow, - ncol, - wspace=0.0, - hspace=0.0, - top=1.0 - 0.5 / (nrow + 1), - bottom=0.5 / (nrow + 1), - left=0.5 / (ncol + 1), - right=1 - 0.5 / (ncol + 1), - ) - - # pad image to have isotropic array dimensions - image = image.numpy() - if overlay is not None: - overlay = overlay.numpy() - if overlay.dtype not in ["uint8", "uint32"]: - overlay[np.abs(overlay) == 0] = np.nan - - image2 = image2.numpy() - if overlay2 is not None: - overlay2 = overlay2.numpy() - if overlay2.dtype not in ["uint8", "uint32"]: - overlay2[np.abs(overlay2) == 0] = np.nan - - #################### - #################### - yz_slice = reorient_slice(image[xyz[0], :, :], 0) - ax = plt.subplot(gs[0, 0]) - ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlay is not None: - yz_overlay = reorient_slice(overlay[xyz[0], :, :], 0) - ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap) - if xyz_lines: - # add lines - l = mlines.Line2D( - [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], - [xyz_pad, yz_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, yz_slice.shape[1] - xyz_pad], - [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - ax.axis("off") - - ####### - yz_slice2 = reorient_slice(image2[xyz[0], :, :], 0) - if not flat: - ax = plt.subplot(gs[0, 1]) - else: - if not transpose: - ax = plt.subplot(gs[1, 0]) - else: - ax = plt.subplot(gs[0, 1]) - ax.imshow(yz_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) - if overlay2 is not None: - yz_overlay2 = reorient_slice(overlay2[xyz[0], :, :], 0) - ax.imshow(yz_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) - if xyz_lines: - # add lines - l = mlines.Line2D( - [yz_slice2.shape[0] - xyz[1], yz_slice2.shape[0] - xyz[1]], - [xyz_pad, yz_slice2.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, yz_slice2.shape[1] - xyz_pad], - [yz_slice2.shape[1] - xyz[2], yz_slice2.shape[1] - xyz[2]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - ax.axis("off") - #################### - #################### - - xz_slice = reorient_slice(image[:, xyz[1], :], 1) - if not flat: - ax = plt.subplot(gs[0, 2]) - else: - if not transpose: - ax = plt.subplot(gs[0, 1]) - else: - ax = plt.subplot(gs[1, 0]) - ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlay is not None: - xz_overlay = reorient_slice(overlay[:, xyz[1], :], 1) - ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap) - if xyz_lines: - # add lines - l = mlines.Line2D( - [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], - [xyz_pad, xz_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, xz_slice.shape[1] - xyz_pad], - [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - ax.axis("off") - - ####### - xz_slice2 = reorient_slice(image2[:, xyz[1], :], 1) - if not flat: - ax = plt.subplot(gs[0, 3]) - else: - ax = plt.subplot(gs[1, 1]) - ax.imshow(xz_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) - if overlay is not None: - xz_overlay2 = reorient_slice(overlay2[:, xyz[1], :], 1) - ax.imshow(xz_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) - if xyz_lines: - # add lines - l = mlines.Line2D( - [xz_slice2.shape[0] - xyz[0], xz_slice2.shape[0] - xyz[0]], - [xyz_pad, xz_slice2.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, xz_slice2.shape[1] - xyz_pad], - [xz_slice2.shape[1] - xyz[2], xz_slice2.shape[1] - xyz[2]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - ax.axis("off") - - #################### - #################### - xy_slice = reorient_slice(image[:, :, xyz[2]], 2) - if not flat: - ax = plt.subplot(gs[1, 2]) - else: - if not transpose: - ax = plt.subplot(gs[0, 2]) - else: - ax = plt.subplot(gs[2, 0]) - ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlay is not None: - xy_overlay = reorient_slice(overlay[:, :, xyz[2]], 2) - ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap) - if xyz_lines: - # add lines - l = mlines.Line2D( - [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], - [xyz_pad, xy_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, xy_slice.shape[1] - xyz_pad], - [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - ax.axis("off") - - ####### - xy_slice2 = reorient_slice(image2[:, :, xyz[2]], 2) - if not flat: - ax = plt.subplot(gs[1, 3]) - else: - if not transpose: - ax = plt.subplot(gs[1, 2]) - else: - ax = plt.subplot(gs[2, 1]) - ax.imshow(xy_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) - if overlay is not None: - xy_overlay2 = reorient_slice(overlay2[:, :, xyz[2]], 2) - ax.imshow(xy_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) - if xyz_lines: - # add lines - l = mlines.Line2D( - [xy_slice2.shape[0] - xyz[0], xy_slice2.shape[0] - xyz[0]], - [xyz_pad, xy_slice2.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, xy_slice2.shape[1] - xyz_pad], - [xy_slice2.shape[1] - xyz[1], xy_slice2.shape[1] - xyz[1]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - ax.axis("off") - - #################### - #################### - - if not flat: - # empty corner - ax = plt.subplot(gs[1, :2]) - if text is not None: - # add text - left, width = 0.25, 0.5 - bottom, height = 0.25, 0.5 - right = left + width - top = bottom + height - ax.text( - 0.5 * (left + right) + text_dx, - 0.5 * (bottom + top) + text_dy, - text, - horizontalalignment="center", - verticalalignment="center", - fontsize=textfontsize, - color=textfontcolor, - transform=ax.transAxes, - ) - # ax.text(0.5, 0.5) - img_shape = list(image.shape[:-1]) - img_shape[1] *= 2 - ax.imshow(np.zeros(img_shape), cmap="Greys_r") - ax.axis("off") - - ## multi-channel images ## - elif image.components > 1: - raise ValueError("Multi-channel images not currently supported!") - - if filename is not None: - plt.savefig(filename, dpi=dpi, transparent=transparent) - plt.close(fig) - else: - plt.show() - - # turn warnings back to default - warnings.simplefilter("default") - - -def plot_ortho( - image, - overlay=None, - reorient=True, - blend=False, - # xyz arguments - xyz=None, - xyz_lines=True, - xyz_color="red", - xyz_alpha=0.6, - xyz_linewidth=2, - xyz_pad=5, - orient_labels=True, - # base image arguments - alpha=1, - cmap="Greys_r", - # overlay arguments - overlay_cmap="jet", - overlay_alpha=0.9, - cbar=False, - cbar_length=0.8, - cbar_dx=0.0, - cbar_vertical=True, - # background arguments - black_bg=True, - bg_thresh_quant=0.01, - bg_val_quant=0.99, - # scale/crop/domain arguments - crop=False, - scale=False, - domain_image_map=None, - # title arguments - title=None, - titlefontsize=24, - title_dx=0, - title_dy=0, - # 4th panel text arguemnts - text=None, - textfontsize=24, - textfontcolor="white", - text_dx=0, - text_dy=0, - # save & size arguments - filename=None, - dpi=500, - figsize=1.0, - flat=False, - transparent=True, - resample=False, - allow_xyz_change=True, -): - """ - Plot an orthographic view of a 3D image - - Use mask_image and/or threshold_image to preprocess images to be be - overlaid and display the overlays in a given range. See the wiki examples. - - ANTsR function: N/A - - Arguments - --------- - image : ANTsImage - image to plot - - overlay : ANTsImage - image to overlay on base image - - xyz : list or tuple of 3 integers - selects index location on which to center display - if given, solid lines will be drawn to converge at this coordinate. - This is useful for pinpointing a specific location in the image. - - flat : boolean - if true, the ortho image will be plot in one row - if false, the ortho image will be a 2x2 grid with the bottom - left corner blank - - cmap : string - colormap to use for base image. See matplotlib. - - overlay_cmap : string - colormap to use for overlay images, if applicable. See matplotlib. - - overlay_alpha : float - level of transparency for any overlays. Smaller value means - the overlay is more transparent. See matplotlib. - - cbar: boolean - if true, a colorbar will be added to the plot - - cbar_length: float - length of the colorbar relative to the image - - cbar_dx: float - horizontal shift of the colorbar relative to the image - - cbar_vertical: boolean - if true, the colorbar will be vertical, if false, it will be - horizontal underneath the image - - axis : integer - which axis to plot along if image is 3D - - black_bg : boolean - if True, the background of the image(s) will be black. - if False, the background of the image(s) will be determined by the - values `bg_thresh_quant` and `bg_val_quant`. - - bg_thresh_quant : float - if white_bg=True, the background will be determined by thresholding - the image at the `bg_thresh` quantile value and setting the background - intensity to the `bg_val` quantile value. - This value should be in [0, 1] - somewhere around 0.01 is recommended. - - equal to 1 will threshold the entire image - - equal to 0 will threshold none of the image - - bg_val_quant : float - if white_bg=True, the background will be determined by thresholding - the image at the `bg_thresh` quantile value and setting the background - intensity to the `bg_val` quantile value. - This value should be in [0, 1] - - equal to 1 is pure white - - equal to 0 is pure black - - somewhere in between is gray - - domain_image_map : ANTsImage - this input ANTsImage or list of ANTsImage types contains a reference image - `domain_image` and optional reference mapping named `domainMap`. - If supplied, the image(s) to be plotted will be mapped to the domain - image space before plotting - useful for non-standard image orientations. - - crop : boolean - if true, the image(s) will be cropped to their bounding boxes, resulting - in a potentially smaller image size. - if false, the image(s) will not be cropped - - scale : boolean or 2-tuple - if true, nothing will happen to intensities of image(s) and overlay(s) - if false, dynamic range will be maximized when visualizing overlays - if 2-tuple, the image will be dynamically scaled between these quantiles - - title : string - add a title to the plot - - filename : string - if given, the resulting image will be saved to this file - - dpi : integer - determines resolution of image if saved to file. Higher values - result in higher resolution images, but at a cost of having a - larger file size - - resample : resample image in case of unbalanced spacing - - allow_xyz_change : boolean will attempt to adjust xyz after padding - - Example - ------- - >>> import ants - >>> mni = ants.image_read(ants.get_data('mni')) - >>> ants.plot_ortho(mni, xyz=(100,100,100)) - >>> mni2 = mni.threshold_image(7000, mni.max()) - >>> ants.plot_ortho(mni, overlay=mni2) - >>> ants.plot_ortho(mni, overlay=mni2, flat=True) - >>> ants.plot_ortho(mni, overlay=mni2, xyz=(110,110,110), xyz_lines=False, - text='Lines Turned Off', textfontsize=22) - >>> ants.plot_ortho(mni, mni2, xyz=(120,100,100), - text=' Example \nOrtho Text', textfontsize=26, - title='Example Ortho Title', titlefontsize=26) - """ - - def mirror_matrix(x): - return x[::-1, :] - - def rotate270_matrix(x): - return mirror_matrix(x.T) - - def reorient_slice(x, axis): - return rotate270_matrix(x) - - # need this hack because of a weird NaN warning from matplotlib with overlays - warnings.simplefilter("ignore") - - # handle `image` argument - if isinstance(image, str): - image = iio2.image_read(image) - if not isinstance(image, iio.ANTsImage): - raise ValueError("image argument must be an ANTsImage") - if image.dimension != 3: - raise ValueError("Input image must have 3 dimensions!") - - # handle `overlay` argument - if overlay is not None: - vminol = overlay.min() - vmaxol = overlay.max() - if isinstance(overlay, str): - overlay = iio2.image_read(overlay) - if not isinstance(overlay, iio.ANTsImage): - raise ValueError("overlay argument must be an ANTsImage") - if overlay.components > 1: - raise ValueError("overlay cannot have more than one voxel component") - if overlay.dimension != 3: - raise ValueError("Overlay image must have 3 dimensions!") - - if not iio.image_physical_space_consistency(image, overlay): - overlay = reg.resample_image_to_target(overlay, image, interp_type="linear") - - if blend: - if alpha == 1: - alpha = 0.5 - image = image * alpha + overlay * (1 - alpha) - overlay = None - alpha = 1.0 - - if image.pixeltype not in {"float", "double"}: - scale = False # turn off scaling if image is discrete - - # reorient images - if reorient != False: - if reorient == True: - reorient = "RPI" - image = image.reorient_image2("RPI") - if overlay is not None: - overlay = overlay.reorient_image2("RPI") - - # handle `slices` argument - if xyz is None: - xyz = [int(s / 2) for s in image.shape] - for i in range(3): - if xyz[i] is None: - xyz[i] = int(image.shape[i] / 2) - - # resample image if spacing is very unbalanced - spacing = [s for i, s in enumerate(image.spacing)] - if (max(spacing) / min(spacing)) > 3.0 and resample: - new_spacing = (1, 1, 1) - image = image.resample_image(tuple(new_spacing)) - if overlay is not None: - overlay = overlay.resample_image(tuple(new_spacing)) - xyz = [ - int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) - ] - - - # potentially crop image - if crop: - plotmask = image.get_mask(cleanup=0) - if plotmask.max() == 0: - plotmask += 1 - image = image.crop_image(plotmask) - if overlay is not None: - overlay = overlay.crop_image(plotmask) - - # pad images - if True: - image, lowpad, uppad = image.pad_image(return_padvals=True) - if allow_xyz_change: - xyz = [v + l for v, l in zip(xyz, lowpad)] - if overlay is not None: - overlay = overlay.pad_image() - - - # handle `domain_image_map` argument - if domain_image_map is not None: - if isinstance(domain_image_map, iio.ANTsImage): - tx = tio2.new_ants_transform( - precision="float", - transform_type="AffineTransform", - dimension=image.dimension, - ) - image = tio.apply_ants_transform_to_image(tx, image, domain_image_map) - if overlay is not None: - overlay = tio.apply_ants_transform_to_image( - tx, overlay, domain_image_map, interpolation="linear" - ) - elif isinstance(domain_image_map, (list, tuple)): - # expect an image and transformation - if len(domain_image_map) != 2: - raise ValueError("domain_image_map list or tuple must have length == 2") - - dimg = domain_image_map[0] - if not isinstance(dimg, iio.ANTsImage): - raise ValueError("domain_image_map first entry should be ANTsImage") - - tx = domain_image_map[1] - image = reg.apply_transforms(dimg, image, transform_list=tx) - if overlay is not None: - overlay = reg.apply_transforms( - dimg, overlay, transform_list=tx, interpolator="linear" - ) - - ## single-channel images ## - if image.components == 1: - - # potentially find dynamic range - if scale == True: - vmin, vmax = image.quantile((0.05, 0.95)) - elif isinstance(scale, (list, tuple)): - if len(scale) != 2: - raise ValueError( - "scale argument must be boolean or list/tuple with two values" - ) - vmin, vmax = image.quantile(scale) - else: - vmin = None - vmax = None +Functions for plotting ants images +""" - if not flat: - nrow = 2 - ncol = 2 - else: - nrow = 1 - ncol = 3 - - fig = plt.figure(figsize=(9 * figsize, 9 * figsize)) - if title is not None: - basey = 0.88 if not flat else 0.66 - basex = 0.5 - fig.suptitle( - title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy - ) - gs = gridspec.GridSpec( - nrow, - ncol, - wspace=0.0, - hspace=0.0, - top=1.0 - 0.5 / (nrow + 1), - bottom=0.5 / (nrow + 1), - left=0.5 / (ncol + 1), - right=1 - 0.5 / (ncol + 1), - ) +__all__ = [ + "plot" +] - # pad image to have isotropic array dimensions - imageReturn = image.clone() - image = image.numpy() - overlayReturn = None - if overlay is not None: - overlayReturn = overlay.clone() - overlay = overlay.numpy() - if overlay.dtype not in ["uint8", "uint32"]: - overlay = np.ma.masked_where( np.abs(overlay) <= 1e-16, overlay) -# overlay[np.abs(overlay) == 0] = np.nan - - yz_slice = reorient_slice(image[xyz[0], :, :], 0) - ax = plt.subplot(gs[0, 0]) - ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlay is not None: - yz_overlay = reorient_slice(overlay[xyz[0], :, :], 0) - ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol ) - if xyz_lines: - # add lines - l = mlines.Line2D( - [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], - [xyz_pad, yz_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, yz_slice.shape[1] - xyz_pad], - [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - if orient_labels: - ax.text( - 0.5, - 0.98, - "S", - horizontalalignment="center", - verticalalignment="top", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.5, - 0.02, - "I", - horizontalalignment="center", - verticalalignment="bottom", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.98, - 0.5, - "A", - horizontalalignment="right", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.02, - 0.5, - "P", - horizontalalignment="left", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.axis("off") - - xz_slice = reorient_slice(image[:, xyz[1], :], 1) - ax = plt.subplot(gs[0, 1]) - ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlay is not None: - xz_overlay = reorient_slice(overlay[:, xyz[1], :], 1) - ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol ) - - if xyz_lines: - # add lines - l = mlines.Line2D( - [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], - [xyz_pad, xz_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, xz_slice.shape[1] - xyz_pad], - [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - if orient_labels: - ax.text( - 0.5, - 0.98, - "S", - horizontalalignment="center", - verticalalignment="top", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.5, - 0.02, - "I", - horizontalalignment="center", - verticalalignment="bottom", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.98, - 0.5, - "L", - horizontalalignment="right", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.02, - 0.5, - "R", - horizontalalignment="left", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.axis("off") +import fnmatch +import math +import os +import warnings - xy_slice = reorient_slice(image[:, :, xyz[2]], 2) - if not flat: - ax = plt.subplot(gs[1, 1]) - else: - ax = plt.subplot(gs[0, 2]) - im = ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) - if overlay is not None: - xy_overlay = reorient_slice(overlay[:, :, xyz[2]], 2) - im = ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol) - - if xyz_lines: - # add lines - l = mlines.Line2D( - [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], - [xyz_pad, xy_slice.shape[0] - xyz_pad], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - l = mlines.Line2D( - [xyz_pad, xy_slice.shape[1] - xyz_pad], - [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], - color=xyz_color, - alpha=xyz_alpha, - linewidth=xyz_linewidth, - ) - ax.add_line(l) - if orient_labels: - ax.text( - 0.5, - 0.98, - "A", - horizontalalignment="center", - verticalalignment="top", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.5, - 0.02, - "P", - horizontalalignment="center", - verticalalignment="bottom", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.98, - 0.5, - "L", - horizontalalignment="right", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.text( - 0.02, - 0.5, - "R", - horizontalalignment="left", - verticalalignment="center", - fontsize=20 * figsize, - color=textfontcolor, - transform=ax.transAxes, - ) - ax.axis("off") - - if not flat: - # empty corner - ax = plt.subplot(gs[1, 0]) - if text is not None: - # add text - left, width = 0.25, 0.5 - bottom, height = 0.25, 0.5 - right = left + width - top = bottom + height - ax.text( - 0.5 * (left + right) + text_dx, - 0.5 * (bottom + top) + text_dy, - text, - horizontalalignment="center", - verticalalignment="center", - fontsize=textfontsize, - color=textfontcolor, - transform=ax.transAxes, - ) - # ax.text(0.5, 0.5) - ax.imshow(np.zeros(image.shape[:-1]), cmap="Greys_r") - ax.axis("off") - - if cbar: - cbar_start = (1 - cbar_length) / 2 - if cbar_vertical: - cax = fig.add_axes([0.9 + cbar_dx, cbar_start, 0.03, cbar_length]) - cbar_orient = "vertical" - else: - cax = fig.add_axes([cbar_start, 0.08 + cbar_dx, cbar_length, 0.03]) - cbar_orient = "horizontal" - fig.colorbar(im, cax=cax, orientation=cbar_orient) +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes - ## multi-channel images ## - elif image.components > 1: - raise ValueError("Multi-channel images not currently supported!") - if filename is not None: - plt.savefig(filename, dpi=dpi, transparent=transparent) - plt.close(fig) - else: - plt.show() +import numpy as np - # turn warnings back to default - warnings.simplefilter("default") - return { "image": imageReturn, "overlay": overlayReturn } +from .. import registration as reg +from ..core import ants_image as iio +from ..core import ants_image_io as iio2 +from ..core import ants_transform as tio +from ..core import ants_transform_io as tio2 def plot( @@ -2587,92 +522,3 @@ def reorient_slice(x, axis): warnings.simplefilter("default") -def plot_directory( - directory, - recursive=False, - regex="*", - save_prefix="", - save_suffix="", - axis=None, - **kwargs -): - """ - Create and save an ANTsPy plot for every image matching a given regular - expression in a directory, optionally recursively. This is a good function - for quick visualize exploration of all of images in a directory - - ANTsR function: N/A - - Arguments - --------- - directory : string - directory in which to search for images and plot them - - recursive : boolean - If true, this function will search through all directories under - the given directory recursively to make plots. - If false, this function will only create plots for images in the - given directory - - regex : string - regular expression used to filter out certain filenames or suffixes - - save_prefix : string - sub-string that will be appended to the beginning of all saved plot filenames. - Default is to add nothing. - - save_suffix : string - sub-string that will be appended to the end of all saved plot filenames. - Default is add nothing. - - kwargs : keyword arguments - any additional arguments to pass onto the `ants.plot` function. - e.g. overlay, alpha, cmap, etc. See `ants.plot` for more options. - - Example - ------- - >>> import ants - >>> ants.plot_directory(directory='~/desktop/testdir', - recursive=False, regex='*') - """ - - def has_acceptable_suffix(fname): - suffixes = {".nii.gz"} - return sum([fname.endswith(sx) for sx in suffixes]) > 0 - - if directory.startswith("~"): - directory = os.path.expanduser(directory) - - if not os.path.isdir(directory): - raise ValueError("directory %s does not exist!" % directory) - - for root, dirnames, fnames in os.walk(directory): - for fname in fnames: - if fnmatch.fnmatch(fname, regex) and has_acceptable_suffix(fname): - load_fname = os.path.join(root, fname) - fname = fname.replace(".".join(fname.split(".")[1:]), "png") - fname = fname.replace(".png", "%s.png" % save_suffix) - fname = "%s%s" % (save_prefix, fname) - save_fname = os.path.join(root, fname) - img = iio2.image_read(load_fname) - - if axis is None: - axis_range = [i for i in range(img.dimension)] - else: - axis_range = axis if isinstance(axis, (list, tuple)) else [axis] - - if img.dimension > 2: - for axis_idx in axis_range: - filename = save_fname.replace(".png", "_axis%i.png" % axis_idx) - ncol = int(math.sqrt(img.shape[axis_idx])) - plot( - img, - axis=axis_idx, - nslices=img.shape[axis_idx], - ncol=ncol, - filename=filename, - **kwargs - ) - else: - filename = save_fname - plot(img, filename=filename, **kwargs) diff --git a/ants/viz/plot_directory.py b/ants/viz/plot_directory.py new file mode 100644 index 00000000..c8724ed8 --- /dev/null +++ b/ants/viz/plot_directory.py @@ -0,0 +1,122 @@ +""" +Functions for plotting ants images +""" + + +__all__ = [ + "plot_directory" +] + +import fnmatch +import math +import os +import warnings + +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + + +import numpy as np + +from .plot import plot +from .. import registration as reg +from ..core import ants_image as iio +from ..core import ants_image_io as iio2 +from ..core import ants_transform as tio +from ..core import ants_transform_io as tio2 + +def plot_directory( + directory, + recursive=False, + regex="*", + save_prefix="", + save_suffix="", + axis=None, + **kwargs +): + """ + Create and save an ANTsPy plot for every image matching a given regular + expression in a directory, optionally recursively. This is a good function + for quick visualize exploration of all of images in a directory + + ANTsR function: N/A + + Arguments + --------- + directory : string + directory in which to search for images and plot them + + recursive : boolean + If true, this function will search through all directories under + the given directory recursively to make plots. + If false, this function will only create plots for images in the + given directory + + regex : string + regular expression used to filter out certain filenames or suffixes + + save_prefix : string + sub-string that will be appended to the beginning of all saved plot filenames. + Default is to add nothing. + + save_suffix : string + sub-string that will be appended to the end of all saved plot filenames. + Default is add nothing. + + kwargs : keyword arguments + any additional arguments to pass onto the `ants.plot` function. + e.g. overlay, alpha, cmap, etc. See `ants.plot` for more options. + + Example + ------- + >>> import ants + >>> ants.plot_directory(directory='~/desktop/testdir', + recursive=False, regex='*') + """ + + def has_acceptable_suffix(fname): + suffixes = {".nii.gz"} + return sum([fname.endswith(sx) for sx in suffixes]) > 0 + + if directory.startswith("~"): + directory = os.path.expanduser(directory) + + if not os.path.isdir(directory): + raise ValueError("directory %s does not exist!" % directory) + + for root, dirnames, fnames in os.walk(directory): + for fname in fnames: + if fnmatch.fnmatch(fname, regex) and has_acceptable_suffix(fname): + load_fname = os.path.join(root, fname) + fname = fname.replace(".".join(fname.split(".")[1:]), "png") + fname = fname.replace(".png", "%s.png" % save_suffix) + fname = "%s%s" % (save_prefix, fname) + save_fname = os.path.join(root, fname) + img = iio2.image_read(load_fname) + + if axis is None: + axis_range = [i for i in range(img.dimension)] + else: + axis_range = axis if isinstance(axis, (list, tuple)) else [axis] + + if img.dimension > 2: + for axis_idx in axis_range: + filename = save_fname.replace(".png", "_axis%i.png" % axis_idx) + ncol = int(math.sqrt(img.shape[axis_idx])) + plot( + img, + axis=axis_idx, + nslices=img.shape[axis_idx], + ncol=ncol, + filename=filename, + **kwargs + ) + else: + filename = save_fname + plot(img, filename=filename, **kwargs) diff --git a/ants/viz/plot_grid.py b/ants/viz/plot_grid.py new file mode 100644 index 00000000..524fbac2 --- /dev/null +++ b/ants/viz/plot_grid.py @@ -0,0 +1,349 @@ +""" +Functions for plotting ants images +""" + + +__all__ = [ + "plot_grid" +] + +import fnmatch +import math +import os +import warnings + +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + + +import numpy as np + +from .. import registration as reg +from ..core import ants_image as iio +from ..core import ants_image_io as iio2 +from ..core import ants_transform as tio +from ..core import ants_transform_io as tio2 + +def plot_grid( + images, + slices=None, + axes=2, + # general figure arguments + figsize=1.0, + rpad=0, + cpad=0, + vmin=None, + vmax=None, + colorbar=True, + cmap="Greys_r", + # title arguments + title=None, + tfontsize=20, + title_dx=0, + title_dy=0, + # row arguments + rlabels=None, + rfontsize=14, + rfontcolor="white", + rfacecolor="black", + # column arguments + clabels=None, + cfontsize=14, + cfontcolor="white", + cfacecolor="black", + # save arguments + filename=None, + dpi=400, + transparent=True, + # other args + **kwargs +): + """ + Plot a collection of images in an arbitrarily-defined grid + + Matplotlib named colors: https://matplotlib.org/examples/color/named_colors.html + + Arguments + --------- + images : list of ANTsImage types + image(s) to plot. + if one image, this image will be used for all grid locations. + if multiple images, they should be arrange in a list the same + shape as the `gridsize` argument. + + slices : integer or list of integers + slice indices to plot + if one integer, this slice index will be used for all images + if multiple integers, they should be arranged in a list the same + shape as the `gridsize` argument + + axes : integer or list of integers + axis or axes along which to plot image slices + if one integer, this axis will be used for all images + if multiple integers, they should be arranged in a list the same + shape as the `gridsize` argument + + Example + ------- + >>> import ants + >>> import numpy as np + >>> mni1 = ants.image_read(ants.get_data('mni')) + >>> mni2 = mni1.smooth_image(1.) + >>> mni3 = mni1.smooth_image(2.) + >>> mni4 = mni1.smooth_image(3.) + >>> images = np.asarray([[mni1, mni2], + ... [mni3, mni4]]) + >>> slices = np.asarray([[100, 100], + ... [100, 100]]) + >>> #axes = np.asarray([[2,2],[2,2]]) + >>> # standard plotting + >>> ants.plot_grid(images=images, slices=slices, title='2x2 Grid') + >>> ants.plot_grid(images.reshape(1,4), slices.reshape(1,4), title='1x4 Grid') + >>> ants.plot_grid(images.reshape(4,1), slices.reshape(4,1), title='4x1 Grid') + + >>> # Padding between rows and/or columns + >>> ants.plot_grid(images, slices, cpad=0.02, title='Col Padding') + >>> ants.plot_grid(images, slices, rpad=0.02, title='Row Padding') + >>> ants.plot_grid(images, slices, rpad=0.02, cpad=0.02, title='Row and Col Padding') + + >>> # Adding plain row and/or column labels + >>> ants.plot_grid(images, slices, title='Adding Row Labels', rlabels=['Row #1', 'Row #2']) + >>> ants.plot_grid(images, slices, title='Adding Col Labels', clabels=['Col #1', 'Col #2']) + >>> ants.plot_grid(images, slices, title='Row and Col Labels', + rlabels=['Row 1', 'Row 2'], clabels=['Col 1', 'Col 2']) + + >>> # Making a publication-quality image + >>> images = np.asarray([[mni1, mni2, mni2], + ... [mni3, mni4, mni4]]) + >>> slices = np.asarray([[100, 100, 100], + ... [100, 100, 100]]) + >>> axes = np.asarray([[0, 1, 2], + [0, 1, 2]]) + >>> ants.plot_grid(images, slices, axes, title='Publication Figures with ANTsPy', + tfontsize=20, title_dy=0.03, title_dx=-0.04, + rlabels=['Row 1', 'Row 2'], + clabels=['Col 1', 'Col 2', 'Col 3'], + rfontsize=16, cfontsize=16) + """ + + def mirror_matrix(x): + return x[::-1, :] + + def rotate270_matrix(x): + return mirror_matrix(x.T) + + def rotate180_matrix(x): + return x[::-1, ::-1] + + def rotate90_matrix(x): + return mirror_matrix(x).T + + def flip_matrix(x): + return mirror_matrix(rotate180_matrix(x)) + + def reorient_slice(x, axis): + if axis != 1: + x = rotate90_matrix(x) + if axis == 1: + x = rotate90_matrix(x) + x = mirror_matrix(x) + return x + + def slice_image(img, axis, idx): + if axis == 0: + return img[idx, :, :] + elif axis == 1: + return img[:, idx, :] + elif axis == 2: + return img[:, :, idx] + elif axis == -1: + return img[:, :, idx] + elif axis == -2: + return img[:, idx, :] + elif axis == -3: + return img[idx, :, :] + else: + raise ValueError("axis %i not valid" % axis) + + if isinstance(images, np.ndarray): + images = images.tolist() + if not isinstance(images, list): + raise ValueError("images argument must be of type list") + if not isinstance(images[0], list): + images = [images] + + if isinstance(slices, int): + one_slice = True + if isinstance(slices, np.ndarray): + slices = slices.tolist() + if isinstance(slices, list): + one_slice = False + if not isinstance(slices[0], list): + slices = [slices] + nslicerow = len(slices) + nslicecol = len(slices[0]) + + nrow = len(images) + ncol = len(images[0]) + + if rlabels is None: + rlabels = [None] * nrow + if clabels is None: + clabels = [None] * ncol + + if not one_slice: + if (nrow != nslicerow) or (ncol != nslicecol): + raise ValueError( + "`images` arg shape (%i,%i) must equal `slices` arg shape (%i,%i)!" + % (nrow, ncol, nslicerow, nslicecol) + ) + + fig = plt.figure(figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize)) + + if title is not None: + basex = 0.5 + basey = 0.9 if clabels[0] is None else 0.95 + fig.suptitle(title, fontsize=tfontsize, x=basex + title_dx, y=basey + title_dy) + + if (cpad > 0) and (rpad > 0): + bothgridpad = max(cpad, rpad) + cpad = 0 + rpad = 0 + else: + bothgridpad = 0.0 + + gs = gridspec.GridSpec( + nrow, + ncol, + wspace=bothgridpad, + hspace=0.0, + top=1.0 - 0.5 / (nrow + 1), + bottom=0.5 / (nrow + 1) + cpad, + left=0.5 / (ncol + 1) + rpad, + right=1 - 0.5 / (ncol + 1), + ) + + if isinstance(vmin, (int, float)): + vmins = [vmin] * nrow + elif vmin is None: + vmins = [None] * nrow + else: + vmins = vmin + + if isinstance(vmax, (int, float)): + vmaxs = [vmax] * nrow + elif vmax is None: + vmaxs = [None] * nrow + else: + vmaxs = vmax + + if isinstance(cmap, str): + cmaps = [cmap] * nrow + elif cmap is None: + cmaps = [None] * nrow + else: + cmaps = cmap + + for rowidx, rvmin, rvmax, rcmap in zip(range(nrow), vmins, vmaxs, cmaps): + for colidx in range(ncol): + ax = plt.subplot(gs[rowidx, colidx]) + + if colidx == 0: + if rlabels[rowidx] is not None: + bottom, height = 0.25, 0.5 + top = bottom + height + # add label text + ax.text( + -0.07, + 0.5 * (bottom + top), + rlabels[rowidx], + horizontalalignment="right", + verticalalignment="center", + rotation="vertical", + transform=ax.transAxes, + color=rfontcolor, + fontsize=rfontsize, + ) + + # add label background + extra = 0.3 if rowidx == 0 else 0.0 + + rect = patches.Rectangle( + (-0.3, 0), + 0.3, + 1.0 + extra, + facecolor=rfacecolor, + alpha=1.0, + transform=ax.transAxes, + clip_on=False, + ) + ax.add_patch(rect) + + if rowidx == 0: + if clabels[colidx] is not None: + bottom, height = 0.25, 0.5 + left, width = 0.25, 0.5 + right = left + width + top = bottom + height + ax.text( + 0.5 * (left + right), + 0.09 + top + bottom, + clabels[colidx], + horizontalalignment="center", + verticalalignment="center", + rotation="horizontal", + transform=ax.transAxes, + color=cfontcolor, + fontsize=cfontsize, + ) + + # add label background + rect = patches.Rectangle( + (0, 1.0), + 1.0, + 0.3, + facecolor=cfacecolor, + alpha=1.0, + transform=ax.transAxes, + clip_on=False, + ) + ax.add_patch(rect) + + tmpimg = images[rowidx][colidx] + if isinstance(axes, int): + tmpaxis = axes + else: + tmpaxis = axes[rowidx][colidx] + sliceidx = slices[rowidx][colidx] if not one_slice else slices + tmpslice = slice_image(tmpimg, tmpaxis, sliceidx) + tmpslice = reorient_slice(tmpslice, tmpaxis) + im = ax.imshow(tmpslice, cmap=rcmap, aspect="auto", vmin=rvmin, vmax=rvmax) + ax.axis("off") + + # A colorbar solution with make_axes_locatable will not allow y-scaling of the colorbar. + # from mpl_toolkits.axes_grid1 import make_axes_locatable + # divider = make_axes_locatable(ax) + # cax = divider.append_axes('right', size='5%', pad=0.05) + if colorbar: + axins = inset_axes(ax, + width="5%", # width = 5% of parent_bbox width + height="90%", # height : 50% + loc='center left', + bbox_to_anchor=(1.03, 0., 1, 1), + bbox_transform=ax.transAxes, + borderpad=0, + ) + fig.colorbar(im, cax=axins, orientation='vertical') + + if filename is not None: + filename = os.path.expanduser(filename) + plt.savefig(filename, dpi=dpi, transparent=transparent, bbox_inches="tight") + plt.close(fig) + else: + plt.show() diff --git a/ants/viz/plot_hist.py b/ants/viz/plot_hist.py new file mode 100644 index 00000000..11789f28 --- /dev/null +++ b/ants/viz/plot_hist.py @@ -0,0 +1,77 @@ +""" +Functions for plotting ants images +""" + + +__all__ = [ + "plot_hist" +] + +import fnmatch +import math +import os +import warnings + +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + + +import numpy as np + +from .. import registration as reg +from ..core import ants_image as iio +from ..core import ants_image_io as iio2 +from ..core import ants_transform as tio +from ..core import ants_transform_io as tio2 + +def plot_hist( + image, + threshold=0.0, + fit_line=False, + normfreq=True, + ## plot label arguments + title=None, + grid=True, + xlabel=None, + ylabel=None, + ## other plot arguments + facecolor="green", + alpha=0.75, +): + """ + Plot a histogram from an ANTsImage + + Arguments + --------- + image : ANTsImage + image from which histogram will be created + """ + img_arr = image.numpy().flatten() + img_arr = img_arr[np.abs(img_arr) > threshold] + + if normfreq != False: + normfreq = 1.0 if normfreq == True else normfreq + n, bins, patches = plt.hist( + img_arr, 50, facecolor=facecolor, alpha=alpha + ) + + if fit_line: + # add a 'best fit' line + y = mlab.normpdf(bins, img_arr.mean(), img_arr.std()) + l = plt.plot(bins, y, "r--", linewidth=1) + + if xlabel is not None: + plt.xlabel(xlabel) + if ylabel is not None: + plt.ylabel(ylabel) + if title is not None: + plt.title(title) + + plt.grid(grid) + plt.show() diff --git a/ants/viz/plot_ortho.py b/ants/viz/plot_ortho.py new file mode 100644 index 00000000..0fa1d76a --- /dev/null +++ b/ants/viz/plot_ortho.py @@ -0,0 +1,629 @@ +""" +Functions for plotting ants images +""" + + +__all__ = [ + "plot_ortho" +] + +import fnmatch +import math +import os +import warnings + +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + + +import numpy as np + +from .. import registration as reg +from ..core import ants_image as iio +from ..core import ants_image_io as iio2 +from ..core import ants_transform as tio +from ..core import ants_transform_io as tio2 + +def plot_ortho( + image, + overlay=None, + reorient=True, + blend=False, + # xyz arguments + xyz=None, + xyz_lines=True, + xyz_color="red", + xyz_alpha=0.6, + xyz_linewidth=2, + xyz_pad=5, + orient_labels=True, + # base image arguments + alpha=1, + cmap="Greys_r", + # overlay arguments + overlay_cmap="jet", + overlay_alpha=0.9, + cbar=False, + cbar_length=0.8, + cbar_dx=0.0, + cbar_vertical=True, + # background arguments + black_bg=True, + bg_thresh_quant=0.01, + bg_val_quant=0.99, + # scale/crop/domain arguments + crop=False, + scale=False, + domain_image_map=None, + # title arguments + title=None, + titlefontsize=24, + title_dx=0, + title_dy=0, + # 4th panel text arguemnts + text=None, + textfontsize=24, + textfontcolor="white", + text_dx=0, + text_dy=0, + # save & size arguments + filename=None, + dpi=500, + figsize=1.0, + flat=False, + transparent=True, + resample=False, + allow_xyz_change=True, +): + """ + Plot an orthographic view of a 3D image + + Use mask_image and/or threshold_image to preprocess images to be be + overlaid and display the overlays in a given range. See the wiki examples. + + ANTsR function: N/A + + Arguments + --------- + image : ANTsImage + image to plot + + overlay : ANTsImage + image to overlay on base image + + xyz : list or tuple of 3 integers + selects index location on which to center display + if given, solid lines will be drawn to converge at this coordinate. + This is useful for pinpointing a specific location in the image. + + flat : boolean + if true, the ortho image will be plot in one row + if false, the ortho image will be a 2x2 grid with the bottom + left corner blank + + cmap : string + colormap to use for base image. See matplotlib. + + overlay_cmap : string + colormap to use for overlay images, if applicable. See matplotlib. + + overlay_alpha : float + level of transparency for any overlays. Smaller value means + the overlay is more transparent. See matplotlib. + + cbar: boolean + if true, a colorbar will be added to the plot + + cbar_length: float + length of the colorbar relative to the image + + cbar_dx: float + horizontal shift of the colorbar relative to the image + + cbar_vertical: boolean + if true, the colorbar will be vertical, if false, it will be + horizontal underneath the image + + axis : integer + which axis to plot along if image is 3D + + black_bg : boolean + if True, the background of the image(s) will be black. + if False, the background of the image(s) will be determined by the + values `bg_thresh_quant` and `bg_val_quant`. + + bg_thresh_quant : float + if white_bg=True, the background will be determined by thresholding + the image at the `bg_thresh` quantile value and setting the background + intensity to the `bg_val` quantile value. + This value should be in [0, 1] - somewhere around 0.01 is recommended. + - equal to 1 will threshold the entire image + - equal to 0 will threshold none of the image + + bg_val_quant : float + if white_bg=True, the background will be determined by thresholding + the image at the `bg_thresh` quantile value and setting the background + intensity to the `bg_val` quantile value. + This value should be in [0, 1] + - equal to 1 is pure white + - equal to 0 is pure black + - somewhere in between is gray + + domain_image_map : ANTsImage + this input ANTsImage or list of ANTsImage types contains a reference image + `domain_image` and optional reference mapping named `domainMap`. + If supplied, the image(s) to be plotted will be mapped to the domain + image space before plotting - useful for non-standard image orientations. + + crop : boolean + if true, the image(s) will be cropped to their bounding boxes, resulting + in a potentially smaller image size. + if false, the image(s) will not be cropped + + scale : boolean or 2-tuple + if true, nothing will happen to intensities of image(s) and overlay(s) + if false, dynamic range will be maximized when visualizing overlays + if 2-tuple, the image will be dynamically scaled between these quantiles + + title : string + add a title to the plot + + filename : string + if given, the resulting image will be saved to this file + + dpi : integer + determines resolution of image if saved to file. Higher values + result in higher resolution images, but at a cost of having a + larger file size + + resample : resample image in case of unbalanced spacing + + allow_xyz_change : boolean will attempt to adjust xyz after padding + + Example + ------- + >>> import ants + >>> mni = ants.image_read(ants.get_data('mni')) + >>> ants.plot_ortho(mni, xyz=(100,100,100)) + >>> mni2 = mni.threshold_image(7000, mni.max()) + >>> ants.plot_ortho(mni, overlay=mni2) + >>> ants.plot_ortho(mni, overlay=mni2, flat=True) + >>> ants.plot_ortho(mni, overlay=mni2, xyz=(110,110,110), xyz_lines=False, + text='Lines Turned Off', textfontsize=22) + >>> ants.plot_ortho(mni, mni2, xyz=(120,100,100), + text=' Example \nOrtho Text', textfontsize=26, + title='Example Ortho Title', titlefontsize=26) + """ + + def mirror_matrix(x): + return x[::-1, :] + + def rotate270_matrix(x): + return mirror_matrix(x.T) + + def reorient_slice(x, axis): + return rotate270_matrix(x) + + # need this hack because of a weird NaN warning from matplotlib with overlays + warnings.simplefilter("ignore") + + # handle `image` argument + if isinstance(image, str): + image = iio2.image_read(image) + if not isinstance(image, iio.ANTsImage): + raise ValueError("image argument must be an ANTsImage") + if image.dimension != 3: + raise ValueError("Input image must have 3 dimensions!") + + # handle `overlay` argument + if overlay is not None: + vminol = overlay.min() + vmaxol = overlay.max() + if isinstance(overlay, str): + overlay = iio2.image_read(overlay) + if not isinstance(overlay, iio.ANTsImage): + raise ValueError("overlay argument must be an ANTsImage") + if overlay.components > 1: + raise ValueError("overlay cannot have more than one voxel component") + if overlay.dimension != 3: + raise ValueError("Overlay image must have 3 dimensions!") + + if not iio.image_physical_space_consistency(image, overlay): + overlay = reg.resample_image_to_target(overlay, image, interp_type="linear") + + if blend: + if alpha == 1: + alpha = 0.5 + image = image * alpha + overlay * (1 - alpha) + overlay = None + alpha = 1.0 + + if image.pixeltype not in {"float", "double"}: + scale = False # turn off scaling if image is discrete + + # reorient images + if reorient != False: + if reorient == True: + reorient = "RPI" + image = image.reorient_image2("RPI") + if overlay is not None: + overlay = overlay.reorient_image2("RPI") + + # handle `slices` argument + if xyz is None: + xyz = [int(s / 2) for s in image.shape] + for i in range(3): + if xyz[i] is None: + xyz[i] = int(image.shape[i] / 2) + + # resample image if spacing is very unbalanced + spacing = [s for i, s in enumerate(image.spacing)] + if (max(spacing) / min(spacing)) > 3.0 and resample: + new_spacing = (1, 1, 1) + image = image.resample_image(tuple(new_spacing)) + if overlay is not None: + overlay = overlay.resample_image(tuple(new_spacing)) + xyz = [ + int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) + ] + + + # potentially crop image + if crop: + plotmask = image.get_mask(cleanup=0) + if plotmask.max() == 0: + plotmask += 1 + image = image.crop_image(plotmask) + if overlay is not None: + overlay = overlay.crop_image(plotmask) + + # pad images + if True: + image, lowpad, uppad = image.pad_image(return_padvals=True) + if allow_xyz_change: + xyz = [v + l for v, l in zip(xyz, lowpad)] + if overlay is not None: + overlay = overlay.pad_image() + + + # handle `domain_image_map` argument + if domain_image_map is not None: + if isinstance(domain_image_map, iio.ANTsImage): + tx = tio2.new_ants_transform( + precision="float", + transform_type="AffineTransform", + dimension=image.dimension, + ) + image = tio.apply_ants_transform_to_image(tx, image, domain_image_map) + if overlay is not None: + overlay = tio.apply_ants_transform_to_image( + tx, overlay, domain_image_map, interpolation="linear" + ) + elif isinstance(domain_image_map, (list, tuple)): + # expect an image and transformation + if len(domain_image_map) != 2: + raise ValueError("domain_image_map list or tuple must have length == 2") + + dimg = domain_image_map[0] + if not isinstance(dimg, iio.ANTsImage): + raise ValueError("domain_image_map first entry should be ANTsImage") + + tx = domain_image_map[1] + image = reg.apply_transforms(dimg, image, transform_list=tx) + if overlay is not None: + overlay = reg.apply_transforms( + dimg, overlay, transform_list=tx, interpolator="linear" + ) + + ## single-channel images ## + if image.components == 1: + + # potentially find dynamic range + if scale == True: + vmin, vmax = image.quantile((0.05, 0.95)) + elif isinstance(scale, (list, tuple)): + if len(scale) != 2: + raise ValueError( + "scale argument must be boolean or list/tuple with two values" + ) + vmin, vmax = image.quantile(scale) + else: + vmin = None + vmax = None + + if not flat: + nrow = 2 + ncol = 2 + else: + nrow = 1 + ncol = 3 + + fig = plt.figure(figsize=(9 * figsize, 9 * figsize)) + if title is not None: + basey = 0.88 if not flat else 0.66 + basex = 0.5 + fig.suptitle( + title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy + ) + + gs = gridspec.GridSpec( + nrow, + ncol, + wspace=0.0, + hspace=0.0, + top=1.0 - 0.5 / (nrow + 1), + bottom=0.5 / (nrow + 1), + left=0.5 / (ncol + 1), + right=1 - 0.5 / (ncol + 1), + ) + + # pad image to have isotropic array dimensions + imageReturn = image.clone() + image = image.numpy() + overlayReturn = None + if overlay is not None: + overlayReturn = overlay.clone() + overlay = overlay.numpy() + if overlay.dtype not in ["uint8", "uint32"]: + overlay = np.ma.masked_where( np.abs(overlay) <= 1e-16, overlay) +# overlay[np.abs(overlay) == 0] = np.nan + + yz_slice = reorient_slice(image[xyz[0], :, :], 0) + ax = plt.subplot(gs[0, 0]) + ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlay is not None: + yz_overlay = reorient_slice(overlay[xyz[0], :, :], 0) + ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol ) + if xyz_lines: + # add lines + l = mlines.Line2D( + [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], + [xyz_pad, yz_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, yz_slice.shape[1] - xyz_pad], + [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + if orient_labels: + ax.text( + 0.5, + 0.98, + "S", + horizontalalignment="center", + verticalalignment="top", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.5, + 0.02, + "I", + horizontalalignment="center", + verticalalignment="bottom", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.98, + 0.5, + "A", + horizontalalignment="right", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.02, + 0.5, + "P", + horizontalalignment="left", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.axis("off") + + xz_slice = reorient_slice(image[:, xyz[1], :], 1) + ax = plt.subplot(gs[0, 1]) + ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlay is not None: + xz_overlay = reorient_slice(overlay[:, xyz[1], :], 1) + ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol ) + + if xyz_lines: + # add lines + l = mlines.Line2D( + [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], + [xyz_pad, xz_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, xz_slice.shape[1] - xyz_pad], + [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + if orient_labels: + ax.text( + 0.5, + 0.98, + "S", + horizontalalignment="center", + verticalalignment="top", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.5, + 0.02, + "I", + horizontalalignment="center", + verticalalignment="bottom", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.98, + 0.5, + "L", + horizontalalignment="right", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.02, + 0.5, + "R", + horizontalalignment="left", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.axis("off") + + xy_slice = reorient_slice(image[:, :, xyz[2]], 2) + if not flat: + ax = plt.subplot(gs[1, 1]) + else: + ax = plt.subplot(gs[0, 2]) + im = ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlay is not None: + xy_overlay = reorient_slice(overlay[:, :, xyz[2]], 2) + im = ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap, vmin=vminol, vmax=vmaxol) + + if xyz_lines: + # add lines + l = mlines.Line2D( + [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], + [xyz_pad, xy_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, xy_slice.shape[1] - xyz_pad], + [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + if orient_labels: + ax.text( + 0.5, + 0.98, + "A", + horizontalalignment="center", + verticalalignment="top", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.5, + 0.02, + "P", + horizontalalignment="center", + verticalalignment="bottom", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.98, + 0.5, + "L", + horizontalalignment="right", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.02, + 0.5, + "R", + horizontalalignment="left", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.axis("off") + + if not flat: + # empty corner + ax = plt.subplot(gs[1, 0]) + if text is not None: + # add text + left, width = 0.25, 0.5 + bottom, height = 0.25, 0.5 + right = left + width + top = bottom + height + ax.text( + 0.5 * (left + right) + text_dx, + 0.5 * (bottom + top) + text_dy, + text, + horizontalalignment="center", + verticalalignment="center", + fontsize=textfontsize, + color=textfontcolor, + transform=ax.transAxes, + ) + # ax.text(0.5, 0.5) + ax.imshow(np.zeros(image.shape[:-1]), cmap="Greys_r") + ax.axis("off") + + if cbar: + cbar_start = (1 - cbar_length) / 2 + if cbar_vertical: + cax = fig.add_axes([0.9 + cbar_dx, cbar_start, 0.03, cbar_length]) + cbar_orient = "vertical" + else: + cax = fig.add_axes([cbar_start, 0.08 + cbar_dx, cbar_length, 0.03]) + cbar_orient = "horizontal" + fig.colorbar(im, cax=cax, orientation=cbar_orient) + + ## multi-channel images ## + elif image.components > 1: + raise ValueError("Multi-channel images not currently supported!") + + if filename is not None: + plt.savefig(filename, dpi=dpi, transparent=transparent) + plt.close(fig) + else: + plt.show() + + # turn warnings back to default + warnings.simplefilter("default") + return { "image": imageReturn, "overlay": overlayReturn } + diff --git a/ants/viz/plot_ortho_double.py b/ants/viz/plot_ortho_double.py new file mode 100644 index 00000000..6cc3fe5c --- /dev/null +++ b/ants/viz/plot_ortho_double.py @@ -0,0 +1,566 @@ +""" +Functions for plotting ants images +""" + + +__all__ = [ + "plot_ortho_double" +] + +import fnmatch +import math +import os +import warnings + +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + + +import numpy as np + +from .. import registration as reg +from ..core import ants_image as iio +from ..core import ants_image_io as iio2 +from ..core import ants_transform as tio +from ..core import ants_transform_io as tio2 + +def plot_ortho_double( + image, + image2, + overlay=None, + overlay2=None, + reorient=True, + # xyz arguments + xyz=None, + xyz_lines=True, + xyz_color="red", + xyz_alpha=0.6, + xyz_linewidth=2, + xyz_pad=5, + # base image arguments + cmap="Greys_r", + alpha=1, + cmap2="Greys_r", + alpha2=1, + # overlay arguments + overlay_cmap="jet", + overlay_alpha=0.9, + overlay_cmap2="jet", + overlay_alpha2=0.9, + # background arguments + black_bg=True, + bg_thresh_quant=0.01, + bg_val_quant=0.99, + # scale/crop/domain arguments + crop=False, + scale=False, + crop2=False, + scale2=True, + domain_image_map=None, + # title arguments + title=None, + titlefontsize=24, + title_dx=0, + title_dy=0, + # 4th panel text arguemnts + text=None, + textfontsize=24, + textfontcolor="white", + text_dx=0, + text_dy=0, + # save & size arguments + filename=None, + dpi=500, + figsize=1.0, + flat=True, + transpose=False, + transparent=True, +): + """ + Create a pair of orthographic plots with overlays. + + Use mask_image and/or threshold_image to preprocess images to be be + overlaid and display the overlays in a given range. See the wiki examples. + + Example + ------- + >>> import ants + >>> mni = ants.image_read(ants.get_data('mni')) + >>> ch2 = ants.image_read(ants.get_data('ch2')) + >>> ants.plot_ortho_double(mni, ch2) + """ + + def mirror_matrix(x): + return x[::-1, :] + + def rotate270_matrix(x): + return mirror_matrix(x.T) + + def reorient_slice(x, axis): + return rotate270_matrix(x) + + # need this hack because of a weird NaN warning from matplotlib with overlays + warnings.simplefilter("ignore") + + # handle `image` argument + if isinstance(image, str): + image = iio2.image_read(image) + if not isinstance(image, iio.ANTsImage): + raise ValueError("image argument must be an ANTsImage") + if image.dimension != 3: + raise ValueError("Input image must have 3 dimensions!") + + if isinstance(image2, str): + image2 = iio2.image_read(image2) + if not isinstance(image2, iio.ANTsImage): + raise ValueError("image2 argument must be an ANTsImage") + if image2.dimension != 3: + raise ValueError("Input image2 must have 3 dimensions!") + + # handle `overlay` argument + if overlay is not None: + if isinstance(overlay, str): + overlay = iio2.image_read(overlay) + if not isinstance(overlay, iio.ANTsImage): + raise ValueError("overlay argument must be an ANTsImage") + if overlay.components > 1: + raise ValueError("overlay cannot have more than one voxel component") + if overlay.dimension != 3: + raise ValueError("Overlay image must have 3 dimensions!") + + if not iio.image_physical_space_consistency(image, overlay): + overlay = reg.resample_image_to_target(overlay, image, interp_type="linear") + + if overlay2 is not None: + if isinstance(overlay2, str): + overlay2 = iio2.image_read(overlay2) + if not isinstance(overlay2, iio.ANTsImage): + raise ValueError("overlay2 argument must be an ANTsImage") + if overlay2.components > 1: + raise ValueError("overlay2 cannot have more than one voxel component") + if overlay2.dimension != 3: + raise ValueError("Overlay2 image must have 3 dimensions!") + + if not iio.image_physical_space_consistency(image2, overlay2): + overlay2 = reg.resample_image_to_target( + overlay2, image2, interp_type="linear" + ) + + if not iio.image_physical_space_consistency(image, image2): + image2 = reg.resample_image_to_target(image2, image, interp_type="linear") + + if image.pixeltype not in {"float", "double"}: + scale = False # turn off scaling if image is discrete + + if image2.pixeltype not in {"float", "double"}: + scale2 = False # turn off scaling if image is discrete + + # reorient images + if reorient != False: + if reorient == True: + reorient = "RPI" + image = image.reorient_image2(reorient) + image2 = image2.reorient_image2(reorient) + if overlay is not None: + overlay = overlay.reorient_image2(reorient) + if overlay2 is not None: + overlay2 = overlay2.reorient_image2(reorient) + + # handle `slices` argument + if xyz is None: + xyz = [int(s / 2) for s in image.shape] + for i in range(3): + if xyz[i] is None: + xyz[i] = int(image.shape[i] / 2) + + # resample image if spacing is very unbalanced + spacing = [s for i, s in enumerate(image.spacing)] + if (max(spacing) / min(spacing)) > 3.0: + new_spacing = (1, 1, 1) + image = image.resample_image(tuple(new_spacing)) + image2 = image2.resample_image_to_target(tuple(new_spacing)) + if overlay is not None: + overlay = overlay.resample_image(tuple(new_spacing)) + if overlay2 is not None: + overlay2 = overlay2.resample_image(tuple(new_spacing)) + xyz = [ + int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) + ] + + # pad images + image, lowpad, uppad = image.pad_image(return_padvals=True) + image2, lowpad2, uppad2 = image2.pad_image(return_padvals=True) + xyz = [v + l for v, l in zip(xyz, lowpad)] + if overlay is not None: + overlay = overlay.pad_image() + if overlay2 is not None: + overlay2 = overlay2.pad_image() + + # handle `domain_image_map` argument + if domain_image_map is not None: + if isinstance(domain_image_map, iio.ANTsImage): + tx = tio2.new_ants_transform( + precision="float", + transform_type="AffineTransform", + dimension=image.dimension, + ) + image = tio.apply_ants_transform_to_image(tx, image, domain_image_map) + image2 = tio.apply_ants_transform_to_image(tx, image2, domain_image_map) + if overlay is not None: + overlay = tio.apply_ants_transform_to_image( + tx, overlay, domain_image_map, interpolation="linear" + ) + if overlay2 is not None: + overlay2 = tio.apply_ants_transform_to_image( + tx, overlay2, domain_image_map, interpolation="linear" + ) + elif isinstance(domain_image_map, (list, tuple)): + # expect an image and transformation + if len(domain_image_map) != 2: + raise ValueError("domain_image_map list or tuple must have length == 2") + + dimg = domain_image_map[0] + if not isinstance(dimg, iio.ANTsImage): + raise ValueError("domain_image_map first entry should be ANTsImage") + + tx = domain_image_map[1] + image = reg.apply_transforms(dimg, image, transform_list=tx) + if overlay is not None: + overlay = reg.apply_transforms( + dimg, overlay, transform_list=tx, interpolator="linear" + ) + + image2 = reg.apply_transforms(dimg, image2, transform_list=tx) + if overlay2 is not None: + overlay2 = reg.apply_transforms( + dimg, overlay2, transform_list=tx, interpolator="linear" + ) + + ## single-channel images ## + if image.components == 1: + + # potentially crop image + if crop: + plotmask = image.get_mask(cleanup=0) + if plotmask.max() == 0: + plotmask += 1 + image = image.crop_image(plotmask) + if overlay is not None: + overlay = overlay.crop_image(plotmask) + + if crop2: + plotmask2 = image2.get_mask(cleanup=0) + if plotmask2.max() == 0: + plotmask2 += 1 + image2 = image2.crop_image(plotmask2) + if overlay2 is not None: + overlay2 = overlay2.crop_image(plotmask2) + + # potentially find dynamic range + if scale == True: + vmin, vmax = image.quantile((0.05, 0.95)) + elif isinstance(scale, (list, tuple)): + if len(scale) != 2: + raise ValueError( + "scale argument must be boolean or list/tuple with two values" + ) + vmin, vmax = image.quantile(scale) + else: + vmin = None + vmax = None + + if scale2 == True: + vmin2, vmax2 = image2.quantile((0.05, 0.95)) + elif isinstance(scale2, (list, tuple)): + if len(scale2) != 2: + raise ValueError( + "scale2 argument must be boolean or list/tuple with two values" + ) + vmin2, vmax2 = image2.quantile(scale2) + else: + vmin2 = None + vmax2 = None + + if not flat: + nrow = 2 + ncol = 4 + else: + if not transpose: + nrow = 2 + ncol = 3 + else: + nrow = 3 + ncol = 2 + + fig = plt.figure( + figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize) + ) + if title is not None: + basey = 0.88 if not flat else 0.66 + basex = 0.5 + fig.suptitle( + title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy + ) + + gs = gridspec.GridSpec( + nrow, + ncol, + wspace=0.0, + hspace=0.0, + top=1.0 - 0.5 / (nrow + 1), + bottom=0.5 / (nrow + 1), + left=0.5 / (ncol + 1), + right=1 - 0.5 / (ncol + 1), + ) + + # pad image to have isotropic array dimensions + image = image.numpy() + if overlay is not None: + overlay = overlay.numpy() + if overlay.dtype not in ["uint8", "uint32"]: + overlay[np.abs(overlay) == 0] = np.nan + + image2 = image2.numpy() + if overlay2 is not None: + overlay2 = overlay2.numpy() + if overlay2.dtype not in ["uint8", "uint32"]: + overlay2[np.abs(overlay2) == 0] = np.nan + + #################### + #################### + yz_slice = reorient_slice(image[xyz[0], :, :], 0) + ax = plt.subplot(gs[0, 0]) + ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlay is not None: + yz_overlay = reorient_slice(overlay[xyz[0], :, :], 0) + ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap) + if xyz_lines: + # add lines + l = mlines.Line2D( + [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], + [xyz_pad, yz_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, yz_slice.shape[1] - xyz_pad], + [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + ax.axis("off") + + ####### + yz_slice2 = reorient_slice(image2[xyz[0], :, :], 0) + if not flat: + ax = plt.subplot(gs[0, 1]) + else: + if not transpose: + ax = plt.subplot(gs[1, 0]) + else: + ax = plt.subplot(gs[0, 1]) + ax.imshow(yz_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) + if overlay2 is not None: + yz_overlay2 = reorient_slice(overlay2[xyz[0], :, :], 0) + ax.imshow(yz_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) + if xyz_lines: + # add lines + l = mlines.Line2D( + [yz_slice2.shape[0] - xyz[1], yz_slice2.shape[0] - xyz[1]], + [xyz_pad, yz_slice2.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, yz_slice2.shape[1] - xyz_pad], + [yz_slice2.shape[1] - xyz[2], yz_slice2.shape[1] - xyz[2]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + ax.axis("off") + #################### + #################### + + xz_slice = reorient_slice(image[:, xyz[1], :], 1) + if not flat: + ax = plt.subplot(gs[0, 2]) + else: + if not transpose: + ax = plt.subplot(gs[0, 1]) + else: + ax = plt.subplot(gs[1, 0]) + ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlay is not None: + xz_overlay = reorient_slice(overlay[:, xyz[1], :], 1) + ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap) + if xyz_lines: + # add lines + l = mlines.Line2D( + [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], + [xyz_pad, xz_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, xz_slice.shape[1] - xyz_pad], + [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + ax.axis("off") + + ####### + xz_slice2 = reorient_slice(image2[:, xyz[1], :], 1) + if not flat: + ax = plt.subplot(gs[0, 3]) + else: + ax = plt.subplot(gs[1, 1]) + ax.imshow(xz_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) + if overlay is not None: + xz_overlay2 = reorient_slice(overlay2[:, xyz[1], :], 1) + ax.imshow(xz_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) + if xyz_lines: + # add lines + l = mlines.Line2D( + [xz_slice2.shape[0] - xyz[0], xz_slice2.shape[0] - xyz[0]], + [xyz_pad, xz_slice2.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, xz_slice2.shape[1] - xyz_pad], + [xz_slice2.shape[1] - xyz[2], xz_slice2.shape[1] - xyz[2]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + ax.axis("off") + + #################### + #################### + xy_slice = reorient_slice(image[:, :, xyz[2]], 2) + if not flat: + ax = plt.subplot(gs[1, 2]) + else: + if not transpose: + ax = plt.subplot(gs[0, 2]) + else: + ax = plt.subplot(gs[2, 0]) + ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlay is not None: + xy_overlay = reorient_slice(overlay[:, :, xyz[2]], 2) + ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap) + if xyz_lines: + # add lines + l = mlines.Line2D( + [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], + [xyz_pad, xy_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, xy_slice.shape[1] - xyz_pad], + [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + ax.axis("off") + + ####### + xy_slice2 = reorient_slice(image2[:, :, xyz[2]], 2) + if not flat: + ax = plt.subplot(gs[1, 3]) + else: + if not transpose: + ax = plt.subplot(gs[1, 2]) + else: + ax = plt.subplot(gs[2, 1]) + ax.imshow(xy_slice2, cmap=cmap2, vmin=vmin2, vmax=vmax2) + if overlay is not None: + xy_overlay2 = reorient_slice(overlay2[:, :, xyz[2]], 2) + ax.imshow(xy_overlay2, alpha=overlay_alpha2, cmap=overlay_cmap2) + if xyz_lines: + # add lines + l = mlines.Line2D( + [xy_slice2.shape[0] - xyz[0], xy_slice2.shape[0] - xyz[0]], + [xyz_pad, xy_slice2.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, xy_slice2.shape[1] - xyz_pad], + [xy_slice2.shape[1] - xyz[1], xy_slice2.shape[1] - xyz[1]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + ax.axis("off") + + #################### + #################### + + if not flat: + # empty corner + ax = plt.subplot(gs[1, :2]) + if text is not None: + # add text + left, width = 0.25, 0.5 + bottom, height = 0.25, 0.5 + right = left + width + top = bottom + height + ax.text( + 0.5 * (left + right) + text_dx, + 0.5 * (bottom + top) + text_dy, + text, + horizontalalignment="center", + verticalalignment="center", + fontsize=textfontsize, + color=textfontcolor, + transform=ax.transAxes, + ) + # ax.text(0.5, 0.5) + img_shape = list(image.shape[:-1]) + img_shape[1] *= 2 + ax.imshow(np.zeros(img_shape), cmap="Greys_r") + ax.axis("off") + + ## multi-channel images ## + elif image.components > 1: + raise ValueError("Multi-channel images not currently supported!") + + if filename is not None: + plt.savefig(filename, dpi=dpi, transparent=transparent) + plt.close(fig) + else: + plt.show() + + # turn warnings back to default + warnings.simplefilter("default") diff --git a/ants/viz/plot_ortho_stack.py b/ants/viz/plot_ortho_stack.py new file mode 100644 index 00000000..8b2a28ca --- /dev/null +++ b/ants/viz/plot_ortho_stack.py @@ -0,0 +1,521 @@ +""" +Functions for plotting ants images +""" + + +__all__ = [ + "plot_ortho_stack" +] + +import fnmatch +import math +import os +import warnings + +from matplotlib import gridspec +import matplotlib.pyplot as plt +import matplotlib.patheffects as path_effects +import matplotlib.lines as mlines +import matplotlib.patches as patches +import matplotlib.mlab as mlab +import matplotlib.animation as animation +from mpl_toolkits.axes_grid1.inset_locator import inset_axes + + +import numpy as np + +from .. import registration as reg +from ..core import ants_image as iio +from ..core import ants_image_io as iio2 +from ..core import ants_transform as tio +from ..core import ants_transform_io as tio2 + +def plot_ortho_stack( + images, + overlays=None, + reorient=True, + # xyz arguments + xyz=None, + xyz_lines=False, + xyz_color="red", + xyz_alpha=0.6, + xyz_linewidth=2, + xyz_pad=5, + # base image arguments + cmap="Greys_r", + alpha=1, + # overlay arguments + overlay_cmap="jet", + overlay_alpha=0.9, + # background arguments + black_bg=True, + bg_thresh_quant=0.01, + bg_val_quant=0.99, + # scale/crop/domain arguments + crop=False, + scale=False, + domain_image_map=None, + # title arguments + title=None, + titlefontsize=24, + title_dx=0, + title_dy=0, + # 4th panel text arguemnts + text=None, + textfontsize=24, + textfontcolor="white", + text_dx=0, + text_dy=0, + # save & size arguments + filename=None, + dpi=500, + figsize=1.0, + colpad=0, + rowpad=0, + transpose=False, + transparent=True, + orient_labels=True, +): + """ + Create a stack of orthographic plots with optional overlays. + + Use mask_image and/or threshold_image to preprocess images to be be + overlaid and display the overlays in a given range. See the wiki examples. + + Example + ------- + >>> import ants + >>> mni = ants.image_read(ants.get_data('mni')) + >>> ch2 = ants.image_read(ants.get_data('ch2')) + >>> ants.plot_ortho_stack([mni,mni,mni]) + """ + + def mirror_matrix(x): + return x[::-1, :] + + def rotate270_matrix(x): + return mirror_matrix(x.T) + + def reorient_slice(x, axis): + return rotate270_matrix(x) + + # need this hack because of a weird NaN warning from matplotlib with overlays + warnings.simplefilter("ignore") + + n_images = len(images) + + # handle `image` argument + for i in range(n_images): + if isinstance(images[i], str): + images[i] = iio2.image_read(images[i]) + if not isinstance(images[i], iio.ANTsImage): + raise ValueError("image argument must be an ANTsImage") + if images[i].dimension != 3: + raise ValueError("Input image must have 3 dimensions!") + + if overlays is None: + overlays = [None] * n_images + # handle `overlay` argument + for i in range(n_images): + if overlays[i] is not None: + if isinstance(overlays[i], str): + overlays[i] = iio2.image_read(overlays[i]) + if not isinstance(overlays[i], iio.ANTsImage): + raise ValueError("overlay argument must be an ANTsImage") + if overlays[i].components > 1: + raise ValueError("overlays[i] cannot have more than one voxel component") + if overlays[i].dimension != 3: + raise ValueError("Overlay image must have 3 dimensions!") + + if not iio.image_physical_space_consistency(images[i], overlays[i]): + overlays[i] = reg.resample_image_to_target( + overlays[i], images[i], interp_type="linear" + ) + + for i in range(1, n_images): + if not iio.image_physical_space_consistency(images[0], images[i]): + images[i] = reg.resample_image_to_target( + images[0], images[i], interp_type="linear" + ) + + # reorient images + if reorient != False: + if reorient == True: + reorient = "RPI" + + for i in range(n_images): + images[i] = images[i].reorient_image2(reorient) + + if overlays[i] is not None: + overlays[i] = overlays[i].reorient_image2(reorient) + + # handle `slices` argument + if xyz is None: + xyz = [int(s / 2) for s in images[0].shape] + for i in range(3): + if xyz[i] is None: + xyz[i] = int(images[0].shape[i] / 2) + + # resample image if spacing is very unbalanced + spacing = [s for i, s in enumerate(images[0].spacing)] + if (max(spacing) / min(spacing)) > 3.0: + new_spacing = (1, 1, 1) + for i in range(n_images): + images[i] = images[i].resample_image(tuple(new_spacing)) + if overlays[i] is not None: + overlays[i] = overlays[i].resample_image(tuple(new_spacing)) + xyz = [ + int(sl * (sold / snew)) for sl, sold, snew in zip(xyz, spacing, new_spacing) + ] + + # potentially crop image + if crop: + for i in range(n_images): + plotmask = images[i].get_mask(cleanup=0) + if plotmask.max() == 0: + plotmask += 1 + images[i] = images[i].crop_image(plotmask) + if overlays[i] is not None: + overlays[i] = overlays[i].crop_image(plotmask) + + # pad images + for i in range(n_images): + if i == 0: + images[i], lowpad, uppad = images[i].pad_image(return_padvals=True) + else: + images[i] = images[i].pad_image() + if overlays[i] is not None: + overlays[i] = overlays[i].pad_image() + xyz = [v + l for v, l in zip(xyz, lowpad)] + + # handle `domain_image_map` argument + if domain_image_map is not None: + if isinstance(domain_image_map, iio.ANTsImage): + tx = tio2.new_ants_transform( + precision="float", transform_type="AffineTransform", dimension=3 + ) + for i in range(n_images): + images[i] = tio.apply_ants_transform_to_image( + tx, images[i], domain_image_map + ) + + if overlays[i] is not None: + overlays[i] = tio.apply_ants_transform_to_image( + tx, overlays[i], domain_image_map, interpolation="linear" + ) + elif isinstance(domain_image_map, (list, tuple)): + # expect an image and transformation + if len(domain_image_map) != 2: + raise ValueError("domain_image_map list or tuple must have length == 2") + + dimg = domain_image_map[0] + if not isinstance(dimg, iio.ANTsImage): + raise ValueError("domain_image_map first entry should be ANTsImage") + + tx = domain_image_map[1] + for i in range(n_images): + images[i] = reg.apply_transforms(dimg, images[i], transform_list=tx) + if overlays[i] is not None: + overlays[i] = reg.apply_transforms( + dimg, overlays[i], transform_list=tx, interpolator="linear" + ) + + # potentially find dynamic range + if scale == True: + vmins = [] + vmaxs = [] + for i in range(n_images): + vmin, vmax = images[i].quantile((0.05, 0.95)) + vmins.append(vmin) + vmaxs.append(vmax) + elif isinstance(scale, (list, tuple)): + if len(scale) != 2: + raise ValueError( + "scale argument must be boolean or list/tuple with two values" + ) + vmins = [] + vmaxs = [] + for i in range(n_images): + vmin, vmax = images[i].quantile(scale) + vmins.append(vmin) + vmaxs.append(vmax) + else: + vmin = None + vmax = None + + if not transpose: + nrow = n_images + ncol = 3 + else: + nrow = 3 + ncol = n_images + + fig = plt.figure(figsize=((ncol + 1) * 2.5 * figsize, (nrow + 1) * 2.5 * figsize)) + if title is not None: + basey = 0.93 + basex = 0.5 + fig.suptitle( + title, fontsize=titlefontsize, color=textfontcolor, x=basex + title_dx, y=basey + title_dy + ) + + if (colpad > 0) and (rowpad > 0): + bothgridpad = max(colpad, rowpad) + colpad = 0 + rowpad = 0 + else: + bothgridpad = 0.0 + + gs = gridspec.GridSpec( + nrow, + ncol, + wspace=bothgridpad, + hspace=0.0, + top=1.0 - 0.5 / (nrow + 1), + bottom=0.5 / (nrow + 1) + colpad, + left=0.5 / (ncol + 1) + rowpad, + right=1 - 0.5 / (ncol + 1), + ) + + # pad image to have isotropic array dimensions + vminols=[] + vmaxols=[] + for i in range(n_images): + images[i] = images[i].numpy() + if overlays[i] is not None: + vminols.append( overlays[i].min() ) + vmaxols.append( overlays[i].max() ) + overlays[i] = overlays[i].numpy() + if overlays[i].dtype not in ["uint8", "uint32"]: + overlays[i][np.abs(overlays[i]) == 0] = np.nan + + #################### + #################### + for i in range(n_images): + yz_slice = reorient_slice(images[i][xyz[0], :, :], 0) + if not transpose: + ax = plt.subplot(gs[i, 0]) + else: + ax = plt.subplot(gs[0, i]) + ax.imshow(yz_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlays[i] is not None: + yz_overlay = reorient_slice(overlays[i][xyz[0], :, :], 0) + ax.imshow(yz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, + vmin=vminols[i], vmax=vmaxols[i]) + if xyz_lines: + # add lines + l = mlines.Line2D( + [yz_slice.shape[0] - xyz[1], yz_slice.shape[0] - xyz[1]], + [xyz_pad, yz_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, yz_slice.shape[1] - xyz_pad], + [yz_slice.shape[1] - xyz[2], yz_slice.shape[1] - xyz[2]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + if orient_labels: + ax.text( + 0.5, + 0.98, + "S", + horizontalalignment="center", + verticalalignment="top", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.5, + 0.02, + "I", + horizontalalignment="center", + verticalalignment="bottom", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.98, + 0.5, + "A", + horizontalalignment="right", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.02, + 0.5, + "P", + horizontalalignment="left", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.axis("off") + #################### + #################### + + xz_slice = reorient_slice(images[i][:, xyz[1], :], 1) + if not transpose: + ax = plt.subplot(gs[i, 1]) + else: + ax = plt.subplot(gs[1, i]) + ax.imshow(xz_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlays[i] is not None: + xz_overlay = reorient_slice(overlays[i][:, xyz[1], :], 1) + ax.imshow(xz_overlay, alpha=overlay_alpha, cmap=overlay_cmap, + vmin=vminols[i], vmax=vmaxols[i]) + if xyz_lines: + # add lines + l = mlines.Line2D( + [xz_slice.shape[0] - xyz[0], xz_slice.shape[0] - xyz[0]], + [xyz_pad, xz_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, xz_slice.shape[1] - xyz_pad], + [xz_slice.shape[1] - xyz[2], xz_slice.shape[1] - xyz[2]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + if orient_labels: + ax.text( + 0.5, + 0.98, + "A", + horizontalalignment="center", + verticalalignment="top", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.5, + 0.02, + "P", + horizontalalignment="center", + verticalalignment="bottom", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.98, + 0.5, + "L", + horizontalalignment="right", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.02, + 0.5, + "R", + horizontalalignment="left", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.axis("off") + + #################### + #################### + xy_slice = reorient_slice(images[i][:, :, xyz[2]], 2) + if not transpose: + ax = plt.subplot(gs[i, 2]) + else: + ax = plt.subplot(gs[2, i]) + ax.imshow(xy_slice, cmap=cmap, vmin=vmin, vmax=vmax) + if overlays[i] is not None: + xy_overlay = reorient_slice(overlays[i][:, :, xyz[2]], 2) + ax.imshow(xy_overlay, alpha=overlay_alpha, cmap=overlay_cmap, + vmin=vminols[i], vmax=vmaxols[i]) + if xyz_lines: + # add lines + l = mlines.Line2D( + [xy_slice.shape[0] - xyz[0], xy_slice.shape[0] - xyz[0]], + [xyz_pad, xy_slice.shape[0] - xyz_pad], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + l = mlines.Line2D( + [xyz_pad, xy_slice.shape[1] - xyz_pad], + [xy_slice.shape[1] - xyz[1], xy_slice.shape[1] - xyz[1]], + color=xyz_color, + alpha=xyz_alpha, + linewidth=xyz_linewidth, + ) + ax.add_line(l) + if orient_labels: + ax.text( + 0.5, + 0.98, + "A", + horizontalalignment="center", + verticalalignment="top", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.5, + 0.02, + "P", + horizontalalignment="center", + verticalalignment="bottom", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.98, + 0.5, + "L", + horizontalalignment="right", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.text( + 0.02, + 0.5, + "R", + horizontalalignment="left", + verticalalignment="center", + fontsize=20 * figsize, + color=textfontcolor, + transform=ax.transAxes, + ) + ax.axis("off") + + #################### + #################### + + if filename is not None: + plt.savefig(filename, dpi=dpi, transparent=transparent) + plt.close(fig) + else: + plt.show() + + # turn warnings back to default + warnings.simplefilter("default")