-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(training): Add initial TimeLimit callback (#115)
* Add TimeLimit callback - Stops training if time limit is reached - Log to file the last checkpoint --------- Co-authored-by: Julien Lefaucheur <julien.lefaucheur@ecmwf.int>
- Loading branch information
Showing
3 changed files
with
194 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 103 additions & 0 deletions
103
training/src/anemoi/training/diagnostics/callbacks/stopping.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# (C) Copyright 2025- Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
import time | ||
from datetime import timedelta | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
|
||
import pytorch_lightning as pl | ||
|
||
from anemoi.utils.dates import frequency_to_string | ||
from anemoi.utils.dates import frequency_to_timedelta | ||
|
||
if TYPE_CHECKING: | ||
from omegaconf import OmegaConf | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class TimeLimit(pl.callbacks.Callback): | ||
"""Callback to stop the training process after a given time limit.""" | ||
|
||
def __init__(self, config: OmegaConf, limit: int | str, record_file: str | None = None) -> None: | ||
"""Initialise the TimeLimit callback. | ||
Parameters | ||
---------- | ||
limit : int or str | ||
The frequency to convert. If an integer, it is assumed to be in hours. If a string, it can be in the format: | ||
- "1h" for 1 hour | ||
- "1d" for 1 day | ||
- "1m" for 1 minute | ||
- "1s" for 1 second | ||
- "1:30" for 1 hour and 30 minutes | ||
- "1:30:10" for 1 hour, 30 minutes and 10 seconds | ||
- "PT10M" for 10 minutes (ISO8601) | ||
record_file : str or None | ||
The file to record the last checkpoint to. If None, no file is written. | ||
""" | ||
super().__init__() | ||
self.config = config | ||
|
||
self.limit = frequency_to_timedelta(limit) | ||
self._record_file = Path(record_file) if record_file is not None else None | ||
|
||
if self._record_file is not None and self._record_file.exists(): | ||
assert self._record_file.is_file(), "The record file must be a file." | ||
|
||
self._start_time = time.time() | ||
|
||
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | ||
_ = pl_module | ||
self._run_stopping_check(trainer) | ||
|
||
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: | ||
_ = pl_module | ||
self._run_stopping_check(trainer) | ||
|
||
def _run_stopping_check(self, trainer: pl.Trainer) -> None: | ||
""" | ||
Check if the time limit has been reached and stop the training if so. | ||
Parameters | ||
---------- | ||
trainer : pl.Trainer | ||
Pytorch Lightning trainer | ||
""" | ||
if timedelta(seconds=time.time() - self._start_time) < self.limit: | ||
return | ||
|
||
LOGGER.info("Time limit of %s seconds reached. Stopping training.", frequency_to_string(self.limit)) | ||
trainer.should_stop = True | ||
self._log_to_file(trainer) | ||
|
||
def _log_to_file(self, trainer: pl.Trainer) -> None: | ||
""" | ||
Log the last checkpoint path to a file if given. | ||
Parameters | ||
---------- | ||
trainer : pl.Trainer | ||
Pytorch Lightning trainer | ||
""" | ||
if self._record_file is not None: | ||
last_checkpoint = trainer.checkpoint_callback.last_model_path | ||
self._record_file.parent.mkdir(parents=True, exist_ok=True) | ||
|
||
if self._record_file.exists(): | ||
self._record_file.unlink() | ||
|
||
Path(self._record_file).write_text(str(last_checkpoint)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# (C) Copyright 2025- Anemoi contributors. | ||
# | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
|
||
from typing import Any | ||
from unittest.mock import Mock | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
from pytorch_lightning import Trainer | ||
|
||
from anemoi.training.diagnostics.callbacks.stopping import TimeLimit | ||
|
||
|
||
@pytest.fixture | ||
def mock_trainer() -> Mock: | ||
training_process = Mock(Trainer) | ||
training_process.should_stop = False | ||
return training_process | ||
|
||
|
||
def test_time_limit_stops_training(mock_trainer: Mock) -> None: | ||
# Setting a time limit of 1 second | ||
time_limit_callback = TimeLimit(config=None, limit="24h") | ||
|
||
# Mocking the time to simulate training for 2 seconds | ||
with patch("time.time", return_value=time_limit_callback._start_time + 24 * 60 * 60 + 1): | ||
time_limit_callback.on_train_epoch_end(mock_trainer, None) | ||
|
||
# Check if the training process should stop | ||
assert mock_trainer.should_stop | ||
|
||
|
||
def test_time_limit_stops_training_int(mock_trainer: Mock) -> None: | ||
# Setting a time limit of 1 second | ||
time_limit_callback = TimeLimit(config=None, limit=24) | ||
|
||
# Mocking the time to simulate training for 2 seconds | ||
with patch("time.time", return_value=time_limit_callback._start_time + 24 * 60 * 60 + 1): | ||
time_limit_callback.on_train_epoch_end(mock_trainer, None) | ||
|
||
# Check if the training process should stop | ||
assert mock_trainer.should_stop | ||
|
||
|
||
def test_time_limit_does_not_stop_training(mock_trainer: Mock) -> None: | ||
# Setting a time limit of 1 hour | ||
time_limit_callback = TimeLimit(config=None, limit="1h") | ||
|
||
# Mocking the time to simulate training for 2 seconds | ||
with patch("time.time", return_value=time_limit_callback._start_time + 2): | ||
time_limit_callback.on_train_epoch_end(mock_trainer, None) | ||
|
||
# Check if the training process should not stop | ||
assert not mock_trainer.should_stop | ||
|
||
|
||
def test_time_limit_creates_file_on_stop(mock_trainer: Mock, tmp_path: Any) -> None: | ||
# Setting a time limit of 1 second | ||
time_limit_callback = TimeLimit(config=None, limit="1s", record_file=tmp_path / "log") | ||
|
||
# Mocking the time to simulate training for 2 seconds | ||
with patch("time.time", return_value=time_limit_callback._start_time + 2): | ||
time_limit_callback.on_train_epoch_end(mock_trainer, None) | ||
|
||
# Check if the training process should stop | ||
assert mock_trainer.should_stop | ||
|
||
# Check if the file is created | ||
assert (tmp_path / "log").exists() | ||
|
||
|
||
def test_time_limit_does_not_create_file_when_not_stopping(mock_trainer: Mock, tmp_path: Any) -> None: | ||
# Setting a time limit of 1 second | ||
time_limit_callback = TimeLimit(config=None, limit="24h", record_file=tmp_path / "log") | ||
|
||
# Mocking the time to simulate training for 2 seconds | ||
with patch("time.time", return_value=time_limit_callback._start_time + 2): | ||
time_limit_callback.on_train_epoch_end(mock_trainer, None) | ||
|
||
# Check if the training process should stop | ||
assert not mock_trainer.should_stop | ||
|
||
# Check if the file is created | ||
assert not (tmp_path / "log").exists() |