diff --git a/src/spdl/dataloader/_pytorch_dataloader.py b/src/spdl/dataloader/_pytorch_dataloader.py index eb579ce1..2ea74eef 100644 --- a/src/spdl/dataloader/_pytorch_dataloader.py +++ b/src/spdl/dataloader/_pytorch_dataloader.py @@ -279,6 +279,7 @@ def get_pytorch_dataloader( prefetch_factor: int = 2, persistent_workers: bool = False, pin_memory_device: str | None = None, + in_order: bool = False, ) -> PyTorchDataLoader[U]: from torch.utils.data.dataloader import IterableDataset @@ -336,4 +337,5 @@ def get_pytorch_dataloader( num_workers=num_workers, timeout=timeout, buffer_size=buffer_size, + output_order="input" if in_order else "completion", )