diff --git a/src/dvclive/lightning.py b/src/dvclive/lightning.py index c097ddfd..335ab5e7 100644 --- a/src/dvclive/lightning.py +++ b/src/dvclive/lightning.py @@ -57,8 +57,9 @@ def __init__( self._live_init["dir"] = dir self._experiment = experiment self._version = run_name - # Force Live instantiation - self.experiment # noqa: B018 + if report == "notebook": + # Force Live instantiation + self.experiment # noqa: B018 @property def name(self): diff --git a/tests/test_frameworks/test_lightning.py b/tests/test_frameworks/test_lightning.py index ef5422e4..a26f5edb 100644 --- a/tests/test_frameworks/test_lightning.py +++ b/tests/test_frameworks/test_lightning.py @@ -15,6 +15,7 @@ from torch.optim import SGD, Adam from torch.utils.data import DataLoader, Dataset + from dvclive import Live from dvclive.lightning import DVCLiveLogger except ImportError: pytest.skip("skipping pytorch_lightning tests", allow_module_level=True) @@ -239,3 +240,14 @@ def test_lightning_val_udpates_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio # Without `self.experiment._latest_studio_step -= 1` # This would be empty assert len(val_loss["data"]) == 1 + + +def test_lightning_force_init(tmp_dir, mocker): + """Regression test for https://github.com/iterative/dvclive/issues/594 + Only call Live.__init__ when report is notebook. + """ + init = mocker.spy(Live, "__init__") + DVCLiveLogger() + init.assert_not_called() + DVCLiveLogger(report="notebook") + init.assert_called_once()