-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsolver.py
174 lines (135 loc) · 4.85 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import torch
import time
import datetime
from model import STDN
class Solver(object):
DEFAULTS = {}
def __init__(self, version, train_data_loader, test_data_loader, config):
"""
Initializes a Solver object
"""
# data loader
self.__dict__.update(Solver.DEFAULTS, **config)
self.version = version
self.train_data_loader = train_data_loader
self.test_data_loader = test_data_loader
self.build_model()
# TODO: build tensorboard
# start with a pre-trained model
if self.pretrained_model:
self.load_pretrained_model()
def build_model(self):
"""
Instantiates the model, loss criterion, and optimizer
"""
# instantiate model
self.model = STDN(config=self.config,
channels=self.input_channels,
class_count=self.class_count,
num_features=self.num_features,
compress_factor=self.compress_factor,
expand_factor=self.expand_factor,
growth_rate=self.growth_rate)
# TODO: instantiate loss criterion
# TODO: instantiate optimizer
# TODO: print network
# self.print_network(self.model, '')
# TODO: use gpu if enabled
# if torch.cuda.is_available() and self.use_gpu:
# self.model.cuda()
# self.criterion.cuda()
def print_network(self, model, name):
"""
Prints the structure of the network and the total number of parameters
"""
num_params = 0
for p in model.parameters():
num_params += p.numel()
print(name)
print(model)
print("The number of parameters: {}".format(num_params))
def load_pretrained_model(self):
"""
loads a pre-trained model from a .pth file
"""
self.model.load_state_dict(torch.load(os.path.join(
self.model_save_path, '{}.pth'.format(self.pretrained_model))))
print('loaded trained model ver {}'.format(self.pretrained_model))
def print_loss_log(self, start_time, iters_per_epoch, e, i, loss):
"""
Prints the loss and elapsed time for each epoch
"""
total_iter = self.num_epochs * iters_per_epoch
cur_iter = e * iters_per_epoch + i
elapsed = time.time() - start_time
total_time = (total_iter - cur_iter) * elapsed / (cur_iter + 1)
epoch_time = (iters_per_epoch - i) * elapsed / (cur_iter + 1)
epoch_time = str(datetime.timedelta(seconds=epoch_time))
total_time = str(datetime.timedelta(seconds=total_time))
elapsed = str(datetime.timedelta(seconds=elapsed))
log = "Elapsed {}/{} -- {}, Epoch [{}/{}], Iter [{}/{}], " \
"loss: {:.4f}".format(elapsed,
epoch_time,
total_time,
e + 1,
self.num_epochs,
i + 1,
iters_per_epoch,
loss)
# TODO: add tensorboard
print(log)
def save_model(self, e):
"""
Saves a model per e epoch
"""
path = os.path.join(
self.model_save_path,
'{}/{}.pth'.format(self.version, e + 1)
)
torch.save(self.model.state_dict(), path)
def model_step(self, images, labels):
"""
A step for each iteration
"""
# TODO: set model in training mode
# self.model.train()
# TODO: empty the gradients of the model through the optimizer
# TODO: self.optimizer.zero_grad()
# TODO: forward pass
# TODO: output = self.model(images)
# TODO: compute loss
# TODO: loss = self.criterion(output, labels.squeeze())
# TODO: compute gradients using back propagation
# loss.backward()
# TODO: update parameters
# self.optimizer.step()
# TODO: return loss
# return loss
pass
def train(self):
"""
Training process
"""
# TODO: add training process
pass
def eval(self, data_loader):
"""
Returns the count of top 1 and top 5 predictions
"""
# set the model to eval mode
# TODO: self.model.eval()
# TODO: return evaluation metric
pass
def train_evaluate(self, e):
"""
Evaluates the performance of the model using the train dataset
"""
# TODO: call self.eval() then print log
pass
def test(self):
"""
Evaluates the performance of the model using the test dataset
"""
# TODO: call self.eval() then print log
pass