Skip to content

Commit

Permalink
feat(training): Add initial TimeLimit callback (#115)
Browse files Browse the repository at this point in the history
* Add TimeLimit callback
- Stops training if time limit is reached
- Log to file the last checkpoint


---------

Co-authored-by: Julien Lefaucheur <julien.lefaucheur@ecmwf.int>
  • Loading branch information
HCookie and jjlk authored Feb 11, 2025
1 parent 0a9cfa7 commit 41ff583
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def get_callbacks(config: DictConfig) -> list[Callback]:
)

# Plotting callbacks

trainer_callbacks.extend(
instantiate(callback, config) for callback in config.diagnostics.plot.get("callbacks", None) or []
)
Expand Down
103 changes: 103 additions & 0 deletions training/src/anemoi/training/diagnostics/callbacks/stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# (C) Copyright 2025- Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import logging
import time
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING

import pytorch_lightning as pl

from anemoi.utils.dates import frequency_to_string
from anemoi.utils.dates import frequency_to_timedelta

if TYPE_CHECKING:
from omegaconf import OmegaConf

LOGGER = logging.getLogger(__name__)


class TimeLimit(pl.callbacks.Callback):
"""Callback to stop the training process after a given time limit."""

def __init__(self, config: OmegaConf, limit: int | str, record_file: str | None = None) -> None:
"""Initialise the TimeLimit callback.
Parameters
----------
limit : int or str
The frequency to convert. If an integer, it is assumed to be in hours. If a string, it can be in the format:
- "1h" for 1 hour
- "1d" for 1 day
- "1m" for 1 minute
- "1s" for 1 second
- "1:30" for 1 hour and 30 minutes
- "1:30:10" for 1 hour, 30 minutes and 10 seconds
- "PT10M" for 10 minutes (ISO8601)
record_file : str or None
The file to record the last checkpoint to. If None, no file is written.
"""
super().__init__()
self.config = config

self.limit = frequency_to_timedelta(limit)
self._record_file = Path(record_file) if record_file is not None else None

if self._record_file is not None and self._record_file.exists():
assert self._record_file.is_file(), "The record file must be a file."

self._start_time = time.time()

def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
_ = pl_module
self._run_stopping_check(trainer)

def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
_ = pl_module
self._run_stopping_check(trainer)

def _run_stopping_check(self, trainer: pl.Trainer) -> None:
"""
Check if the time limit has been reached and stop the training if so.
Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
"""
if timedelta(seconds=time.time() - self._start_time) < self.limit:
return

LOGGER.info("Time limit of %s seconds reached. Stopping training.", frequency_to_string(self.limit))
trainer.should_stop = True
self._log_to_file(trainer)

def _log_to_file(self, trainer: pl.Trainer) -> None:
"""
Log the last checkpoint path to a file if given.
Parameters
----------
trainer : pl.Trainer
Pytorch Lightning trainer
"""
if self._record_file is not None:
last_checkpoint = trainer.checkpoint_callback.last_model_path
self._record_file.parent.mkdir(parents=True, exist_ok=True)

if self._record_file.exists():
self._record_file.unlink()

Path(self._record_file).write_text(str(last_checkpoint))
91 changes: 91 additions & 0 deletions training/tests/diagnostics/callbacks/test_timelimit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# (C) Copyright 2025- Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from typing import Any
from unittest.mock import Mock
from unittest.mock import patch

import pytest
from pytorch_lightning import Trainer

from anemoi.training.diagnostics.callbacks.stopping import TimeLimit


@pytest.fixture
def mock_trainer() -> Mock:
training_process = Mock(Trainer)
training_process.should_stop = False
return training_process


def test_time_limit_stops_training(mock_trainer: Mock) -> None:
# Setting a time limit of 1 second
time_limit_callback = TimeLimit(config=None, limit="24h")

# Mocking the time to simulate training for 2 seconds
with patch("time.time", return_value=time_limit_callback._start_time + 24 * 60 * 60 + 1):
time_limit_callback.on_train_epoch_end(mock_trainer, None)

# Check if the training process should stop
assert mock_trainer.should_stop


def test_time_limit_stops_training_int(mock_trainer: Mock) -> None:
# Setting a time limit of 1 second
time_limit_callback = TimeLimit(config=None, limit=24)

# Mocking the time to simulate training for 2 seconds
with patch("time.time", return_value=time_limit_callback._start_time + 24 * 60 * 60 + 1):
time_limit_callback.on_train_epoch_end(mock_trainer, None)

# Check if the training process should stop
assert mock_trainer.should_stop


def test_time_limit_does_not_stop_training(mock_trainer: Mock) -> None:
# Setting a time limit of 1 hour
time_limit_callback = TimeLimit(config=None, limit="1h")

# Mocking the time to simulate training for 2 seconds
with patch("time.time", return_value=time_limit_callback._start_time + 2):
time_limit_callback.on_train_epoch_end(mock_trainer, None)

# Check if the training process should not stop
assert not mock_trainer.should_stop


def test_time_limit_creates_file_on_stop(mock_trainer: Mock, tmp_path: Any) -> None:
# Setting a time limit of 1 second
time_limit_callback = TimeLimit(config=None, limit="1s", record_file=tmp_path / "log")

# Mocking the time to simulate training for 2 seconds
with patch("time.time", return_value=time_limit_callback._start_time + 2):
time_limit_callback.on_train_epoch_end(mock_trainer, None)

# Check if the training process should stop
assert mock_trainer.should_stop

# Check if the file is created
assert (tmp_path / "log").exists()


def test_time_limit_does_not_create_file_when_not_stopping(mock_trainer: Mock, tmp_path: Any) -> None:
# Setting a time limit of 1 second
time_limit_callback = TimeLimit(config=None, limit="24h", record_file=tmp_path / "log")

# Mocking the time to simulate training for 2 seconds
with patch("time.time", return_value=time_limit_callback._start_time + 2):
time_limit_callback.on_train_epoch_end(mock_trainer, None)

# Check if the training process should stop
assert not mock_trainer.should_stop

# Check if the file is created
assert not (tmp_path / "log").exists()

0 comments on commit 41ff583

Please sign in to comment.