@@ -422,44 +422,50 @@ def train_epoch():
422
422
iterator = tqdm (ds_iter_train , ncols = 80 )
423
423
t = time .time ()
424
424
for n , sample in enumerate (iterator ):
425
- if n > 0 :
426
- meters_time ["data" ].add (time .time () - t )
427
-
428
- optimizer .zero_grad ()
429
-
430
- t = time .time ()
431
- loss = h (data = sample , meters = meters_train )
432
- meters_time ["forward" ].add (time .time () - t )
433
- iterator .set_postfix (loss = loss .item ())
434
- meters_train ["loss_total" ].add (loss .item ())
435
-
436
- t = time .time ()
437
- loss .backward ()
438
- total_grad_norm = torch .nn .utils .clip_grad_norm_ (
439
- model .parameters (),
440
- max_norm = args .clip_grad_norm ,
441
- norm_type = 2 ,
442
- )
443
- meters_train ["grad_norm" ].add (torch .as_tensor (total_grad_norm ).item ())
444
-
445
- optimizer .step ()
446
- meters_time ["backward" ].add (time .time () - t )
447
- meters_time ["memory" ].add (
448
- torch .cuda .max_memory_allocated () / 1024.0 ** 2 ,
449
- )
450
-
451
- if epoch < args .n_epochs_warmup :
452
- lr_scheduler_warmup .step ()
453
- t = time .time ()
425
+ if n < 5 :
426
+ if n > 0 :
427
+ meters_time ["data" ].add (time .time () - t )
428
+
429
+ optimizer .zero_grad ()
430
+
431
+ t = time .time ()
432
+ loss = h (data = sample , meters = meters_train )
433
+ meters_time ["forward" ].add (time .time () - t )
434
+ iterator .set_postfix (loss = loss .item ())
435
+ meters_train ["loss_total" ].add (loss .item ())
436
+
437
+ t = time .time ()
438
+ loss .backward ()
439
+ total_grad_norm = torch .nn .utils .clip_grad_norm_ (
440
+ model .parameters (),
441
+ max_norm = args .clip_grad_norm ,
442
+ norm_type = 2 ,
443
+ )
444
+ meters_train ["grad_norm" ].add (torch .as_tensor (total_grad_norm ).item ())
445
+
446
+ optimizer .step ()
447
+ meters_time ["backward" ].add (time .time () - t )
448
+ meters_time ["memory" ].add (
449
+ torch .cuda .max_memory_allocated () / 1024.0 ** 2 ,
450
+ )
451
+
452
+ if epoch < args .n_epochs_warmup :
453
+ lr_scheduler_warmup .step ()
454
+ t = time .time ()
455
+ else :
456
+ continue
454
457
if epoch >= args .n_epochs_warmup :
455
458
lr_scheduler .step ()
456
459
457
460
@torch .no_grad ()
458
461
def validation ():
459
462
model .eval ()
460
- for sample in tqdm (ds_iter_val , ncols = 80 ):
461
- loss = h (data = sample , meters = meters_val )
462
- meters_val ["loss_total" ].add (loss .item ())
463
+ for n , sample in enumerate (tqdm (ds_iter_val , ncols = 80 )):
464
+ if n < 5 :
465
+ loss = h (data = sample , meters = meters_val )
466
+ meters_val ["loss_total" ].add (loss .item ())
467
+ else :
468
+ continue
463
469
464
470
@torch .no_grad ()
465
471
def test ():
0 commit comments