Skip to content

Commit

Permalink
fix(studio): handle unexpected exceptions in updates thread (#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Feb 6, 2025
1 parent 6e29c5e commit 3dac488
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,10 +929,17 @@ def post_data_to_studio(self):
self._studio_queue = queue.Queue()

def worker():
error_occurred = False
while True:
item, data = self._studio_queue.get()
post_to_studio(item, "data", data)
self._studio_queue.task_done()
try:
if not error_occurred:
post_to_studio(item, "data", data)
except Exception:
logger.exception("Failed to post data to studio")
error_occurred = True
finally:
self._studio_queue.task_done()

threading.Thread(target=worker, daemon=True).start()

Expand Down
30 changes: 30 additions & 0 deletions tests/test_post_to_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,36 @@ def long_post(*args, **kwargs):
assert metrics_file.read_text() == metrics_content


def test_studio_update_raises_exception(tmp_path, mocked_dvc_repo, mocked_studio_post):
# Test that if a studio update raises an exception, main process doesn't hang on
# queue join in the Live main thread.
# https://github.com/iterative/dvclive/pull/864
mocked_post, valid_response = mocked_studio_post

def post_raises_exception(*args, **kwargs):
if kwargs["json"]["type"] == "data":
# We'll hit this sleep only once, other calls are ignored
# after the exception is raised
time.sleep(1)
raise Exception("test exception") # noqa: TRY002, TRY003
return valid_response

mocked_post.side_effect = post_raises_exception

with Live() as live:
live.log_metric("foo", 1)
live.log_metric("foo", 2)
live.log_metric("foo", 3)

# Only 1 data call is made, other calls are ignored after the exception is raised
assert mocked_post.call_count == 3
assert [e.kwargs["json"]["type"] for e in mocked_post.call_args_list] == [
"start",
"data",
"done",
]


@pytest.mark.studio
def test_post_to_studio_skip_start_and_done_on_env_var(
tmp_dir, mocked_dvc_repo, mocked_studio_post, monkeypatch
Expand Down

0 comments on commit 3dac488

Please sign in to comment.