diff --git a/src/dvclive/studio.py b/src/dvclive/studio.py index e2e9e1c5..e1ede8d2 100644 --- a/src/dvclive/studio.py +++ b/src/dvclive/studio.py @@ -1,4 +1,5 @@ # ruff: noqa: SLF001 +import base64 import os from pathlib import Path @@ -43,6 +44,21 @@ def _adapt_plot_datapoints(live, plot): return _cast_to_numbers(datapoints) +def _adapt_image(image_path): + with open(image_path, "rb") as fobj: + return base64.b64encode(fobj.read()).decode("utf-8") + + +def _adapt_images(live): + return { + _adapt_plot_name(live, image.output_path): { + "image": _adapt_image(image.output_path) + } + for image in live._images.values() + if image.step > live._latest_studio_step + } + + def get_studio_updates(live): if os.path.isfile(live.params_file): params_file = live.params_file @@ -65,4 +81,6 @@ def get_studio_updates(live): } plots = {k: {"data": v} for k, v in plots.items()} + plots.update(_adapt_images(live)) + return metrics, params, plots diff --git a/tests/test_studio.py b/tests/test_studio.py index 97666aee..b62bbcef 100644 --- a/tests/test_studio.py +++ b/tests/test_studio.py @@ -1,10 +1,12 @@ from pathlib import Path import pytest +from PIL import Image as ImagePIL from dvclive import Live from dvclive.env import DVC_EXP_BASELINE_REV, DVC_EXP_NAME -from dvclive.plots import Metric +from dvclive.plots import Image, Metric +from dvclive.studio import _adapt_image def test_post_to_studio(tmp_dir, mocked_dvc_repo, mocked_studio_post): @@ -378,3 +380,36 @@ def test_post_to_studio_inside_subdir_dvc_exp( def test_post_to_studio_requires_exp(tmp_dir, mocked_dvc_repo, mocked_studio_post): assert Live()._studio_events_to_skip == {"start", "data", "done"} assert not Live(save_dvc_exp=True)._studio_events_to_skip + + +def test_post_to_studio_images(tmp_dir, mocked_dvc_repo, mocked_studio_post): + mocked_post, _ = mocked_studio_post + + live = Live(save_dvc_exp=True) + live.log_image("foo.png", ImagePIL.new("RGB", (10, 10), (0, 0, 0))) + live.next_step() + + dvc_path = Path(live.dvc_file).as_posix() + metrics_path = Path(live.metrics_file).as_posix() + foo_path = (Path(live.plots_dir) / Image.subfolder / "foo.png").as_posix() + + mocked_post.assert_called_with( + "https://0.0.0.0/api/live", + json={ + "type": "data", + "repo_url": "STUDIO_REPO_URL", + "baseline_sha": live._baseline_rev, + "name": live._exp_name, + "step": 0, + "metrics": {f"{metrics_path}": {"data": {"step": 0}}}, + "plots": { + f"{dvc_path}::{foo_path}": {"image": _adapt_image(foo_path)}, + }, + "client": "dvclive", + }, + headers={ + "Authorization": "token STUDIO_TOKEN", + "Content-type": "application/json", + }, + timeout=5, + )