Skip to content

Commit

Permalink
Merge pull request #1 from DIAGNijmegen/nnunet-preprocessing-strategy
Browse files Browse the repository at this point in the history
New preprocessing strategy
  • Loading branch information
anindox8 authored May 19, 2022
2 parents 28c9bb4 + 9cf6a47 commit 0b46a5d
Show file tree
Hide file tree
Showing 85 changed files with 442 additions and 305 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.mha filter=lfs diff=lfs merge=lfs -text
*.nii.gz filter=lfs diff=lfs merge=lfs -text
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ jobs:

steps:
- uses: actions/checkout@v2
with:
lfs: 'true'
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
Expand Down
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
long_description = fh.read()

setuptools.setup(
version='1.1.6',
version='1.2',
author_email='Joeran.Bosma@radboudumc.nl',
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -14,5 +14,7 @@
"Bug Tracker": "https://github.com/DIAGNijmegen/picai_prep/issues"
},
license='Apache License, Version 2.0',
packages=['picai_prep', 'picai_prep.resources', 'picai_prep.examples.dcm2mha', 'picai_prep.examples.mha2nnunet'],
package_dir={"": "src"}, # our packages live under src, but src is not a package itself
packages=setuptools.find_packages('src', exclude=['tests']),
exclude_package_data={'': ['tests']},
)
21 changes: 11 additions & 10 deletions src/picai_prep/examples/mha2nnunet/picai_archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,17 @@ def generate_mha2nnunet_settings(
}
},
"preprocessing": {
"matrix_size": [
20,
320,
320
],
"spacing": [
3.0,
0.5,
0.5
]
# optionally, resample and perform centre crop:
# "matrix_size": [
# 20,
# 320,
# 320
# ],
# "spacing": [
# 3.0,
# 0.5,
# 0.5
# ],
},
"archive": archive_list
}
Expand Down
21 changes: 11 additions & 10 deletions src/picai_prep/examples/mha2nnunet/picai_archive_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,17 @@ def generate_mha2nnunet_settings(
}
},
"preprocessing": {
"matrix_size": [
20,
160,
160
],
"spacing": [
3.6,
0.5,
0.5
]
# optionally, resample and perform centre crop:
# "matrix_size": [
# 20,
# 160,
# 160
# ],
# "spacing": [
# 3.6,
# 0.5,
# 0.5
# ]
},
"archive": archive_list
}
Expand Down
189 changes: 49 additions & 140 deletions src/picai_prep/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@

import SimpleITK as sitk
import numpy as np
from numpy.testing import assert_allclose
from dataclasses import dataclass
from scipy import ndimage

