forked from chenzhaiyu/unet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
45 lines (32 loc) · 1.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from model import *
from data import *
from config import train_config, data_aug_config
IS_DATA_AUG = False
MODEL_NAME = None # can be chosen from [None, "model_filename", "last"]
if IS_DATA_AUG:
data_gen_args = data_aug_config
else:
data_gen_args = None
trainGene = trainGenerator(train_config["batch_size"], train_config["train_data_dir"],
'images', 'masks', data_gen_args, save_to_dir=None)
valGene = valGenerator(train_config["batch_size"], train_config["val_data_dir"],
'images', "masks", data_gen_args, save_to_dir=None)
model = unet()
model_checkpoint = ModelCheckpoint(os.path.join(train_config["weights_dir"],
'unet_buildings_weights.{epoch:02d}-{val_loss:.2f}.hdf5'),
monitor='loss', verbose=1, save_best_only=False)
initial_epoch = 0
if MODEL_NAME == "last":
# Load the last model you trained and continue training
last_model_path, initial_epoch = find_last(train_config["weights_dir"])
model.load_weights(last_model_path, by_name=True)
elif isinstance(MODEL_NAME, str) and MODEL_NAME != "last":
print("loading weights from {}".format(MODEL_NAME))
model.load_weights(os.path.join(train_config["weights_dir"], MODEL_NAME), by_name=True)
elif MODEL_NAME is None:
print("training model from scratch")
else:
raise NotImplementedError
model.fit_generator(trainGene, validation_data=valGene, validation_steps=train_config["validation_steps"],
steps_per_epoch=train_config["steps_epoch"], epochs=train_config["epochs"],
callbacks=[model_checkpoint], initial_epoch=initial_epoch)