Skip to content

Commit

Permalink
add turbo mode
Browse files Browse the repository at this point in the history
  • Loading branch information
drcege committed Aug 22, 2024
1 parent 7ee70df commit 17e0714
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
26 changes: 18 additions & 8 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def init_configs(args=None):
help='Suffixes of files that will be find and loaded. If not set, we '
'will find all suffix files, and select a suitable formatter '
'with the most files as default.')
parser.add_argument(
'--turbo',
type=bool,
default=False,
help='Enable Turbo mode to maximize processing speed. Stability '
'features like fault tolerance will be disabled.')
parser.add_argument(
'--use_cache',
type=bool,
Expand Down Expand Up @@ -463,6 +469,8 @@ def init_setup_from_cfg(cfg):
'image_key': cfg.image_key,
'audio_key': cfg.audio_key,
'video_key': cfg.video_key,
'num_proc': cfg.np,
'turbo': cfg.turbo,
}
else:
if 'text_key' not in args or args['text_key'] is None:
Expand All @@ -473,6 +481,10 @@ def init_setup_from_cfg(cfg):
args['audio_key'] = cfg.audio_key
if 'video_key' not in args or args['video_key'] is None:
args['video_key'] = cfg.video_key
if 'num_proc' not in args or args['num_proc'] is None:
args['num_proc'] = cfg.np
if 'turbo' not in args or args['turbo'] is None:
args['turbo'] = cfg.turbo
op[op_name] = args

return cfg
Expand Down Expand Up @@ -567,14 +579,12 @@ def update_op_process(cfg, parser):

# update op params of cfg.process
internal_op_para = temp_cfg.get(op_in_process_name)
if internal_op_para is not None:
num_proc = internal_op_para.get('num_proc')
if 'num_proc' in internal_op_para:
internal_op_para['num_proc'] = num_proc or cfg.np
internal_op_para = namespace_to_dict(internal_op_para)
else:
internal_op_para = None
cfg.process[i] = {op_in_process_name: internal_op_para}

cfg.process[i] = {
op_in_process_name:
None if internal_op_para is None else
namespace_to_dict(internal_op_para)
}

# check the op params via type hint
temp_parser = copy.deepcopy(parser)
Expand Down
13 changes: 8 additions & 5 deletions data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def process(self,
dataset = op.run(dataset, exporter=exporter, tracer=tracer)
# record processed ops
if checkpointer is not None:
checkpointer.record(op._of_cfg)
checkpointer.record(op._op_cfg)
end = time()
logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. '
f'Left {len(dataset)} samples.')
Expand Down Expand Up @@ -227,11 +227,14 @@ def map(self, *args, **kargs):
called_func, '__wrapped__'):
called_func = called_func.__wrapped__

# Batched is always required for fault tolerance
if inspect.ismethod(called_func):
kargs['batched'] = True
kargs['batch_size'] = kargs.pop(
'batch_size', 1) if called_func.__self__.is_batched_op() else 1
# batched is required for fault-tolerant or batched OP
if not called_func.__self__.turbo or \
called_func.__self__.is_batched_op():
kargs['batched'] = True
kargs['batch_size'] = kargs.pop('batch_size', 1)
else:
kargs['batched'] = False

if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None:
new_fingerprint = generate_fingerprint(self, *args, **kargs)
Expand Down
2 changes: 2 additions & 0 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def __init__(self, *args, **kwargs):
if isinstance(self.mem_required, str):
self.mem_required = size_to_bytes(self.mem_required) / 1024**3

self.turbo = kwargs.get('turbo', False)

# nested wrappers
from data_juicer.core.data import wrap_func_with_nested_access
for name in ['process', 'compute_stats', 'compute_hash']:
Expand Down

0 comments on commit 17e0714

Please sign in to comment.