-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsegmentation.py
109 lines (89 loc) · 3.66 KB
/
segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from skimage import color
import matplotlib.pyplot as plt
import numpy as np
from sklearn import cluster
from tqdm import tqdm
from pathlib import Path
import image_preprocessing
import get_img_paths
def segmentation(gray_image, nclust=3):
"""Applies K-Means segmentation to gray image, returning segmented image and labels matrix
The matrix contains labels in range(nclust) in the order of brightness
Args:
gray_image (Image Array): grayscale image
nclust (int, optional): number of means to be used in K-means. Defaults to 3.
Returns:
(Img, Label_Array): Tuple of segmented image and label array
"""
image = color.gray2rgb(gray_image)
x, y, z = image.shape
image_2d = image.reshape(x * y, z)
image_2d.shape
kmeans_cluster = cluster.KMeans(n_clusters=nclust)
kmeans_cluster.fit(image_2d)
cluster_centers = kmeans_cluster.cluster_centers_
cluster_labels = kmeans_cluster.labels_
segmented_img = cluster_centers[cluster_labels].reshape(x, y, z)
cluster_labels_matrix = cluster_labels.reshape(x, y)
# relabel the clusters such that 0 is darkest and nclust-1 is brightest
sorted_labels = _sort_labels(segmented_img, cluster_labels_matrix)
new_labels_matrix = np.zeros_like(cluster_labels_matrix)
for i in range(nclust):
new_labels_matrix[cluster_labels_matrix == sorted_labels[i]] = i
cluster_labels_matrix = new_labels_matrix
return segmented_img, cluster_labels_matrix
def _sort_labels(segmented_image, lables):
"""Returns a list of labels in the order of brightness
Args:
segmented_image (Image Array): Segmented Image
lables (Array): Array of labels for each pixel
Returns:
[int]: List of labels in the order of brightness
"""
label_brightness = []
for lable in range(np.max(lables) + 1):
label_brightness.append(np.max(segmented_image[lables == lable]))
sorted_labels = np.argsort(label_brightness)
return sorted_labels
if __name__ == '__main__':
# Script to run segmentation on all training images
output_path = Path('segmentation_output')
output_path.mkdir(exist_ok=True)
healthy, srf = get_img_paths.train_data()
nclust = 6
healthy_output_path = output_path / 'NoSRF'
healthy_output_path.mkdir(exist_ok=True)
for img_path in tqdm(healthy):
img = plt.imread(img_path)
img = image_preprocessing.preprocess(img)
seg_img, cluster_labels = segmentation(img, nclust=nclust)
# Uncomment one below if want segmented image or particular segment
# modified = seg_img
modified = cluster_labels == 1
plt.subplot(1, 2, 1)
plt.imshow(img, cmap='gray')
plt.subplot(1, 2, 2)
plt.imshow(modified, cmap='gray')
plt.suptitle(img_path.name)
plt.tight_layout()
file_name = healthy_output_path / img_path.name
plt.savefig(file_name, dpi=400, bbox_inches='tight')
plt.close()
srf_output_path = output_path / 'SRF'
srf_output_path.mkdir(exist_ok=True)
for img_path in tqdm(srf):
img = plt.imread(img_path)
img = image_preprocessing.preprocess(img)
seg_img, cluster_labels = segmentation(img, nclust=nclust)
# Uncomment one below if want segmented image or particular segment
# modified = seg_img
modified = cluster_labels == 1
plt.subplot(1, 2, 1)
plt.imshow(img, cmap='gray')
plt.subplot(1, 2, 2)
plt.imshow(modified, cmap='gray')
plt.suptitle(img_path.name)
plt.tight_layout()
file_name = srf_output_path / img_path.name
plt.savefig(file_name, dpi=400, bbox_inches='tight')
plt.close()