-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsupervised.py
269 lines (206 loc) · 10 KB
/
supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import argparse
import logging
import glob
import os
import pprint
import torch
import numpy as np
from torch import nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import yaml
from dataset.semi import SemiDataset
from model.semseg.dpt import DPT
from util.classes import CLASSES
from util.ohem import ProbOhemCrossEntropy2d
from util.utils import count_params, AverageMeter, intersectionAndUnion, init_log
from util.dist_helper import setup_distributed
parser = argparse.ArgumentParser(description='Fully-Supervised Training in Semantic Segmentation')
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--labeled-id-path', type=str, required=True)
parser.add_argument('--unlabeled-id-path', type=str, default=None)
parser.add_argument('--pretrained-path', type=str, default=None)
parser.add_argument('--save-path', type=str, required=True)
parser.add_argument('--local_rank', '--local-rank', default=0, type=int)
parser.add_argument('--port', default=None, type=int)
def evaluate(model, loader, mode, cfg, multiplier=None):
model.eval()
assert mode in ['original', 'center_crop', 'sliding_window']
intersection_meter = AverageMeter()
union_meter = AverageMeter()
with torch.no_grad():
for img, mask, id in loader:
img = img.cuda()
if mode == 'sliding_window':
grid = cfg['crop_size']
b, _, h, w = img.shape
final = torch.zeros(b, 19, h, w).cuda()
row = 0
while row < h:
col = 0
while col < w:
pred = model(img[:, :, row: row + grid, col: col + grid])
final[:, :, row: row + grid, col: col + grid] += pred.softmax(dim=1)
if col == w - grid:
break
col = min(col + int(grid * 2 / 3), w - grid)
if row == h - grid:
break
row = min(row + int(grid * 2 / 3), h - grid)
pred = final
else:
assert mode == 'original'
if multiplier is not None:
ori_h, ori_w = img.shape[-2:]
if multiplier == 512:
new_h, new_w = 512, 512
else:
new_h, new_w = int(ori_h / multiplier + 0.5) * multiplier, int(ori_w / multiplier + 0.5) * multiplier
img = F.interpolate(img, (new_h, new_w), mode='bilinear', align_corners=True)
pred = model(img)
if multiplier is not None:
pred = F.interpolate(pred, (ori_h, ori_w), mode='bilinear', align_corners=True)
pred = pred.argmax(dim=1)
intersection, union, target = \
intersectionAndUnion(pred.cpu().numpy(), mask.numpy(), cfg['nclass'], 255)
reduced_intersection = torch.from_numpy(intersection).cuda()
reduced_union = torch.from_numpy(union).cuda()
reduced_target = torch.from_numpy(target).cuda()
dist.all_reduce(reduced_intersection)
dist.all_reduce(reduced_union)
dist.all_reduce(reduced_target)
intersection_meter.update(reduced_intersection.cpu().numpy())
union_meter.update(reduced_union.cpu().numpy())
iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) * 100.0
mIOU = np.mean(iou_class)
return mIOU, iou_class
def main():
args = parser.parse_args()
cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
logger = init_log('global', logging.INFO)
logger.propagate = 0
rank, world_size = setup_distributed(port=args.port)
cfg['batch_size'] *= 2
if rank == 0:
all_args = {**cfg, **vars(args), 'ngpus': world_size}
logger.info('{}\n'.format(pprint.pformat(all_args)))
writer = SummaryWriter(args.save_path)
os.makedirs(args.save_path, exist_ok=True)
cudnn.enabled = True
cudnn.benchmark = True
model_configs = {
'small': {'encoder_size': 'small', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'base': {'encoder_size': 'base', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'large': {'encoder_size': 'large', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'giant': {'encoder_size': 'giant', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
model = DPT(**{**model_configs[cfg['backbone'].split('_')[-1]], 'nclass': cfg['nclass']})
state_dict = torch.load(f'./pretrained/{cfg["backbone"]}.pth')
model.backbone.load_state_dict(state_dict)
if cfg['lock_backbone']:
model.lock_backbone()
optimizer = AdamW(
[
{'params': [p for p in model.backbone.parameters() if p.requires_grad], 'lr': cfg['lr']},
{'params': [param for name, param in model.named_parameters() if 'backbone' not in name], 'lr': cfg['lr'] * cfg['lr_multi']}
],
lr=cfg['lr'], betas=(0.9, 0.999), weight_decay=0.01
)
if rank == 0:
logger.info('Total params: {:.1f}M\n'.format(count_params(model)))
local_rank = int(os.environ["LOCAL_RANK"])
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda(local_rank)
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[local_rank], broadcast_buffers=False, output_device=local_rank, find_unused_parameters=True
)
if cfg['criterion']['name'] == 'CELoss':
criterion = nn.CrossEntropyLoss(**cfg['criterion']['kwargs']).cuda(local_rank)
elif cfg['criterion']['name'] == 'OHEM':
criterion = ProbOhemCrossEntropy2d(**cfg['criterion']['kwargs']).cuda(local_rank)
else:
raise NotImplementedError('%s criterion is not implemented' % cfg['criterion']['name'])
n_upsampled = {
'pascal': 3000,
'cityscapes': 3000,
'ade20k': 6000,
'coco': 30000
}
trainset = SemiDataset(
cfg['dataset'], cfg['data_root'], 'train_l', cfg['crop_size'], args.labeled_id_path, nsample=n_upsampled[cfg['dataset']]
)
valset = SemiDataset(
cfg['dataset'], cfg['data_root'], 'val'
)
trainsampler = torch.utils.data.distributed.DistributedSampler(trainset)
trainloader = DataLoader(
trainset, batch_size=cfg['batch_size'], pin_memory=True, num_workers=4, drop_last=True, sampler=trainsampler
)
valsampler = torch.utils.data.distributed.DistributedSampler(valset)
valloader = DataLoader(
valset, batch_size=1, pin_memory=True, num_workers=1, drop_last=False, sampler=valsampler
)
iters = 0
total_iters = len(trainloader) * cfg['epochs']
previous_best = 0.0
epoch = -1
if os.path.exists(os.path.join(args.save_path, 'latest.pth')):
checkpoint = torch.load(os.path.join(args.save_path, 'latest.pth'), map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
previous_best = checkpoint['previous_best']
if rank == 0:
logger.info('************ Load from checkpoint at epoch %i\n' % epoch)
for epoch in range(epoch + 1, cfg['epochs']):
if rank == 0:
logger.info('===========> Epoch: {:}, LR: {:.7f}, Previous best: {:.2f}'.format(
epoch, optimizer.param_groups[0]['lr'], previous_best))
model.train()
total_loss = AverageMeter()
trainsampler.set_epoch(epoch)
for i, (img, mask) in enumerate(trainloader):
img, mask = img.cuda(), mask.cuda()
pred = model(img)
loss = criterion(pred, mask)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss.update(loss.item())
iters = epoch * len(trainloader) + i
lr = cfg['lr'] * (1 - iters / total_iters) ** 0.9
optimizer.param_groups[0]["lr"] = lr
optimizer.param_groups[1]["lr"] = lr * cfg['lr_multi']
if rank == 0:
writer.add_scalar('train/loss_all', loss.item(), iters)
writer.add_scalar('train/loss_x', loss.item(), iters)
if (i % (len(trainloader) // 8) == 0) and (rank == 0):
logger.info('Iters: {:}, Total loss: {:.3f}'.format(i, total_loss.avg))
eval_mode = 'sliding_window' if cfg['dataset'] == 'cityscapes' else 'original'
mIoU, iou_class = evaluate(model, valloader, eval_mode, cfg, multiplier=14)
if rank == 0:
for (cls_idx, iou) in enumerate(iou_class):
logger.info('***** Evaluation ***** >>>> Class [{:} {:}] '
'IoU: {:.2f}'.format(cls_idx, CLASSES[cfg['dataset']][cls_idx], iou))
logger.info('***** Evaluation {} ***** >>>> MeanIoU: {:.2f}\n'.format(eval_mode, mIoU))
writer.add_scalar('eval/mIoU', mIoU, epoch)
for i, iou in enumerate(iou_class):
writer.add_scalar('eval/%s_IoU' % (CLASSES[cfg['dataset']][i]), iou, epoch)
is_best = mIoU > previous_best
previous_best = max(mIoU, previous_best)
if rank == 0:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'previous_best': previous_best,
}
torch.save(checkpoint, os.path.join(args.save_path, 'latest.pth'))
if is_best:
torch.save(checkpoint, os.path.join(args.save_path, 'best.pth'))
if __name__ == '__main__':
main()