-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathmain_ddp.py
executable file
·467 lines (380 loc) · 18.9 KB
/
main_ddp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import CosineAnnealingLR
from util.data_util import ModelNet40 as ModelNet40
import numpy as np
from util.util import cal_loss, load_cfg_from_cfg_file, merge_cfg_from_list, find_free_port, AverageMeter, intersectionAndUnionGPU
import time
import logging
import random
from tensorboardX import SummaryWriter
def get_logger():
logger_name = "main-logger"
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
handler.setFormatter(logging.Formatter(fmt))
logger.addHandler(handler)
file_handler = logging.FileHandler(os.path.join('checkpoints', args.exp_name, 'main-' + str(int(time.time())) + '.log'))
file_handler.setFormatter(logging.Formatter(fmt))
logger.addHandler(file_handler)
return logger
def get_parser():
parser = argparse.ArgumentParser(description='3D Object Classification')
parser.add_argument('--config', type=str, default='config/dgcnn_paconv.yaml', help='config file')
parser.add_argument('opts', help='see config/dgcnn_paconv.yaml for all options', default=None, nargs=argparse.REMAINDER)
args = parser.parse_args()
assert args.config is not None
cfg = load_cfg_from_cfg_file(args.config)
if args.opts is not None:
cfg = merge_cfg_from_list(cfg, args.opts)
cfg['classes'] = cfg.get('classes', 40)
cfg['sync_bn'] = cfg.get('sync_bn', True)
cfg['dist_url'] = cfg.get('dist_url', 'tcp://127.0.0.1:6789')
cfg['dist_backend'] = cfg.get('dist_backend', 'nccl')
cfg['multiprocessing_distributed'] = cfg.get('multiprocessing_distributed', True)
cfg['world_size'] = cfg.get('world_size', 1)
cfg['rank'] = cfg.get('rank', 0)
cfg['manual_seed'] = cfg.get('manual_seed', 0)
cfg['workers'] = cfg.get('workers', 6)
cfg['print_freq'] = cfg.get('print_freq', 10)
return cfg
def worker_init_fn(worker_id):
random.seed(args.manual_seed + worker_id)
def main_process():
return not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % args.ngpus_per_node == 0)
# weight initialization:
def weight_init(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.Conv1d):
torch.nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.BatchNorm2d):
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.BatchNorm1d):
torch.nn.init.constant_(m.weight, 1)
torch.nn.init.constant_(m.bias, 0)
def train(gpu, ngpus_per_node):
# ============= Model ===================
if args.arch == 'dgcnn':
from model.DGCNN_PAConv import PAConv
model = PAConv(args)
elif args.arch == 'pointnet':
from model.PointNet_PAConv import PAConv
model = PAConv(args)
else:
raise Exception("Not implemented")
model.apply(weight_init)
if main_process():
logger.info(model)
if args.sync_bn and args.distributed:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.distributed:
torch.cuda.set_device(gpu)
args.batch_size = int(args.batch_size / ngpus_per_node)
args.test_batch_size = int(args.test_batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu], find_unused_parameters=True)
else:
model = torch.nn.DataParallel(model.cuda())
# =========== Dataloader =================
train_data = ModelNet40(partition='train', num_points=args.num_points, pt_norm=args.pt_norm)
test_data = ModelNet40(partition='test', num_points=args.num_points, pt_norm=False)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)
else:
train_sampler = None
test_sampler = None
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler,
drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=test_sampler)
# ============= Optimizer ===================
if main_process():
logger.info("Use SGD")
opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=1e-4)
scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr/100)
criterion = cal_loss
best_test_acc = 0
start_epoch = 0
# ============= Training from scratch=================
for epoch in range(start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_epoch(train_loader, model, opt, scheduler, epoch, criterion)
test_acc = test_epoch(test_loader, model, epoch, criterion)
if test_acc >= best_test_acc and main_process():
best_test_acc = test_acc
logger.info('Max Acc:%.6f' % best_test_acc)
torch.save(model.state_dict(), 'checkpoints/%s/best_model.t7' % args.exp_name) # save the best model
def train_epoch(train_loader, model, opt, scheduler, epoch, criterion):
train_loss = 0.0
count = 0.0
batch_time = AverageMeter()
data_time = AverageMeter()
forward_time = AverageMeter()
backward_time = AverageMeter()
loss_meter = AverageMeter()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()
model.train()
end = time.time()
max_iter = args.epochs * len(train_loader)
for ii, (data, label) in enumerate(train_loader):
data_time.update(time.time() - end)
data, label = data.cuda(non_blocking=True), label.cuda(non_blocking=True).squeeze(1)
data = data.permute(0, 2, 1)
batch_size = data.size(0)
end2 = time.time()
logits, loss = model(data, label, criterion)
forward_time.update(time.time() - end2)
preds = logits.max(dim=1)[1]
if not args.multiprocessing_distributed:
loss = torch.mean(loss)
end3 = time.time()
opt.zero_grad()
loss.backward() # the own loss of each process, backward by the optimizer belongs to this process
opt.step()
backward_time.update(time.time() - end3)
# Loss
if args.multiprocessing_distributed:
loss = loss * batch_size
_count = label.new_tensor([batch_size], dtype=torch.long).cuda(non_blocking=True) # b_size on one process
dist.all_reduce(loss), dist.all_reduce(_count) # obtain the sum of all xxx at all processes
n = _count.item()
loss = loss / n # avg loss across all processes
# then calculate loss same as without dist
count += batch_size
train_loss += loss.item() * batch_size
loss_meter.update(loss.item(), batch_size)
batch_time.update(time.time() - end)
end = time.time()
current_iter = epoch * len(train_loader) + ii + 1
remain_iter = max_iter - current_iter
remain_time = remain_iter * batch_time.avg
t_m, t_s = divmod(remain_time, 60)
t_h, t_m = divmod(t_m, 60)
remain_time = '{:02d}:{:02d}:{:02d}'.format(int(t_h), int(t_m), int(t_s))
if (ii + 1) % args.print_freq == 0 and main_process():
logger.info('Epoch: [{}/{}][{}/{}] '
'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
'Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Forward {for_time.val:.3f} ({for_time.avg:.3f}) '
'Backward {back_time.val:.3f} ({back_time.avg:.3f}) '
'Remain {remain_time} '
'Loss {loss_meter.val:.4f} '.format(epoch + 1, args.epochs, ii + 1, len(train_loader),
batch_time=batch_time,
data_time=data_time,
for_time = forward_time,
back_time = backward_time,
remain_time=remain_time,
loss_meter=loss_meter))
intersection, union, target = intersectionAndUnionGPU(preds, label, args.classes)
if args.multiprocessing_distributed: # obtain the sum of all tensors at all processes: all_reduce
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
scheduler.step()
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10) # the first sum here is to sum the acc across all classes
outstr = 'Train %d, loss: %.6f, train acc: %.6f, ' \
'train avg acc: %.6f' % (epoch + 1,
train_loss * 1.0 / count,
allAcc, mAcc)
if main_process():
logger.info(outstr)
# Write to tensorboard
writer.add_scalar('loss_train', train_loss * 1.0 / count, epoch + 1)
writer.add_scalar('mAcc_train', mAcc, epoch + 1)
writer.add_scalar('allAcc_train', allAcc, epoch + 1)
def test_epoch(test_loader, model, epoch, criterion):
test_loss = 0.0
count = 0.0
model.eval()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()
for data, label in test_loader:
data, label = data.cuda(non_blocking=True), label.cuda(non_blocking=True).squeeze(1)
data = data.permute(0, 2, 1)
batch_size = data.size(0)
logits = model(data)
# Loss
loss = criterion(logits, label) # here use model's output directly
if args.multiprocessing_distributed:
loss = loss * batch_size
_count = label.new_tensor([batch_size], dtype=torch.long).cuda(non_blocking=True)
dist.all_reduce(loss), dist.all_reduce(_count)
n = _count.item()
loss = loss / n
else:
loss = torch.mean(loss)
preds = logits.max(dim=1)[1]
count += batch_size
test_loss += loss.item() * batch_size
intersection, union, target = intersectionAndUnionGPU(preds, label, args.classes)
if args.multiprocessing_distributed:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
outstr = 'Test %d, loss: %.6f, test acc: %.6f, ' \
'test avg acc: %.6f' % (epoch + 1,
test_loss * 1.0 / count,
allAcc,
mAcc)
if main_process():
logger.info(outstr)
# Write to tensorboard
writer.add_scalar('loss_test', test_loss * 1.0 / count, epoch + 1)
writer.add_scalar('mAcc_test', mAcc, epoch + 1)
writer.add_scalar('allAcc_test', allAcc, epoch + 1)
return allAcc
def test(gpu, ngpus_per_node):
if main_process():
logger.info('<<<<<<<<<<<<<<<<< Start Evaluation <<<<<<<<<<<<<<<<<')
# ============= Model ===================
if args.arch == 'dgcnn':
from model.DGCNN_PAConv import PAConv
model = PAConv(args)
elif args.arch == 'pointnet':
from model.PointNet_PAConv import PAConv
model = PAConv(args)
else:
raise Exception("Not implemented")
if main_process():
logger.info(model)
if args.sync_bn:
assert args.distributed == True
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.distributed:
torch.cuda.set_device(gpu)
args.batch_size = int(args.batch_size / ngpus_per_node)
args.test_batch_size = int(args.test_batch_size / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu], find_unused_parameters=True)
else:
model = torch.nn.DataParallel(model.cuda())
state_dict = torch.load("checkpoints/%s/best_model.t7" % args.exp_name, map_location=torch.device('cpu'))
for k in state_dict.keys():
if 'module' not in k:
from collections import OrderedDict
new_state_dict = OrderedDict()
for k in state_dict:
new_state_dict['module.' + k] = state_dict[k]
state_dict = new_state_dict
break
model.load_state_dict(state_dict)
# Dataloader
test_data = ModelNet40(partition='test', num_points=args.num_points)
if args.distributed:
test_sampler = torch.utils.data.distributed.DistributedSampler(test_data)
else:
test_sampler = None
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=test_sampler)
model.eval()
intersection_meter = AverageMeter()
union_meter = AverageMeter()
target_meter = AverageMeter()
for data, label in test_loader:
data, label = data.cuda(non_blocking=True), label.cuda(non_blocking=True).squeeze(1)
data = data.permute(0, 2, 1)
with torch.no_grad():
logits = model(data)
preds = logits.max(dim=1)[1]
intersection, union, target = intersectionAndUnionGPU(preds, label, args.classes)
if args.multiprocessing_distributed:
dist.all_reduce(intersection), dist.all_reduce(union), dist.all_reduce(target)
intersection, union, target = intersection.cpu().numpy(), union.cpu().numpy(), target.cpu().numpy()
intersection_meter.update(intersection), union_meter.update(union), target_meter.update(target)
accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
mAcc = np.mean(accuracy_class)
allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
if main_process():
logger.info('Test result: mAcc/allAcc {:.4f}/{:.4f}.'.format(mAcc, allAcc))
for i in range(args.classes):
logger.info('Class_{} Result: accuracy {:.4f}.'.format(i, accuracy_class[i]))
logger.info('<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<')
def main_worker(gpu, ngpus_per_node, argss):
global args
args = argss
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size,
rank=args.rank)
if main_process():
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
if not os.path.exists('checkpoints/' + args.exp_name):
os.makedirs('checkpoints/' + args.exp_name)
if not args.eval: # backup the running files
os.system('cp main_ddp.py checkpoints' + '/' + args.exp_name + '/' + 'main_ddp.py.backup')
os.system('cp util/PAConv_util.py checkpoints' + '/' + args.exp_name + '/' + 'PAConv_util.py.backup')
os.system('cp util/data_util.py checkpoints' + '/' + args.exp_name + '/' + 'data_util.py.backup')
if args.arch == 'dgcnn':
os.system('cp model/DGCNN_PAConv.py checkpoints' + '/' + args.exp_name + '/' + 'DGCNN_PAConv.py.backup')
elif args.arch == 'pointnet':
os.system(
'cp model/PointNet_PAConv.py checkpoints' + '/' + args.exp_name + '/' + 'PointNet_PAConv.py.backup')
global logger, writer
writer = SummaryWriter('checkpoints/' + args.exp_name)
logger = get_logger()
logger.info(args)
args.cuda = not args.no_cuda and torch.cuda.is_available()
assert not args.eval, "The all_reduce function of PyTorch DDP will ignore/repeat inputs " \
"(leading to the wrong result), " \
"please use main.py to test (avoid DDP) for getting the right result."
train(gpu, ngpus_per_node)
if __name__ == "__main__":
args = get_parser()
args.gpu = [int(i) for i in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
if args.manual_seed is not None:
random.seed(args.manual_seed)
np.random.seed(args.manual_seed)
torch.manual_seed(args.manual_seed)
torch.cuda.manual_seed(args.manual_seed)
torch.cuda.manual_seed_all(args.manual_seed)
cudnn.benchmark = False
cudnn.deterministic = True
if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
args.ngpus_per_node = len(args.gpu)
if len(args.gpu) == 1:
args.sync_bn = False
args.distributed = False
args.multiprocessing_distributed = False
if args.multiprocessing_distributed:
port = find_free_port()
args.dist_url = f"tcp://127.0.0.1:{port}"
args.world_size = args.ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=args.ngpus_per_node, args=(args.ngpus_per_node, args))
else:
main_worker(args.gpu, args.ngpus_per_node, args)