diff --git a/training/src/anemoi/training/diagnostics/callbacks/__init__.py b/training/src/anemoi/training/diagnostics/callbacks/__init__.py index 65a19ce1..8a7eab0b 100644 --- a/training/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/training/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -188,7 +188,6 @@ def get_callbacks(config: DictConfig) -> list[Callback]: ) # Plotting callbacks - trainer_callbacks.extend( instantiate(callback, config) for callback in config.diagnostics.plot.get("callbacks", None) or [] ) diff --git a/training/src/anemoi/training/diagnostics/callbacks/stopping.py b/training/src/anemoi/training/diagnostics/callbacks/stopping.py new file mode 100644 index 00000000..d5c674b5 --- /dev/null +++ b/training/src/anemoi/training/diagnostics/callbacks/stopping.py @@ -0,0 +1,103 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +import time +from datetime import timedelta +from pathlib import Path +from typing import TYPE_CHECKING + +import pytorch_lightning as pl + +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta + +if TYPE_CHECKING: + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class TimeLimit(pl.callbacks.Callback): + """Callback to stop the training process after a given time limit.""" + + def __init__(self, config: OmegaConf, limit: int | str, record_file: str | None = None) -> None: + """Initialise the TimeLimit callback. + + Parameters + ---------- + limit : int or str + The frequency to convert. If an integer, it is assumed to be in hours. If a string, it can be in the format: + + - "1h" for 1 hour + - "1d" for 1 day + - "1m" for 1 minute + - "1s" for 1 second + - "1:30" for 1 hour and 30 minutes + - "1:30:10" for 1 hour, 30 minutes and 10 seconds + - "PT10M" for 10 minutes (ISO8601) + + record_file : str or None + The file to record the last checkpoint to. If None, no file is written. + + """ + super().__init__() + self.config = config + + self.limit = frequency_to_timedelta(limit) + self._record_file = Path(record_file) if record_file is not None else None + + if self._record_file is not None and self._record_file.exists(): + assert self._record_file.is_file(), "The record file must be a file." + + self._start_time = time.time() + + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + _ = pl_module + self._run_stopping_check(trainer) + + def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + _ = pl_module + self._run_stopping_check(trainer) + + def _run_stopping_check(self, trainer: pl.Trainer) -> None: + """ + Check if the time limit has been reached and stop the training if so. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + """ + if timedelta(seconds=time.time() - self._start_time) < self.limit: + return + + LOGGER.info("Time limit of %s seconds reached. Stopping training.", frequency_to_string(self.limit)) + trainer.should_stop = True + self._log_to_file(trainer) + + def _log_to_file(self, trainer: pl.Trainer) -> None: + """ + Log the last checkpoint path to a file if given. + + Parameters + ---------- + trainer : pl.Trainer + Pytorch Lightning trainer + """ + if self._record_file is not None: + last_checkpoint = trainer.checkpoint_callback.last_model_path + self._record_file.parent.mkdir(parents=True, exist_ok=True) + + if self._record_file.exists(): + self._record_file.unlink() + + Path(self._record_file).write_text(str(last_checkpoint)) diff --git a/training/tests/diagnostics/callbacks/test_timelimit.py b/training/tests/diagnostics/callbacks/test_timelimit.py new file mode 100644 index 00000000..4af00fd6 --- /dev/null +++ b/training/tests/diagnostics/callbacks/test_timelimit.py @@ -0,0 +1,91 @@ +# (C) Copyright 2025- Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from typing import Any +from unittest.mock import Mock +from unittest.mock import patch + +import pytest +from pytorch_lightning import Trainer + +from anemoi.training.diagnostics.callbacks.stopping import TimeLimit + + +@pytest.fixture +def mock_trainer() -> Mock: + training_process = Mock(Trainer) + training_process.should_stop = False + return training_process + + +def test_time_limit_stops_training(mock_trainer: Mock) -> None: + # Setting a time limit of 1 second + time_limit_callback = TimeLimit(config=None, limit="24h") + + # Mocking the time to simulate training for 2 seconds + with patch("time.time", return_value=time_limit_callback._start_time + 24 * 60 * 60 + 1): + time_limit_callback.on_train_epoch_end(mock_trainer, None) + + # Check if the training process should stop + assert mock_trainer.should_stop + + +def test_time_limit_stops_training_int(mock_trainer: Mock) -> None: + # Setting a time limit of 1 second + time_limit_callback = TimeLimit(config=None, limit=24) + + # Mocking the time to simulate training for 2 seconds + with patch("time.time", return_value=time_limit_callback._start_time + 24 * 60 * 60 + 1): + time_limit_callback.on_train_epoch_end(mock_trainer, None) + + # Check if the training process should stop + assert mock_trainer.should_stop + + +def test_time_limit_does_not_stop_training(mock_trainer: Mock) -> None: + # Setting a time limit of 1 hour + time_limit_callback = TimeLimit(config=None, limit="1h") + + # Mocking the time to simulate training for 2 seconds + with patch("time.time", return_value=time_limit_callback._start_time + 2): + time_limit_callback.on_train_epoch_end(mock_trainer, None) + + # Check if the training process should not stop + assert not mock_trainer.should_stop + + +def test_time_limit_creates_file_on_stop(mock_trainer: Mock, tmp_path: Any) -> None: + # Setting a time limit of 1 second + time_limit_callback = TimeLimit(config=None, limit="1s", record_file=tmp_path / "log") + + # Mocking the time to simulate training for 2 seconds + with patch("time.time", return_value=time_limit_callback._start_time + 2): + time_limit_callback.on_train_epoch_end(mock_trainer, None) + + # Check if the training process should stop + assert mock_trainer.should_stop + + # Check if the file is created + assert (tmp_path / "log").exists() + + +def test_time_limit_does_not_create_file_when_not_stopping(mock_trainer: Mock, tmp_path: Any) -> None: + # Setting a time limit of 1 second + time_limit_callback = TimeLimit(config=None, limit="24h", record_file=tmp_path / "log") + + # Mocking the time to simulate training for 2 seconds + with patch("time.time", return_value=time_limit_callback._start_time + 2): + time_limit_callback.on_train_epoch_end(mock_trainer, None) + + # Check if the training process should stop + assert not mock_trainer.should_stop + + # Check if the file is created + assert not (tmp_path / "log").exists()