diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 3c412fc0..18341050 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -27,6 +27,15 @@ class ARModel(pl.LightningModule): # pylint: disable=arguments-differ # Disable to override args/kwargs from superclass + @property + def lr(self): + """ + Get learning rate of optimizer + """ + if self._trainer: + return self.trainer.optimizers[0].param_groups[0]["lr"] + return self.config.training.optimization.lr + def __init__( self, args, @@ -316,7 +325,7 @@ def training_step(self, batch): log_dict = { "train_loss": batch_loss, - "lr": self.trainer.optimizers[0].param_groups[0]["lr"], + "lr": self.lr, } self.log_dict( log_dict, @@ -329,6 +338,11 @@ def training_step(self, batch): return batch_loss + def get_lr(self): + if hasattr(self, "trainer"): + return self.trainer.optimizers[0].param_groups[0]["lr"] + return None + def on_train_batch_end(self, outputs, batch, batch_idx): if scheduler := self.lr_schedulers(): if (