Skip to content

Commit

Permalink
make lr a feature that is only accessed if the model has a trainer at…
Browse files Browse the repository at this point in the history
…tached. Otherwise return the lr in the configuration
  • Loading branch information
Jacob Mathias Schreiner committed Jan 28, 2025
1 parent ea9b9df commit 4e111a6
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ class ARModel(pl.LightningModule):
# pylint: disable=arguments-differ
# Disable to override args/kwargs from superclass

@property
def lr(self):
"""
Get learning rate of optimizer
"""
if self._trainer:
return self.trainer.optimizers[0].param_groups[0]["lr"]
return self.config.training.optimization.lr

def __init__(
self,
args,
Expand Down Expand Up @@ -316,7 +325,7 @@ def training_step(self, batch):

log_dict = {
"train_loss": batch_loss,
"lr": self.trainer.optimizers[0].param_groups[0]["lr"],
"lr": self.lr,
}
self.log_dict(
log_dict,
Expand All @@ -329,6 +338,11 @@ def training_step(self, batch):

return batch_loss

def get_lr(self):
if hasattr(self, "trainer"):
return self.trainer.optimizers[0].param_groups[0]["lr"]
return None

def on_train_batch_end(self, outputs, batch, batch_idx):
if scheduler := self.lr_schedulers():
if (
Expand Down

0 comments on commit 4e111a6

Please sign in to comment.