Skip to content

Commit

Permalink
Explicitly pass backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 committed Apr 6, 2024
1 parent 788e87e commit b6539f1
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
8 changes: 4 additions & 4 deletions streamjoy/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,7 @@ def default_holoviews_renderer(
The rendered HoloViews Element.
"""
import holoviews as hv

backend = kwargs.get("backend", "bokeh")
hv.extension(backend)
backend = kwargs.get("backend", hv.Store.current_backend)

clims = kwargs.pop("clims", {})
for hv_el in hv_obj.traverse(full_breadth=False):
Expand All @@ -142,10 +140,12 @@ def default_holoviews_renderer(
except IndexError:
continue
if vdim in clims:
hv_el.opts(clim=clims[vdim])
hv_el.opts(clim=clims[vdim], backend=backend)

if backend == "bokeh":
kwargs["toolbar"] = None
elif backend == "matplotlib":
kwargs["cbar_extend"] = kwargs.get("cbar_extend", "both")
hv_obj.opts(**kwargs)

return hv_obj
9 changes: 6 additions & 3 deletions streamjoy/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,7 @@ def serialize_holoviews(
"""
import holoviews as hv

backend = hv.Store.current_backend
hv.extension(backend)
backend = kwargs.get("backend", hv.Store.current_backend)

def _select_element(hv_obj, key):
try:
Expand Down Expand Up @@ -332,7 +331,10 @@ def _select_element(hv_obj, key):
if len(kdims) > 1:
raise ValueError("Can only handle 1D HoloViews objects.")

resources = [_select_element(hv_obj, key).opts(title=str(key)) for key in keys]
resources = [
_select_element(hv_obj, key).opts(title=str(key))
for key in keys
]

renderer_kwargs = renderer_kwargs or {}
renderer_kwargs.update(_utils.pop_from_cls(stream_cls, kwargs))
Expand All @@ -353,6 +355,7 @@ def _select_element(hv_obj, key):
array = hv_el.dimension_values(vdim)
clim = (np.nanmin(array), np.nanmax(array))
clims[vdim] = clim

renderer_kwargs.update(
backend=backend,
clims=clims,
Expand Down
6 changes: 2 additions & 4 deletions streamjoy/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def wrapper(renderer):
def wrapped(*args, **kwargs) -> Path | BytesIO:
import holoviews as hv

backend = kwargs.get("backend", "bokeh")
hv.extension(backend)

backend = kwargs.get("backend", hv.Store.current_backend)
output = renderer(*args, **kwargs)

hv_obj = output
Expand Down Expand Up @@ -129,7 +127,7 @@ def wrapped(*args, **kwargs) -> Path | BytesIO:
service=Service(ChromeDriverManager().install()), options=options
) as webdriver:
image = get_screenshot_as_png(
hv.render(hv_obj, backend="bokeh"), driver=webdriver
hv.render(hv_obj, backend=backend), driver=webdriver
)
if fsspec_fs:
with fsspec_fs.open(uri, "wb") as f:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def _assert_stream_and_props(self, sj, stream_cls):
buf = sj.write()
props = improps(buf)
props.n_images == 3
return props

def test_from_pandas(self, stream_cls, df):
sj = stream_cls.from_pandas(df)
Expand Down Expand Up @@ -54,6 +55,20 @@ def test_fsspec_fs(self, stream_cls, df, fsspec_fs):
sj = stream_cls.from_pandas(df, fsspec_fs=fsspec_fs)
self._assert_stream_and_props(sj, stream_cls)

def test_holoviews_matplotlib_backend(self, stream_cls, ds):
sj = stream_cls.from_holoviews(
ds.hvplot("lon", "lat", fig_size=200, backend="matplotlib")
)
props = self._assert_stream_and_props(sj, stream_cls)
assert props.shape[1] == 300

def test_holoviews_bokeh_backend(self, stream_cls, ds):
sj = stream_cls.from_holoviews(
ds.hvplot("lon", "lat", width=300, backend="bokeh")
)
props = self._assert_stream_and_props(sj, stream_cls)
assert props.shape[1] == 300


class TestGifStream(AbstractTestMediaStream):
@pytest.fixture(scope="class")
Expand Down

0 comments on commit b6539f1

Please sign in to comment.