Skip to content

Commit

Permalink
Clean up the PipelineBuilder implementation (#144)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Aug 11, 2024
1 parent 8c41e50 commit bd57fbe
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 25 deletions.
32 changes: 9 additions & 23 deletions src/spdl/dataloader/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ async def afunc(item: T) -> U:


class PipelineBuilder:
"""**[Experimental]** Build :py:class:`~spdl.dataloader.Pipeline` object.
"""Build :py:class:`~spdl.dataloader.Pipeline` object.
See :py:class:`~spdl.dataloader.Pipeline` for details.
"""
Expand Down Expand Up @@ -727,39 +727,28 @@ def add_sink(self, buffer_size: int) -> "PipelineBuilder":
self._sink_buffer_size = buffer_size
return self

def _build(
self,
num_items: int | None,
# TODO: Once we remove AsyncPipeline, construct queues internally.
queues: list[AsyncQueue],
) -> Coroutine[None, None, None]:
def _build(self) -> tuple[Coroutine[None, None, None], list[AsyncQueue]]:
if self._source is None:
raise ValueError("Source is not set.")
if num_items is not None and num_items < 1:
raise ValueError("num_items must be >= 0")

construct_queues = len(queues) == 0

# Note:
# Make sure that coroutines are ordered from source to sink.
# `_run_pipeline_coroutines` expects and rely on this ordering.
coros = []
queues: list[AsyncQueue] = []

# source
if construct_queues:
queues.append(AsyncQueue(self._source_buffer_size))

queues.append(AsyncQueue(self._source_buffer_size))
coros.append(
(
"AsyncPipeline::0_source",
_enqueue(self._source, queues[0], max_items=num_items),
_enqueue(self._source, queues[0]),
)
)

# pipes
for i, (type_, args, buffer_size) in enumerate(self._process_args, start=1):
if construct_queues:
queues.append(AsyncQueue(buffer_size))
queues.append(AsyncQueue(buffer_size))
in_queue, out_queue = queues[i - 1 : i + 1]

match type_:
Expand All @@ -776,17 +765,15 @@ def _build(

# sink
if self._sink_buffer_size is not None:
if construct_queues:
queues.append(AsyncQueue(self._sink_buffer_size))

queues.append(AsyncQueue(self._sink_buffer_size))
coros.append(
(
f"AsyncPipeline::{len(self._process_args) + 1}_sink",
_sink(*queues[-2:]),
)
)

return _run_pipeline_coroutines(coros)
return _run_pipeline_coroutines(coros), queues

def _get_desc(self) -> list[str]:
parts = []
Expand Down Expand Up @@ -828,8 +815,7 @@ def build(self, *, num_threads: int | None = None) -> Pipeline:
async event loop.
If not specified, the maximum concurrency value is used.
"""
queues = []
coro = self._build(None, queues)
coro, queues = self._build()

if num_threads is None:
concurrencies = [
Expand Down
3 changes: 1 addition & 2 deletions src/spdl/dataloader/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,7 @@ class _EventLoopState(IntEnum):
class Pipeline(Generic[T]):
"""Pipeline()
**[Experimental]** Data processing pipeline. Use :py:class:`PipelineBuilder` to
instantiate.
Data processing pipeline. Use :py:class:`PipelineBuilder` to instantiate.
``Pipeline`` and ``PipelineBuilder`` facilitate building data processing pipeline
consists of multiple stages of async operations.
Expand Down

0 comments on commit bd57fbe

Please sign in to comment.