-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
104 lines (91 loc) · 5.06 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import sys
import time
from options.test_options import TestOptions
from util.data_loader import DataLoader
from util.evaluation import ExcelEvaluate
from util.plot_handler import PlotHandler
from util.util import info, warning, get_timestamp, remove_background, normalize_with_opt
sys.path.append(os.getcwd() + "/build/cluster") # Fix this, make it nicer
sys.path.append(os.getcwd() + "/build/map") # Fix this, make it nicer
from map.map import Mapping
if __name__ == "__main__":
# Set options for printing and plotting
opt_handler = TestOptions()
opt_handler.initialize(opt_handler.parser)
opts = opt_handler.set_and_create()
timestamp_run = get_timestamp()
data_loader = DataLoader(opts, opt_handler)
plot_handler = PlotHandler(opts, opt_handler)
model_folder = opt_handler.set_model(opts)
# Retrieve main parameters
mapping_source = opts.mapping_source
mapping_target = opts.mapping_target
main_clusters = opts.main_clusters
sub_clusters = opts.sub_clusters
excel_filepath = plot_handler.plot_folder / (opts.test_set + "_" + opts.experiment_name + ".csv")
excel = ExcelEvaluate(excel_filepath, opts.excel)
# Look for the model if we are looking at a training model
if opts.model_phase == "train":
model_search_name = "model_" + opts.method + "_main" + str(main_clusters) + "_sub" + str(sub_clusters) + "_*"
models_files = sorted(model_folder.glob(model_search_name), key=lambda path: int(path.stem.rsplit("_", 1)[1]))
if len(models_files) < 1:
exit("No model has been found in the " + str(model_folder) + " folder with the given parameters.")
if -1 < opts.model_index < len(models_files):
model_index = opts.model_index
else:
model_index = -1
model = models_files[model_index] # Get the last one (most recent)
print("Testing will use model: " + str(model))
# Collect the segmented image
map = Mapping(data_loader, plot_handler, model_folder, main_clusters, sub_clusters, opts.method)
map.restore_table(model)
time_init = time.time()
for query_filename in data_loader.query_files:
query_friendly_filename = query_filename.name
info("Testing with image " + query_friendly_filename + ", please make sure that you used the same "
"settings as for training.")
# Look for the model if we are looking at a search model (one per file)
if opts.model_phase == "search":
model_search_name = "model_" + opts.method + "_main" + str(main_clusters) + "_sub" + str(
sub_clusters) + "_*"
models_files = sorted((model_folder / query_friendly_filename).glob(model_search_name),
key=lambda path: int(path.stem.rsplit("_", 1)[1]))
if len(models_files) < 1:
exit("No model has been found in the " + str(model_folder) + " folder with the given parameters.")
if -1 < opts.model_index < len(models_files):
model_index = opts.model_index
else:
model_index = -1
model = models_files[model_index] # Get the last one (most recent)
print("Testing will use model: " + str(model))
# Collect the segmented image
map = Mapping(data_loader, plot_handler, model_folder, main_clusters, sub_clusters, opts.method)
map.restore_table(model)
# Find the query MRIs
mris = data_loader.return_file(query_filename, query_file=True)
truth_nonzero = None
if 'truth' in mris:
# Consider the truth about the tumour
truth_mri, truth_nonzero = remove_background(mris['truth'])
# If the truth is not there, then we don't have any tumour on this slice
if len(truth_nonzero) == 0:
truth_nonzero = None
warning("The slice " + str(opts.chosen_slice) + " does not contain any tumour, "
"and thus the tumour MSE cannot be computed.")
else:
plot_handler.print_tumour(mris['truth'], query_friendly_filename, data_loader.mri_shape,
data_loader.affine)
info(
"Computing mapping " + mapping_source + " to " + mapping_target + " for query " + query_friendly_filename + ".")
mris = map.return_results_query(mris, opts.smoothing)
if 'target' in mris:
excel.evaluate(mris, query_friendly_filename, truth_nonzero, opts.smoothing)
for label in mris.keys():
if 'truth' not in label:
mris[label] = normalize_with_opt(mris[label], opts.postprocess)
plot_handler.plot_results(mris, query_friendly_filename, opts.smoothing, data_loader.mri_shape,
data_loader.affine)
time_end = round(time.time() - time_init, 3)
print("Time spent for testing " + str(data_loader.query_files_size) + " images " + str(time_end) + "s.")
excel.close()