-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_clf.py
92 lines (72 loc) · 2.89 KB
/
train_clf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""Train a classifier.
This script trains a classifier, that can be used as a base model for triplet-
loss training. Of the trained classifier only the convolutional layers should
be saved, while the final fully connected layers will be dropped and replaced
in the tripletloss training.
"""
import numpy as np
import chainer
from chainer import cuda
from chainer import optimizers
from chainer import functions as F
from chainer import links as L
from aux import helpers
from aux.labelled_loader import LabelledLoader
from tripletembedding.aux import Logger, load_snapshot
from models import vgg_small
from models import vgg_xs
if __name__ == '__main__':
args = helpers.get_args()
model = vgg_small.VGGClf(2) # TODO provide parameter
xp = cuda.cupy if args.gpu >= 0 else np
dl = LabelledLoader(xp)
if args.gpu >= 0:
cuda.get_device(args.gpu).use()
dl.use_device(args.gpu)
model = model.to_gpu()
optimizer = optimizers.MomentumSGD(lr=0.001)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(args.weight_decay))
if args.initmodel and args.resume:
load_snapshot(args.initmodel, args.resume, model, optimizer)
print("Continuing from snapshot. LR: {}".format(optimizer.lr))
logger = Logger(args, optimizer, args.out)
logger = Logger(args, optimizer, args.out)
for _ in range(1, args.epoch + 1):
optimizer.new_epoch()
model.zerograds()
print('========\nepoch', optimizer.epoch)
# training
dl.create_sources(args.data, args.batchsize, 1.0 - args.test)
while True:
data = dl.get_batch('train')
if data is None:
break
t_data, x_data = data
x = chainer.Variable(x_data)
t = chainer.Variable(t_data)
optimizer.update(model, x, t)
logger.log_iteration("train", float(model.loss.data),
float(model.accuracy.data), 0.0, 0.0)
logger.log_mean("train")
if optimizer.epoch % args.lrinterval == 0 and optimizer.lr > 0.000001:
optimizer.lr *= 0.5
logger.mark_lr()
print("learning rate decreased to {}".format(optimizer.lr))
if optimizer.epoch % args.interval == 0:
logger.make_snapshot(model)
# testing
while True:
data = dl.get_batch('test')
if data is None:
break
t_data, x_data = data
x = chainer.Variable(x_data, volatile=True)
t = chainer.Variable(t_data, volatile=True)
loss = model(x, t)
logger.log_iteration("test", float(model.loss.data),
float(model.accuracy.data), 0.0, 0.0)
logger.log_mean("test")
# make final snapshot if not just taken one
if optimizer.epoch % args.interval != 0:
logger.make_snapshot(model)