Skip to content

Commit

Permalink
cell pose water
Browse files Browse the repository at this point in the history
  • Loading branch information
kapoorlab committed Sep 8, 2024
1 parent f1b6297 commit 16351e4
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 52 deletions.
4 changes: 2 additions & 2 deletions src/vollseg/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = version = "32.3.0"
__version_tuple__ = version_tuple = (32, 3, 0)
__version__ = version = "32.3.1"
__version_tuple__ = version_tuple = (32, 3, 1)
103 changes: 53 additions & 50 deletions src/vollseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cv2
from skimage.segmentation import clear_border
from scipy.ndimage import gaussian_filter

# import matplotlib.pyplot as plt
import pandas as pd
from cellpose import models
Expand Down Expand Up @@ -331,6 +332,7 @@ def dilate_label_holes(lbl_img, iterations):
lbl_img_filled[mask_filled] = lb
return lbl_img_filled


def erode_labels(lbl_img, iterations=1):
lbl_img_filled = np.zeros_like(lbl_img)
for lb in range(np.min(lbl_img), np.max(lbl_img) + 1):
Expand All @@ -339,6 +341,7 @@ def erode_labels(lbl_img, iterations=1):
lbl_img_filled[mask_filled] = lb
return lbl_img_filled


def erode_label_regions(segmentation, erosion_iterations=1):
regions = regionprops(segmentation)
erode = np.zeros(segmentation.shape)
Expand Down Expand Up @@ -366,14 +369,20 @@ def process_region_3d(label_id, z):

# Aggregate results
for future in futures:
result, z = future.result(), futures.index(future) % segmentation.shape[0]
result, z = (
future.result(),
futures.index(future) % segmentation.shape[0],
)
erode[z, :, :] += result

# For 2D segmentation, we parallelize only over regions
else:
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_region_2d, regions[i].label) for i in range(len(regions))]

futures = [
executor.submit(process_region_2d, regions[i].label)
for i in range(len(regions))
]

# Aggregate results
for future in futures:
erode += future.result()
Expand Down Expand Up @@ -408,22 +417,27 @@ def process_region_3d(label_id, z):

# Aggregate results
for future in futures:
result, z = future.result(), futures.index(future) % segmentation.shape[0]
result, z = (
future.result(),
futures.index(future) % segmentation.shape[0],
)
erode[z, :, :] += result

# For 2D segmentation, parallelize only over regions
else:
with ThreadPoolExecutor() as executor:
futures = [executor.submit(process_region_2d, regions[i].label) for i in range(len(regions))]

futures = [
executor.submit(process_region_2d, regions[i].label)
for i in range(len(regions))
]

# Aggregate results
for future in futures:
erode += future.result()

return erode



def match_labels(ys: np.ndarray, nms_thresh=0.5):

