Skip to content

Commit

Permalink
Add connect
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 committed Mar 4, 2024
1 parent cae3abe commit 39543fd
Show file tree
Hide file tree
Showing 8 changed files with 567 additions and 255 deletions.
28 changes: 0 additions & 28 deletions .pre-commit-config.yaml

This file was deleted.

6 changes: 5 additions & 1 deletion streamjoy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from .core import stream
from ._utils import update_logger
from .core import stream, connect
from .models import GifStream, Mp4Stream
from .renderers import (
default_holoviews_renderer,
Expand All @@ -19,6 +20,7 @@
"file_handlers",
"obj_handlers",
"stream",
"connect",
"GifStream",
"Mp4Stream",
]
Expand All @@ -29,3 +31,5 @@
format=config["logging_format"],
datefmt=config["logging_datefmt"],
)

update_logger()
14 changes: 14 additions & 0 deletions streamjoy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from io import BytesIO
from pathlib import Path
from typing import Any, Callable
from itertools import islice

from dask.distributed import Client, Future, get_client

Expand Down Expand Up @@ -146,6 +147,12 @@ def get_max_frames(total_frames: int, max_frames: int) -> int:
return max_frames


def get_first(iterable):
if isinstance(iterable, (list, tuple)):
return iterable[0]
return next(islice(iterable, 0, 1), None)


def get_result(future: Future) -> Any:
if isinstance(future, Future):
return future.result()
Expand Down Expand Up @@ -207,3 +214,10 @@ def validate_xarray(
if ds.ndim > 3:
raise ValueError(f"Can only handle 3D arrays; {ds.ndim}D array found")
return ds


def map_over(client, func, resources, batch_size, *args, **kwargs):
try:
return client.map(func, resources, *args, batch_size=batch_size, **kwargs)
except TypeError as exc:
return [client.submit(func, resource, *args, **kwargs) for resource in resources]
48 changes: 29 additions & 19 deletions streamjoy/core.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from __future__ import annotations

from functools import partial
from pathlib import Path
from typing import Any, Callable

from . import _utils
from .models import AnyStream, GifStream, Mp4Stream
from .models import AnyStream, GifStream, Mp4Stream, ConnectedStreams


def stream(
resources: Any,
output_path: str | Path | None = None,
renderer: Callable | None = None,
renderer_kwargs: dict | None = None,
iterables: list[Any] | None = None,
**kwargs,
) -> AnyStream | GifStream | Mp4Stream | Path:
"""
Expand All @@ -24,13 +21,10 @@ def stream(
output_path: The path to write the stream to. If None, the stream is returned.
renderer: The renderer to use. If None, the default renderer is used.
renderer_kwargs: Additional keyword arguments to pass to the renderer.
iterables: A list of iterables to map alongside the resources; useful for
rendering resources with additional metadata. Each item in the
list should be the same length as the resources.
**kwargs: Additional keyword arguments to pass to the stream constructor.
Returns:
The stream if path is None, otherwise None.
The stream if output_path is None, otherwise the output_path.
"""
stream_cls = AnyStream
if output_path:
Expand All @@ -42,19 +36,35 @@ def stream(
else:
raise ValueError(f"Unsupported file extension {output_path.suffix}")

params = {
key: kwargs.pop(key) for key in stream_cls.param.values() if key in kwargs
}
resources, renderer, renderer_kwargs = stream_cls._expand_from_any(
resources, renderer, renderer_kwargs or {}, **kwargs
)
stream = stream_cls(
renderer=renderer, renderer_kwargs=renderer_kwargs or {}, **params
resources=resources,
renderer=renderer,
renderer_kwargs=renderer_kwargs,
**kwargs,
)

if output_path:
return stream.write(
resources,
output_path=output_path,
iterables=iterables,
**kwargs,
)
stream.write = partial(stream.write, resources, **kwargs)
return stream.write(output_path=output_path)
return stream


def connect(
streams: list[AnyStream | GifStream | Mp4Stream],
output_path: str | Path | None = None,
) -> ConnectedStreams | Path:
"""
Connect multiple streams into a single stream.
Args:
streams: The streams to connect.
Returns:
The connected streams if output_path is None, otherwise the output_path.
"""
stream = ConnectedStreams(streams=streams)
if output_path:
return stream.write(output_path=output_path)
return stream
Loading

0 comments on commit 39543fd

Please sign in to comment.