Skip to content

Commit

Permalink
Knobs for controlling index downloading concurrency.
Browse files Browse the repository at this point in the history
  • Loading branch information
knighton committed Feb 4, 2024
1 parent 2e26454 commit 25c8f15
Showing 1 changed file with 87 additions and 3 deletions.
90 changes: 87 additions & 3 deletions streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ class StreamingDataset(Array, IterableDataset):
* ``validate_hash``
* ``keep_phases``
* How to iterate:
* Init:
* ``index_download_procs``
* ``index_download_procs_per_cpu``
* ``index_download_max_procs``
* Shard lifecycle:
* ``predownload``
* ``cache_limit``
Expand Down Expand Up @@ -267,6 +271,15 @@ class StreamingDataset(Array, IterableDataset):
Specified as a single use or phase to keep, a sequence of uses or phases to keep, a
mapping of uses or phases to whether to keep or drop, or a ``Phaser`` (which performs
the same keeping or dropping). Defaults to ``None``.
index_download_procs (int, optional): Size of the process pool. You may set either this
arg or ``index_download_procs_per_cpu``, but not both. Defaults to ``None``.
index_download_procs_per_cpu (float | int, optional): Size of the process pool as the
number of processes per CPU. You may set either this arg or ``index_download_procs``,
but not both. If its value is negative, it is taken to be the reciprocal of that value
as a positive number, e.g. ``-4`` means ``num_cpus // 4``. Defaults to ``None``.
index_download_max_procs (int, optional): Optional ceiling on the number of index
downlaod processes. Used to not overwhelm remote storage with concurrency connectinos.
Defaults to ``None``.
predownload (int, optional): Target number of samples to download per worker in advance
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
Expand Down Expand Up @@ -333,6 +346,9 @@ def __init__(
download_max_size: Optional[Union[str, int]] = '200mb',
validate_hash: Union[None, str, Sequence[str]] = None,
keep_phases: Union[None, str, Sequence[str], Dict[str, Optional[bool]], Phaser] = None,
index_download_procs: Optional[int] = None,
index_download_procs_per_cpu: Optional[Union[float, int]] = 4,
index_download_max_procs: Optional[int] = 64,
predownload: Optional[int] = None,
cache_limit: Optional[Union[str, int]] = None,
shuffle_seed: int = 9176,
Expand Down Expand Up @@ -473,8 +489,16 @@ def __init__(
# Parallelism is important because there could be a very large number of Streams, and we
# expect equal performance as them having been concatenated into one Stream.
if world.is_local_leader:
with Pool() as pool:
pool.imap_unordered(lambda stream: stream.download_index(), self.streams)
num_procs = self._get_index_download_procs(
index_download_procs,
index_download_procs_per_cpu,
index_download_max_procs,
)
pool = Pool(num_procs)
pool.imap_unordered(lambda stream: stream.download_index(), self.streams)
pool.close()
else:
pool = None

# All ranks then walk all Streams, for which they (1) wait for its index to become
# downloaded, (2) load its shards, and (3) map streams <-> shards <-> samples.
Expand All @@ -501,6 +525,9 @@ def __init__(
self.stream_per_shard = np.array(stream_per_shard, np.int64)
self.num_shards = len(self.shards)

if pool is not None:
pool.join()

# Check that cache limit is possible.
if cache_limit:
self.cache_limit = normalize_bytes(cache_limit)
Expand Down Expand Up @@ -689,7 +716,64 @@ def __len__(self) -> int:
"""
return self.length

def _set_shuffle_block_size(self, world: World):
def _get_index_download_procs(
self,
index_download_procs: Optional[int] = None,
index_download_procs_per_cpu: Optional[Union[float, int]] = 4,
index_download_max_procs: Optional[int] = 64,
) -> int:
"""Get the size of the process pool used by index downloading.
Args:
index_download_procs (int, optional): Size of the process pool. You may set either this
arg or ``index_download_procs_per_cpu``, but not both. Defaults to ``None``.
index_download_procs_per_cpu (float | int, optional): Size of the process pool as the
number of processes per CPU. You may set either this arg or
``index_download_procs``, but not both. If its value is negative, it is taken to be
the reciprocal of that value as a positive number, e.g. ``-4`` means
``num_cpus // 4``. Defaults to ``4``.
index_download_max_procs (int, optional): Optional ceiling on the number of index
downlaod processes. Used to not overwhelm remote storage with concurrency
connectinos. Defaults to ``64``.
"""
if index_download_procs is not None:
if index_download_procs_per_cpu is not None:
raise ValueError(
f'You may specify `index_download_procs`, `index_download_procs_per_cpu`, ' +
f'or neither, but got both: {index_download_procs} and ' +
f'{index_download_procs_per_cpu} respectively.')
else:
num_procs = index_download_procs
if num_procs <= 0:
raise ValueError(f'`index_download_procs` must be a positive integer.')
else:
num_cpus = os.cpu_count() or 1
if index_download_procs_per_cpu is not None:
procs_per_cpu = index_download_procs_per_cpu
if procs_per_cpu < 0:
num_procs = num_cpus / -procs_per_cpu
elif not procs_per_cpu:
num_procs = num_cpus
else:
num_procs = num_cpus * procs_per_cpu
num_procs = int(np.ceil(num_procs))
if num_procs <= 0:
raise ValueError(f'`index_download_procs_per_cpu` must result in a positive ' +
f'number of index download processes after rounding up.')
else:
num_procs = num_cpus

max_procs = index_download_max_procs
if max_procs is not None:
if max_procs <= 0:
raise ValueError(f'`index_download_max_procs` must be a positive integer, but ' +
f'got: {max_procs}.')
if max_procs < num_procs:
num_procs = max_procs

return num_procs

def _set_shuffle_block_size(self, world: World) -> None:
"""Set the shuffle block size value."""
if self.shuffle_block_size is None:
if not world.worker_of_rank:
Expand Down

0 comments on commit 25c8f15

Please sign in to comment.