diff --git a/ants/registration/registration.py b/ants/registration/registration.py index 2c6739a1..1a377524 100644 --- a/ants/registration/registration.py +++ b/ants/registration/registration.py @@ -1,7 +1,7 @@ """ ANTsPy Registration """ -__all__ = ["registration", +__all__ = ["registration", "motion_correction", "label_image_registration"] @@ -64,8 +64,9 @@ def registration( See Notes below for more. initial_transform : list of strings (optional) - transforms to prepend. If None, a translation is computed to align the image centers of mass. - To use an identity transform, set this to 'Identity'. + transforms to prepend. If None, a translation is computed to align the image centers of mass, unless the type of + transform is deformable-only (time-varying diffeomorphisms, SyNOnly, or antsRegistrationSyN*[so|bo]). + To force initialization with an identity transform, set this to 'Identity'. outprefix : string output will be named with this prefix. @@ -84,10 +85,10 @@ def registration( flow_sigma : scalar smoothing for update field - At each iteration, the similarity metric and gradient is calculated. - That gradient field is also called the update field and is smoothed - before composing with the total field (i.e., the estimate of the total - transform at that iteration). This total field can also be smoothed + At each iteration, the similarity metric and gradient is calculated. + That gradient field is also called the update field and is smoothed + before composing with the total field (i.e., the estimate of the total + transform at that iteration). This total field can also be smoothed after each iteration. total_sigma : scalar @@ -155,7 +156,7 @@ def registration( singleprecision : boolean if True, use float32 for computations. This is useful for reducing memory - usage for large datasets, at the cost of precision. + usage for large datasets, at the cost of precision. kwargs : keyword args extra arguments @@ -197,12 +198,7 @@ def registration( - "SyNRA": Symmetric normalization: Rigid + Affine + deformable transformation, with mutual information as optimization metric. - "SyNOnly": Symmetric normalization with no rigid or affine stages. - Uses mutual information as optimization metric. Affine alignment is - from the initial_transform arg, either provide the .mat from linear - registration or use initial_transform='Identity' if the images are - already affinely aligned. - Can be useful if you want to run an unmasked affine followed by - masked deformable registration. + Uses mutual information as optimization metric. - "SyNCC": SyN, but with cross-correlation as the metric. - "SyNabp": SyN optimized for abpBrainExtraction. - "SyNBold": SyN, but optimized for registrations between BOLD and T1 images. @@ -446,8 +442,18 @@ def registration( else: earlymaskopt = "[NA,NA]" + deformable_only_transforms = ["SyNOnly", "antsRegistrationSyN[so]", "antsRegistrationSyNQuick[so]", + "antsRegistrationSyNRepro[so]", "antsRegistrationSyNQuickRepro[so]", + "antsRegistrationSyN[bo]", "antsRegistrationSyNQuick[bo]", + "antsRegistrationSyNRepro[bo]", "antsRegistrationSyNQuickRepro[bo]", + "TVMSQ", "TVMSQC"] + tvTypes + if initx is None: - initx = ["[%s,%s,1]" % (f, m)] + if type_of_transform in deformable_only_transforms: + initx = ["Identity"] + else: + initx = ["[%s,%s,1]" % (f, m)] + # ------------------------------------------------------------ if type_of_transform == "SyNBold": args = [ @@ -1067,7 +1073,8 @@ def registration( args = [ "-d", str(fixed.dimension), - # '-r', initx, + '-r' + ] + initx + [ "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -1098,7 +1105,8 @@ def registration( args = [ "-d", str(fixed.dimension), - # '-r', initx, + '-r' + ] + initx + [ "-m", "demons[%s,%s,0.5,0]" % (f, m), "-m", @@ -1573,7 +1581,7 @@ def motion_correction( "FD": FD, } -def label_image_registration(fixed_label_images, +def label_image_registration(fixed_label_images, moving_label_images, fixed_intensity_images=None, moving_intensity_images=None, @@ -1587,8 +1595,8 @@ def label_image_registration(fixed_label_images, verbose=False): """ - Perform pairwise registration using fixed and moving sets of label - images (and, optionally, sets of corresponding intensity images). + Perform pairwise registration using fixed and moving sets of label + images (and, optionally, sets of corresponding intensity images). Arguments --------- @@ -1607,34 +1615,34 @@ def label_image_registration(fixed_label_images, fixed_mask : ANTsImage Defines region for similarity metric calculation in the space of the fixed image. - + moving_mask : ANTsImage Defines region for similarity metric calculation in the space of the moving image. - + type_of_linear_transform : string - Use label images with the centers of mass to a calculate linear + Use label images with the centers of mass to a calculate linear transform of type 'rigid', 'similarity', or 'affine'. type_of_deformable_transform : string Only works with deformable-only transforms, specifically the family - of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms. + of antsRegistrationSyN*[so] or antsRegistrationSyN*[bo] transforms. See 'type_of_transform' in ants.registration. Additionally, one can - use a list to pass a more tailored deformably-only transform - optimization using SyN or BSplineSyN transforms. The order of + use a list to pass a more tailored deformably-only transform + optimization using SyN or BSplineSyN transforms. The order of parameters in the list would be 1) transform specification, i.e. - "SyN" or "BSplineSyN", 2) gradient (real), 3) intensity metric (string), - 4) intensity metric parameter (real), 5) convergence iterations per level - (tuple) 6) smoothing factors per level (tuple), 7) shrink factors per level - (tuple). An example would type_of_deformable_transform = ["SyN", 0.2, "CC", + "SyN" or "BSplineSyN", 2) gradient (real), 3) intensity metric (string), + 4) intensity metric parameter (real), 5) convergence iterations per level + (tuple) 6) smoothing factors per level (tuple), 7) shrink factors per level + (tuple). An example would type_of_deformable_transform = ["SyN", 0.2, "CC", 4, (100,50,10), (2,1,0), (4,2,1)]. label_image_weighting : float or list of floats Relative weighting for the label images. - + output_prefix : string Define the output prefix for the filenames of the output transform - files. + files. random_seed : integer Definition for deformable registration. @@ -1644,7 +1652,7 @@ def label_image_registration(fixed_label_images, Returns ------- - Set of transforms definining the mapping to/from the fixed image domain + Set of transforms definining the mapping to/from the fixed image domain to the moving image domain. Example @@ -1658,7 +1666,7 @@ def label_image_registration(fixed_label_images, >>> r64_seg1 = ants.threshold_image(r64, "Kmeans", 3) - 1 >>> r64_seg2 = ants.threshold_image(r64, "Kmeans", 5) - 1 >>> reg = ants.label_image_registration([r16_seg1, r16_seg2], - [r64_seg1, r64_seg2], + [r64_seg1, r64_seg2], fixed_intensity_images=r16, moving_intensity_images=r64, type_of_linear_transform='affine', @@ -1691,7 +1699,7 @@ def label_image_registration(fixed_label_images, else: label_image_weights = tuple(label_image_weighting) if len(fixed_label_images) != len(label_image_weights): - raise ValueError("The length of label_image_weights must" + + raise ValueError("The length of label_image_weights must" + "match the number of label image pairs.") image_dimension = fixed_label_images[0].dimension @@ -1699,10 +1707,10 @@ def label_image_registration(fixed_label_images, if output_prefix == "" or output_prefix is None or len(output_prefix) == 0: output_prefix = mktemp() - allowable_linear_transforms = ['rigid', 'similarity', 'affine'] + allowable_linear_transforms = ['rigid', 'similarity', 'affine'] if not type_of_linear_transform in allowable_linear_transforms: - raise ValueError("Unrecognized linear transform.") - + raise ValueError("Unrecognized linear transform.") + do_deformable = True if type_of_deformable_transform is None or len(type_of_deformable_transform) == 0: do_deformable = False @@ -1720,7 +1728,7 @@ def label_image_registration(fixed_label_images, print("Common label ids for image pair ", str(i), ": ", common_label_ids[i]) if len(common_label_ids[i]) == 0: raise ValueError("No common labels for image pair " + str(i)) - + if verbose: print("Total number of labels: " + str(total_number_of_labels)) @@ -1737,12 +1745,12 @@ def label_image_registration(fixed_label_images, print("\n\nComputing linear transform.\n") if total_number_of_labels < 3: - raise ValueError(" Number of labels must be >= 3.") + raise ValueError(" Number of labels must be >= 3.") - fixed_centers_of_mass = np.zeros((total_number_of_labels, image_dimension)) + fixed_centers_of_mass = np.zeros((total_number_of_labels, image_dimension)) moving_centers_of_mass = np.zeros((total_number_of_labels, image_dimension)) deformable_multivariate_extras = list() - + count = 0 for i in range(len(common_label_ids)): for j in range(len(common_label_ids[i])): @@ -1755,17 +1763,17 @@ def label_image_registration(fixed_label_images, moving_centers_of_mass[count, :] = ants.get_center_of_mass(moving_single_label_image) count += 1 if do_deformable: - deformable_multivariate_extras.append(["MSQ", fixed_single_label_image, - moving_single_label_image, + deformable_multivariate_extras.append(["MSQ", fixed_single_label_image, + moving_single_label_image, label_image_weights[i], 0]) - - linear_xfrm = ants.fit_transform_to_paired_points(moving_centers_of_mass, - fixed_centers_of_mass, + + linear_xfrm = ants.fit_transform_to_paired_points(moving_centers_of_mass, + fixed_centers_of_mass, transform_type=type_of_linear_transform, verbose=verbose) - + linear_xfrm_file = output_prefix + "0GenericAffine.mat" - ants.write_transform(linear_xfrm, linear_xfrm_file) + ants.write_transform(linear_xfrm, linear_xfrm_file) ############################## # @@ -1787,7 +1795,7 @@ def label_image_registration(fixed_label_images, gradient_step = 0.1 syn_transform = "SyN" - syn_stage = list() + syn_stage = list() if isinstance(type_of_deformable_transform, list): @@ -1795,17 +1803,17 @@ def label_image_registration(fixed_label_images, not isinstance(type_of_deformable_transform[0], str) or not isinstance(type_of_deformable_transform[1], float) or not isinstance(type_of_deformable_transform[2], str) or - not isinstance(type_of_deformable_transform[3], int) or - not isinstance(type_of_deformable_transform[4], tuple) or - not isinstance(type_of_deformable_transform[5], tuple) or + not isinstance(type_of_deformable_transform[3], int) or + not isinstance(type_of_deformable_transform[4], tuple) or + not isinstance(type_of_deformable_transform[5], tuple) or not isinstance(type_of_deformable_transform[6], tuple)): - raise ValueError("Incorrect specification for type_of_deformable_transform. See help menu.") + raise ValueError("Incorrect specification for type_of_deformable_transform. See help menu.") syn_transform = type_of_deformable_transform[0] gradient_step = type_of_deformable_transform[1] intensity_metric = type_of_deformable_transform[2] intensity_metric_parameter = type_of_deformable_transform[3] - + t = type_of_deformable_transform[4] tstr = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1]) syn_convergence = "[" + tstr + ",1e-6,10]" @@ -1816,8 +1824,8 @@ def label_image_registration(fixed_label_images, t = type_of_deformable_transform[6] syn_shrink_factors = ''.join(map(lambda x: str(x) + 'x', t[:len(t)-1])) + str(t[len(t)-1]) - - else: + + else: do_quick = False if "Quick" in type_of_deformable_transform: @@ -1840,10 +1848,10 @@ def label_image_registration(fixed_label_images, spline_distance = subtype_of_deformable_transform_args[2] if do_quick: - intensity_metric = "MI" + intensity_metric = "MI" if intensity_metric_parameter is None: intensity_metric_parameter = 32 - syn_convergence = "[100x70x50x0,1e-6,10]" + syn_convergence = "[100x70x50x0,1e-6,10]" if fixed_intensity_images is not None and len(fixed_intensity_images) > 0: for i in range(len(fixed_intensity_images)): @@ -1854,7 +1862,7 @@ def label_image_registration(fixed_label_images, get_pointer_string(moving_intensity_images[i]), 1.0, intensity_metric_parameter) syn_stage.append(metric_string) - + for kk in range(len(deformable_multivariate_extras)): syn_stage.append("--metric") metricString = "%s[%s,%s,%s,%s]" % ( @@ -1862,7 +1870,7 @@ def label_image_registration(fixed_label_images, get_pointer_string(deformable_multivariate_extras[kk][1]), get_pointer_string(deformable_multivariate_extras[kk][2]), deformable_multivariate_extras[kk][3], 0.0) - syn_stage.append(metricString) + syn_stage.append(metricString) syn_stage.append("--convergence") syn_stage.append(syn_convergence) @@ -1887,25 +1895,25 @@ def label_image_registration(fixed_label_images, "-o", output_prefix] args.append(syn_stage) - fixed_mask_string = 'NA' + fixed_mask_string = 'NA' if fixed_mask is not None: fixed_mask_binary = fixed_mask != 0 fixed_mask_string = get_pointer_string(fixed_mask_binary) - moving_mask_string = 'NA' + moving_mask_string = 'NA' if moving_mask is not None: moving_mask_binary = moving_mask != 0 moving_mask_string = get_pointer_string(moving_mask_binary) mask_option = "[%s,%s]" % (fixed_mask_string, moving_mask_string) - + args.append("-x") args.append(mask_option) args = list(itertools.chain.from_iterable( - itertools.repeat(x, 1) - if isinstance(x, str) - else x for x in args)) + itertools.repeat(x, 1) + if isinstance(x, str) + else x for x in args)) args.append("--float") args.append("1") @@ -1929,7 +1937,7 @@ def label_image_registration(fixed_label_images, raise RuntimeError(f"Registration failed with error code {deformable_registration_exit_error}") all_xfrms = sorted(set(glob.glob(output_prefix + "*" + "[0-9]*"))) - + find_inverse_warps = np.where([re.search("[0-9]InverseWarp.nii.gz", ff) for ff in all_xfrms])[0] find_forward_warps = np.where([re.search("[0-9]Warp.nii.gz", ff) for ff in all_xfrms])[0] @@ -1942,8 +1950,8 @@ def label_image_registration(fixed_label_images, if verbose: print("\n\nResulting transforms") - print(" fwdtransforms: ", fwdtransforms) - print(" invtransforms: ", invtransforms) + print(" fwdtransforms: ", fwdtransforms) + print(" invtransforms: ", invtransforms) return { "fwdtransforms": fwdtransforms,