if nms_thresh is None:
Expand Down Expand Up @@ -1573,7 +1587,7 @@ def CellPoseSeg(

def VollCellSeg(
image: np.ndarray,
nuclei_seg_image: np.ndarray ,
nuclei_seg_image: np.ndarray,
cellpose_labels: np.ndarray = None,
diameter_cellpose: float = 34.6,
stitch_threshold: float = 0.5,
Expand All @@ -1590,11 +1604,8 @@ def VollCellSeg(
Name: str = "Result",
do_3D: bool = False,
channels=None,

):



if len(image.shape) == 3 and "T" not in axes:
# Just a 3D image
image_membrane = image
Expand All @@ -1613,8 +1624,6 @@ def VollCellSeg(
channels=channels,
)



if len(image.shape) == 4 and "T" not in axes:
image_membrane = image[:, channel_membrane, :, :]

Expand All @@ -1632,8 +1641,6 @@ def VollCellSeg(
channels=channels,
)



if len(image.shape) > 4 and "T" in axes:

if len(n_tiles) == 4:
Expand All @@ -1654,53 +1661,50 @@ def VollCellSeg(
channels=channels,
)



if cellpose_model_path is not None:
cellpose_labels = cellres[0]
cellpose_labels = np.asarray(cellpose_labels)

if nuclei_seg_image is not None:


voll_cell_seg = _cellpose_block(
axes, image_membrane, nuclei_seg_image, cellpose_labels
)

if save_dir is not None:
Path(save_dir).mkdir(exist_ok=True)
voll_cell_seg = _cellpose_block(
axes, image_membrane, nuclei_seg_image, cellpose_labels
)

if cellpose_model_path is not None:
cellpose_results = Path(save_dir) / "CellPose"
Path(cellpose_results).mkdir(exist_ok=True)
imwrite(
(os.path.join(cellpose_results.as_posix(), Name + ".tif")),
np.asarray(cellpose_labels).astype("uint16"),
)
if save_dir is not None:
Path(save_dir).mkdir(exist_ok=True)

vollcellpose_results = Path(save_dir) / "VollCellPose"
Path(vollcellpose_results).mkdir(exist_ok=True)
if cellpose_model_path is not None:
cellpose_results = Path(save_dir) / "CellPose"
Path(cellpose_results).mkdir(exist_ok=True)
imwrite(
(os.path.join(vollcellpose_results.as_posix(), Name + ".tif")),
np.asarray(voll_cell_seg).astype("uint16"),
(os.path.join(cellpose_results.as_posix(), Name + ".tif")),
np.asarray(cellpose_labels).astype("uint16"),
)


vollcellpose_results = Path(save_dir) / "VollCellPose"
Path(vollcellpose_results).mkdir(exist_ok=True)
imwrite(
(os.path.join(vollcellpose_results.as_posix(), Name + ".tif")),
np.asarray(voll_cell_seg).astype("uint16"),
)


def _cellpose_block(axes, membrane_image, sized_smart_seeds,cellpose_labels):
def _cellpose_block(axes, membrane_image, sized_smart_seeds, cellpose_labels):

if "T" not in axes:

voll_cell_seg = CellPoseWater(membrane_image, sized_smart_seeds,cellpose_labels)
voll_cell_seg = CellPoseWater(
membrane_image, sized_smart_seeds, cellpose_labels
)
if "T" in axes:

voll_cell_seg = []
for time in range(sized_smart_seeds.shape[0]):
sized_smart_seeds_time = sized_smart_seeds[time]
membrane_image_time = membrane_image[time]
voll_cell_seg_time = CellPoseWater(
membrane_image_time, sized_smart_seeds_time,cellpose_labels
membrane_image_time, sized_smart_seeds_time, cellpose_labels
)
voll_cell_seg.append(voll_cell_seg_time)
voll_cell_seg = np.asarray(voll_cell_seg_time)
Expand Down Expand Up @@ -4582,41 +4586,39 @@ def simple_dist(label_image):
# Create an empty output image
binary_image = np.zeros_like(label_image, dtype=np.float32)
binary_image = find_boundaries(label_image, mode="outer") * 255
binary_image = gaussian_filter(binary_image, sigma = 2)
binary_image = gaussian_filter(binary_image, sigma=2)
output_image = binary_image / np.max(binary_image)
return output_image


return output_image


def CellPoseWater(membrane_image, sized_smart_seeds, cellpose_labels):

cellpose_labels_copy_binary = cellpose_labels > 0

cellpose_labels_copy_binary = cellpose_labels > 0

# Get centroids of regions in the current slice
properties = measure.regionprops(sized_smart_seeds)

Coordinates = [prop.centroid for prop in properties]
Coordinates.append((0, 0, 0))
Coordinates = np.asarray(Coordinates)
coordinates_int = np.round(Coordinates).astype(int)
markers_raw = np.zeros_like(sized_smart_seeds)
markers_raw[tuple(coordinates_int.T)] = 1 + np.arange(len(Coordinates))
markers = morphology.dilation(markers_raw.astype("uint16"), morphology.ball(2))
membrane_image = gaussian_filter(membrane_image, sigma=4)
membrane_image = gaussian_filter(membrane_image, sigma=1)
inverted_membrane = membrane_image == 0
# Apply watershed for the current slice
distance_map = distance_transform_edt(inverted_membrane)
watershed_result = watershed(-distance_map, markers, mask=cellpose_labels_copy_binary)

watershed_result = watershed(
-distance_map, markers, mask=cellpose_labels_copy_binary
)

# Relabel sequentially to remove any gaps in the label numbers
watershed_result, _, _ = relabel_sequential(watershed_result.astype(np.uint16))


return watershed_result


def relabel_image(image1: np.ndarray, image2: np.ndarray) -> np.ndarray:
"""
Relabels image1 such that its minimum label is greater than the maximum label in image2.
Expand All @@ -4639,6 +4641,7 @@ def relabel_image(image1: np.ndarray, image2: np.ndarray) -> np.ndarray:

return relabeled_image1


def WatershedwithMask3D(Image, Label, mask, nms_thresh, seedpool=True, z_thresh=1):

print("Watershed with Mask 3D")
Expand Down

0 comments on commit 16351e4

Please sign in to comment.