from typing import List, Tuple, Callable, Optional, Union, Any, Iterable, cast
from typing import List, Callable, Optional, Union, Any, Iterable
try:
import numpy.typing as npt
except ImportError: # pragma: no cover
Expand All @@ -32,21 +33,15 @@ class PreprocessingSettings():
- matrix_size: number of voxels output volume (z, y, x)
- spacing: output voxel spacing in mm (z, y, x)
- physical_size: size in mm/voxel of the target volume (z, y, x)
- align_physical_space: whether to align sequences to eachother, based on metadata
- crop_to_first_physical_centre: whether to crop to physical centre of first sequence,
or to the new centre after aligning sequences
- align_segmentation: whether to align the scans using the centroid of the provided segmentation
"""
matrix_size: Iterable[int] = (20, 160, 160)
matrix_size: Optional[Iterable[int]] = None
spacing: Optional[Iterable[float]] = None
physical_size: Optional[Iterable[float]] = None
align_physical_space: bool = False
crop_to_first_physical_centre: bool = False
align_segmentation: Optional[sitk.Image] = None

def __post_init__(self):
if self.physical_size is None:
assert self.spacing, "Need either physical_size or spacing"
if self.physical_size is None and self.spacing is not None and self.matrix_size is not None:
# calculate physical size
self.physical_size = [
voxel_spacing * num_voxels
Expand All @@ -56,8 +51,7 @@ def __post_init__(self):
)
]

if self.spacing is None:
assert self.physical_size, "Need either physical_size or spacing"
if self.spacing is None and self.physical_size is not None and self.matrix_size is not None:
# calculate spacing
self.spacing = [
size / num_voxels
Expand All @@ -67,13 +61,8 @@ def __post_init__(self):
)
]

@property
def _spacing(self) -> Iterable[float]:
return cast(Iterable[float], self.spacing)

@property
def _physical_size(self) -> Iterable[float]:
return cast(Iterable[float], self.physical_size)
if self.align_segmentation is not None:
raise NotImplementedError("Alignment of scans based on segmentation is not implemented yet.")


def resample_img(
Expand Down Expand Up @@ -170,83 +159,6 @@ def crop_or_pad(
return np.pad(image[tuple(slicer)], padding)


def get_overlap_start_indices(img_main: sitk.Image, img_secondary: sitk.Image):
# convert start index from main image to secondary image
point_secondary = img_secondary.TransformIndexToPhysicalPoint((0, 0, 0))
index_main = img_main.TransformPhysicalPointToContinuousIndex(point_secondary)

# clip index
index_main = np.clip(index_main, a_min=0, a_max=None)

# convert main index back to secondary image
point_main = img_main.TransformContinuousIndexToPhysicalPoint(index_main)
index_secondary = img_secondary.TransformPhysicalPointToContinuousIndex(point_main)

# round secondary index up (round to 5 decimals for e.g. 18.999999999999996)
index_secondary = np.ceil(np.round(index_secondary, decimals=5))

# convert secondary index once again to main image
point_secondary = img_secondary.TransformContinuousIndexToPhysicalPoint(index_secondary)
index_main = img_main.TransformPhysicalPointToIndex(point_secondary)

# convert and return result
return np.array(index_secondary).astype(int), np.array(index_main).astype(int)


def get_overlap_end_indices(img_main: sitk.Image, img_secondary: sitk.Image):
# convert end index from secondary image to primary image
point_secondary = img_secondary.TransformIndexToPhysicalPoint(img_secondary.GetSize())
index_main = img_main.TransformPhysicalPointToContinuousIndex(point_secondary)

# clip index
index_main = [min(sz, i) for (i, sz) in zip(index_main, img_main.GetSize())]

# convert primary index back to secondary image
point_main = img_main.TransformContinuousIndexToPhysicalPoint(index_main)
index_secondary = img_secondary.TransformPhysicalPointToContinuousIndex(point_main)

# round secondary index down (round to 5 decimals for e.g. 18.999999999999996)
index_secondary = np.floor(np.round(index_secondary, decimals=5))

# convert secondary index once again to primary image
point_secondary = img_secondary.TransformContinuousIndexToPhysicalPoint(index_secondary)
index_main = img_main.TransformPhysicalPointToIndex(point_secondary)

# convert and return result
return np.array(index_secondary).astype(int), np.array(index_main).astype(int)


def crop_to_common_physical_space(
img_main: sitk.Image,
img_sec: sitk.Image
) -> Tuple[sitk.Image, sitk.Image]:
"""
Crop SimpleITK images to the largest shared physical volume
"""
# determine crop indices
idx_start_sec, idx_start_main = get_overlap_start_indices(img_main, img_sec)
idx_end_sec, idx_end_main = get_overlap_end_indices(img_main, img_sec)

# check extracted indices
assert ((idx_end_sec - idx_start_sec) > np.array(img_sec.GetSize()) / 2).all(), \
"Found unrealistically little overlap when aligning scans, aborting."
assert ((idx_end_main - idx_start_main) > np.array(img_main.GetSize()) / 2).all(), \
"Found unrealistically little overlap when aligning scans, aborting."

# apply crop
slices = [slice(idx_start, idx_end) for (idx_start, idx_end) in zip(idx_start_main, idx_end_main)]
img_main = img_main[slices]

slices = [slice(idx_start, idx_end) for (idx_start, idx_end) in zip(idx_start_sec, idx_end_sec)]
img_sec = img_sec[slices]

return img_main, img_sec


def get_physical_centre(image: sitk.Image):
return image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize()) / 2.0)


@dataclass
class Sample:
scans: List[sitk.Image]
Expand All @@ -260,53 +172,42 @@ class Sample:
num_gt_lesions: Optional[int] = None

def __post_init__(self):
# determine main centre
self.main_centre = get_physical_centre(self.scans[0])

if self.lbl is not None:
# keep track of connected components
lbl = sitk.GetArrayFromImage(self.lbl)
_, num_gt_lesions = ndimage.label(lbl, structure=np.ones((3, 3, 3)))
self.num_gt_lesions = num_gt_lesions

def crop_to_common_physical_space(self):
"""
Align physical centre of the first scan (e.g., T2W) with subsequent scans (e.g., ADC, high b-value)
"""
main_centre = get_physical_centre(self.scans[0])

should_align_scans = False
for scan in self.scans[1:]:
secondary_centre = get_physical_centre(scan)

# calculate distance from center of first scan (e.g., T2W) to center of secondary scan (e.g., ADC, high b-value)
distance = np.sqrt(np.sum((np.array(main_centre) - np.array(secondary_centre))**2))

# if difference in center coordinates is more than 2mm, align the scans
if distance > 2:
print(f"Aligning scans with distance of {distance:.1f} mm between centers for {self.name}.")
should_align_scans = True

if should_align_scans:
for i, main_scan in enumerate(self.scans):
for j, secondary_scan in enumerate(self.scans):
if i == j:
continue

# align scans
img_main, img_sec = crop_to_common_physical_space(main_scan, secondary_scan)
self.scans[i] = img_main
self.scans[j] = img_sec

def resample(self):
"""Resample scans and label"""
def resample_to_first_scan(self):
"""Resample scans and label to the first scan"""
# set up resampler to resolution, field of view, etc. of first scan
resampler = sitk.ResampleImageFilter() # default linear
resampler.SetReferenceImage(self.scans[0])
resampler.SetInterpolator(sitk.sitkBSpline)

# resample other images
self.scans[1:] = [resampler.Execute(scan) for scan in self.scans[1:]]

# resample annotation
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
if self.lbl is not None:
self.lbl = resampler.Execute(self.lbl)

def resample_spacing(self, spacing: Optional[Iterable[float]] = None):
"""Resample scans and label to the target spacing"""
if spacing is None:
assert self.settings.spacing is not None
spacing = self.settings.spacing

# resample scans to target resolution
self.scans = [
resample_img(scan, out_spacing=self.settings._spacing, is_label=False)
resample_img(scan, out_spacing=spacing, is_label=False)
for scan in self.scans
]

# resample annotation to target resolution
if self.lbl is not None:
self.lbl = resample_img(self.lbl, out_spacing=self.settings._spacing, is_label=True)
self.lbl = resample_img(self.lbl, out_spacing=spacing, is_label=True)

def centre_crop(self):
"""Centre crop scans and label"""
Expand All @@ -318,7 +219,7 @@ def centre_crop(self):
if self.lbl is not None:
self.lbl = crop_or_pad(self.lbl, size=self.settings.matrix_size)

def copy_physical_metadata(self):
def align_physical_metadata(self, check_almost_equal=True):
"""Align the origin and direction of each scan, and label"""
case_origin, case_direction, case_spacing = None, None, None
for img in self.scans:
Expand All @@ -328,6 +229,13 @@ def copy_physical_metadata(self):
case_direction = img.GetDirection()
case_spacing = img.GetSpacing()
else:
if check_almost_equal:
# check if current scan's metadata is almost equal to the first scan
assert_allclose(img.GetOrigin(), case_origin)
assert_allclose(img.GetDirection(), case_direction)
assert_allclose(img.GetSpacing(), case_spacing)

# copy over first scan's metadata to current scan
img.SetOrigin(case_origin)
img.SetDirection(case_direction)
img.SetSpacing(case_spacing)
Expand All @@ -348,18 +256,19 @@ def preprocess(self):
# apply scan transformation
self.scans = [self.scan_preprocess_func(scan) for scan in self.scans]

if self.settings.align_physical_space:
# align sequences based on metadata
self.crop_to_common_physical_space()
if self.settings.spacing is not None:
# resample scans and label to specified spacing
self.resample_spacing()

# resample scans and label
self.resample()
if self.settings.matrix_size is not None:
# perform centre crop
self.centre_crop()

# perform centre crop
self.centre_crop()
# resample scans and label to first scan's spacing, field-of-view, etc.
self.resample_to_first_scan()

# copy physical metadata to align subvoxel differences between sequences
self.copy_physical_metadata()
self.align_physical_metadata()

if self.lbl is not None:
# check connected components of annotation
Expand Down
14 changes: 3 additions & 11 deletions src/picai_prep/resources/mha2nnunet_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,16 @@
"type": "object",
"description": "Preprocessing parameters",
"properties": {
"align_physical_space": {
"description": "...",
"type": "boolean"
},
"crop_to_first_physical_centre": {
"description": "...",
"type": "boolean"
},
"physical_size": {
"description": "...",
"description": "Target field-of-view in mm (z, y, x). Automatically calculated if `matrix_size` and `spacing` are set.",
"$ref": "#/$defs/3d"
},
"matrix_size": {
"description": "Defaults to [20, 160, 160] if neither this or 'physical_size' is set.",
"description": "Target matrix size. Automatically calculated if `physical_size` and `spacing` are set.",
"$ref": "#/$defs/3d"
},
"spacing": {
"description": "...",
"description": "Target resolution in mm/voxel (z, y, x). Automatically calculated if `physical_size` and `matrix_size` are set.",
"$ref": "#/$defs/3d"
}
},
Expand Down
Binary file modified tests/input/annotations/ProstateX/ProstateX-0000_07-07-2011.nii.gz
100755 → 100644
Binary file not shown.
Binary file modified tests/input/annotations/ProstateX/ProstateX-0001_07-08-2011.nii.gz
100755 → 100644
Binary file not shown.
Git LFS file not shown
Git LFS file not shown
Loading

0 comments on commit 0b46a5d

Please sign in to comment.