From 3dac4889735a9859443c566364d9711ac05e9cfd Mon Sep 17 00:00:00 2001 From: Ivan Shcheklein Date: Wed, 5 Feb 2025 19:28:39 -0800 Subject: [PATCH] fix(studio): handle unexpected exceptions in updates thread (#864) --- src/dvclive/live.py | 11 +++++++++-- tests/test_post_to_studio.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/dvclive/live.py b/src/dvclive/live.py index d73be31..86f78ee 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -929,10 +929,17 @@ def post_data_to_studio(self): self._studio_queue = queue.Queue() def worker(): + error_occurred = False while True: item, data = self._studio_queue.get() - post_to_studio(item, "data", data) - self._studio_queue.task_done() + try: + if not error_occurred: + post_to_studio(item, "data", data) + except Exception: + logger.exception("Failed to post data to studio") + error_occurred = True + finally: + self._studio_queue.task_done() threading.Thread(target=worker, daemon=True).start() diff --git a/tests/test_post_to_studio.py b/tests/test_post_to_studio.py index 0c6d119..6554ded 100644 --- a/tests/test_post_to_studio.py +++ b/tests/test_post_to_studio.py @@ -274,6 +274,36 @@ def long_post(*args, **kwargs): assert metrics_file.read_text() == metrics_content +def test_studio_update_raises_exception(tmp_path, mocked_dvc_repo, mocked_studio_post): + # Test that if a studio update raises an exception, main process doesn't hang on + # queue join in the Live main thread. + # https://github.com/iterative/dvclive/pull/864 + mocked_post, valid_response = mocked_studio_post + + def post_raises_exception(*args, **kwargs): + if kwargs["json"]["type"] == "data": + # We'll hit this sleep only once, other calls are ignored + # after the exception is raised + time.sleep(1) + raise Exception("test exception") # noqa: TRY002, TRY003 + return valid_response + + mocked_post.side_effect = post_raises_exception + + with Live() as live: + live.log_metric("foo", 1) + live.log_metric("foo", 2) + live.log_metric("foo", 3) + + # Only 1 data call is made, other calls are ignored after the exception is raised + assert mocked_post.call_count == 3 + assert [e.kwargs["json"]["type"] for e in mocked_post.call_args_list] == [ + "start", + "data", + "done", + ] + + @pytest.mark.studio def test_post_to_studio_skip_start_and_done_on_env_var( tmp_dir, mocked_dvc_repo, mocked_studio_post, monkeypatch