forked from ellisdg/3DUnetCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
79 lines (68 loc) · 3.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import glob
from unet3d.data import write_data_to_file, open_data_file
from unet3d.generator import get_training_and_validation_generators
from unet3d.model import unet_model_3d
from unet3d.training import load_old_model, train_model
from brats.config import config_one
config = config_one
def fetch_training_data_files():
training_data_files = list()
print(config["preprocessed"])
for subject_dir in glob.glob(os.path.join(os.path.dirname(__file__), "data", config["preprocessed"], "*", "*")):
subject_files = list()
for modality in config["training_modalities"] + ["truth"]:
subject_files.append(os.path.join(subject_dir, modality + ".nii.gz"))
training_data_files.append(tuple(subject_files))
return training_data_files
def main(overwrite=False):
# convert input images into an hdf5 file
if overwrite or not os.path.exists(config["data_file"]):
training_files = fetch_training_data_files()
write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"])
data_file_opened = open_data_file(config["data_file"])
if not overwrite and os.path.exists(config["model_file"]):
model = load_old_model(config["model_file"])
else:
# instantiate new model
print("BN:")
model = unet_model_3d(input_shape=config["input_shape"],
pool_size=config["pool_size"],
n_labels=config["n_labels"],
initial_learning_rate=config["initial_learning_rate"],
deconvolution=config["deconvolution"],
batch_normalization=False)
# get training and testing generators
train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
data_file_opened,
batch_size=config["batch_size"],
data_split=config["validation_split"],
overwrite=overwrite,
validation_keys_file=config["validation_file"],
training_keys_file=config["training_file"],
n_labels=config["n_labels"],
labels=config["labels"],
patch_shape=config["patch_shape"],
validation_batch_size=config["validation_batch_size"],
validation_patch_overlap=config["validation_patch_overlap"],
training_patch_start_offset=config["training_patch_start_offset"],
permute=config["permute"],
augment=config["augment"],
skip_blank=config["skip_blank"],
augment_flip=config["flip"],
augment_distortion_factor=config["distort"])
# run training
train_model(model=model,
model_file=config["model_file"],
training_generator=train_generator,
validation_generator=validation_generator,
steps_per_epoch=n_train_steps,
validation_steps=n_validation_steps,
initial_learning_rate=config["initial_learning_rate"],
learning_rate_drop=config["learning_rate_drop"],
learning_rate_patience=config["patience"],
early_stopping_patience=config["early_stop"],
n_epochs=config["n_epochs"])
data_file_opened.close()
if __name__ == "__main__":
main(overwrite=config["overwrite"])