-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
123 lines (97 loc) · 3.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
# coding: utf-8
"""
utils module
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
def median_frequency_balancing(training_data, n_classes):
""" The weight of class c is computed as weight(c) = median freq/freq(c) where
freq(c) is the number of pixels of class c divided by the total number of
pixels in images where c is present, and median freq is the median of these
frequencies. Based on the paper `Predicting Depth, Surface Normals and Semantic
Labels with a Common Multi-Scale Convolutional Architecture <http://https://arxiv.org/abs/1411.4734>`_
Args:
training_data (array) - Training dataset of pairs (image, labels)
n_classes (int) - Number of classes
Returns:
(array) - Weight of each class
"""
freqs = np.zeros(n_classes) # Frequencies of each class
pixel_counts = np.zeros(n_classes) # Number of pixels of class c
total_counts = np.zeros(n_classes) # Total number of pixels in images where c is present
for p in range(len(training_data)):
pixels = training_data[p][1].numpy()
classes, counts = np.unique(pixels, return_counts=True)
total_pixels = (pixels.shape[0] * pixels.shape[1])
for i, c in enumerate(classes):
pixel_counts[c] += counts[i]
total_counts[c] += total_pixels
for i in range(freqs.shape[0]):
freqs[i] = pixel_counts[i] / total_counts[i]
median = np.median(freqs)
weights = median / freqs
return weights
def plot_seg_results(images, ground_truths, predictions):
""" Plot a grid of several images, their ground-truth segmentations
and their predicted segmentations.
Args:
images (array-like shape) - Images
ground_truths (array-like shape) - Ground-truth segmentations
predictions (array-like shape) - Predicted segmentations
"""
f, axarr = plt.subplots(len(images), 3)
f.set_size_inches(10,3*len(images))
for i in range(len(images)):
axarr[i,0].imshow(images[i])
axarr[i,1].imshow(ground_truths[i])
axarr[i,2].imshow(predictions[i].squeeze())
# Remove axis
for i in range(len(images)):
for j in range(3):
axarr[i,j].xaxis.set_visible(False)
axarr[i,j].yaxis.set_visible(False)
# Set columns titles
axarr[0,0].set_title('IMAGE')
axarr[0,1].set_title('GROUND TRUTH')
axarr[0,2].set_title('PREDICTION')
plt.show()
def plot_seg_result(image, ground_truth, prediction):
""" Show a grid of several images, their ground-truth segmentations
and their predicted segmentations.
Args:
image (torch.Tensor or numpy.array) - Image
ground_truth (torch.Tensor or numpy.array) - Ground-truth segmentation
prediction (torch.Tensor or numpy.array) - Predicted segmentation
"""
f, axarr = plt.subplots(1,3)
f.set_size_inches(5, 10)
axarr[0].imshow(image)
axarr[1].imshow(ground_truth)
axarr[2].imshow(prediction)
# Remove axis
for j in range(3):
axarr[i,j].xaxis.set_visible(False)
axarr[i,j].yaxis.set_visible(False)
# Set columns titles
axarr[0].set_title('IMAGE')
axarr[1].set_title('GROUND TRUTH')
axarr[2].set_title('PREDICTION')
plt.show()
def plot_metric(metric_history, label, color='b'):
""" Plot a metric vs. the epochs
Args:
metric_history (numpy.array): history of the metric's values order
from older to newer.
label (string): y-axis label
title (string): title for the plot
"""
epochs = range(len(metric_history))
plt.plot(epochs, metric_history, color, label=label)
plt.title(label + " vs. Epochs")
plt.xticks(np.arange(0, len(epochs), 2.0))
plt.xlabel('Epochs')
plt.ylabel(label)
plt.legend()
plt.show()