Skip to content

Commit

Permalink
Update SPDL Image Classification Example with Dataloader (facebookres…
Browse files Browse the repository at this point in the history
…earch#289)

Summary:

Update SPDL image classification example to use Dataloader instead of Pipeline.

The dataloader builds a Pipeline when converted to an iterable.

Minor difference: the default concurrency is used in the Dataloader aggregator

Reviewed By: moto-meta

Differential Revision: D66724034
  • Loading branch information
Victor Bourgin authored and facebook-github-bot committed Dec 14, 2024
1 parent a520ba4 commit e03b51a
Showing 1 changed file with 23 additions and 27 deletions.
50 changes: 23 additions & 27 deletions examples/imagenet_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"""Benchmark the performance of loading images from local file systems and
classifying them using a GPU.
This script builds the data loading pipeline and instantiates an image
This script builds the data loader and instantiates an image
classification model in a GPU.
The pipeline transfer the batch image data to the GPU concurrently, and
The data loader transfers the batch image data to the GPU concurrently, and
the foreground thread run the model on data one by one.
.. include:: ../plots/imagenet_classification_chart.txt
Expand Down Expand Up @@ -43,7 +43,7 @@
import spdl.io
import spdl.utils
import torch
from spdl.pipeline import Pipeline, PipelineBuilder
from spdl.dataloader._dataloader import DataLoader
from torch import Tensor
from torch.profiler import profile

Expand All @@ -55,7 +55,7 @@
"benchmark",
"source",
"get_decode_func",
"get_pipeline",
"get_dataloader",
"get_model",
"ModelBundle",
"Classification",
Expand Down Expand Up @@ -365,31 +365,34 @@ async def decode_images_nvdec(items: list[tuple[str, int]]):
return decode_images_nvdec


def get_pipeline(
def get_dataloader(
src: Iterator[tuple[str, int]],
batch_size: int,
decode_func: Callable[[list[tuple[str, int]]], Awaitable[tuple[Tensor, Tensor]]],
concurrency: int,
buffer_size: int,
num_threads: int,
) -> Pipeline:
"""Build image data loading pipeline.
) -> DataLoader:
"""Build the dataloader for the ImageNet classification task.
The pipeline uses the ``decode_func`` for decoding images concurrently and
The dataloader uses the ``decode_func`` for decoding images concurrently and
send the resulting data to GPU.
Args:
src: The source of the data. See :py:func:`source`.
batch_size: The number of images in a batch.
decode_func: The function to decode images.
buffer_size: The size of the buffer for the dataloader sink
num_threads: The number of worker threads.
"""
return (
PipelineBuilder()
.add_source(src)
.aggregate(batch_size, drop_last=True)
.pipe(decode_func, concurrency=concurrency)
.add_sink(buffer_size)
.build(num_threads=num_threads)
return DataLoader(
src,
batch_size=batch_size,
drop_last=True,
aggregator=decode_func,
buffer_size=buffer_size,
num_threads=num_threads,
timeout=20,
)


Expand Down Expand Up @@ -432,24 +435,20 @@ def benchmark(dataloader: Iterator[tuple[Tensor, Tensor]], model: ModelBundle) -
_LG.info(f"Accuracy (top5)={acc5:.2%} ({num_correct_top5}/{num_frames})")


def _get_pipeline(args, device_index) -> Pipeline:
def _get_dataloader(args, device_index) -> DataLoader:
src = source(args.input_flist, args.prefix, args.max_samples)

if args.use_nvjpeg:
decode_func = _get_experimental_nvjpeg_decode_function(device_index)
concurrency = 7
elif args.use_nvdec:
decode_func = _get_experimental_nvdec_decode_function(device_index)
concurrency = 4
else:
decode_func = get_decode_func(device_index)
concurrency = args.num_threads

return get_pipeline(
return get_dataloader(
src,
args.batch_size,
decode_func,
concurrency,
args.queue_size,
args.num_threads,
)
Expand All @@ -464,9 +463,7 @@ def entrypoint(args: list[int] | None = None):

device_index = 0
model = get_model(args.batch_size, device_index, args.compile, args.use_bf16)
pipeline = _get_pipeline(args, device_index)

print(pipeline)
dataloader = _get_dataloader(args, device_index)

trace_path = f"{args.trace}"
if args.use_nvjpeg:
Expand All @@ -478,9 +475,8 @@ def entrypoint(args: list[int] | None = None):
torch.no_grad(),
profile() if args.trace else contextlib.nullcontext() as prof,
spdl.utils.tracing(f"{trace_path}.pftrace", enable=args.trace is not None),
pipeline.auto_stop(timeout=1),
):
benchmark(pipeline.get_iterator(), model)
benchmark(dataloader, model)

if args.trace:
prof.export_chrome_trace(f"{trace_path}.json")
Expand Down

0 comments on commit e03b51a

Please sign in to comment.