From 213f7f8aef078395eaefd89817b098eeb94a45b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ce=20Ge=20=28=E6=88=88=E7=AD=96=29?= Date: Thu, 22 Aug 2024 15:03:51 +0800 Subject: [PATCH 1/2] Enhance/ckpt (#396) * enhance ckpt logic * fix tests --- data_juicer/core/analyzer.py | 6 ++--- data_juicer/core/data.py | 45 ++++++++++++++++++------------- data_juicer/core/executor.py | 25 ++++------------- data_juicer/core/ray_executor.py | 7 ++--- data_juicer/ops/base_op.py | 21 --------------- data_juicer/ops/load.py | 6 ++--- data_juicer/utils/ckpt_utils.py | 4 +-- tests/config/test_config_funcs.py | 4 +-- tests/ops/test_op_fusion.py | 3 ++- 9 files changed, 45 insertions(+), 76 deletions(-) diff --git a/data_juicer/core/analyzer.py b/data_juicer/core/analyzer.py index 5dbf233ba..ce1af2d84 100644 --- a/data_juicer/core/analyzer.py +++ b/data_juicer/core/analyzer.py @@ -31,7 +31,6 @@ def __init__(self, cfg=None): self.cfg = init_configs() if cfg is None else cfg self.work_dir = self.cfg.work_dir - self.ops = None if self.cfg.use_cache: logger.info(f'Using cache compression method: ' @@ -79,13 +78,12 @@ def run(self, load_data_np=None, skip_export=False): # extract processes logger.info('Preparing process operators...') - self.cfg.process, self.ops = load_ops(self.cfg.process, - self.cfg.op_fusion) + ops = load_ops(self.cfg.process, self.cfg.op_fusion) # 2. stats precompute only for filter ops logger.info('Computing the stats of dataset...') stats_collected = False - for op in self.ops: + for op in ops: if isinstance(op, Filter): original_process = op.process op.process = None diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index b0afefd19..ccf6ff1aa 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -2,6 +2,7 @@ import copy import inspect +import traceback from abc import ABC, abstractmethod from functools import wraps from time import time @@ -174,24 +175,32 @@ def process(self, unforkable_operators = set(UNFORKABLE.modules.keys()) dataset = self - for op in operators: - mp_context = ['forkserver', 'spawn'] if ( - op.use_cuda() or op._name in unforkable_operators) else None - setup_mp(mp_context) - - start = time() - # run single op - dataset = op(dataset, - exporter=exporter, - checkpointer=checkpointer, - tracer=tracer) - # record processed ops - if checkpointer is not None: - checkpointer.record(op._name, - list(op._process_kwargs.values())[0]) - end = time() - logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. ' - f'Left {len(dataset)} samples.') + try: + for op in operators: + mp_context = ['forkserver', 'spawn'] if ( + op.use_cuda() + or op._name in unforkable_operators) else None + setup_mp(mp_context) + + start = time() + # run single op + dataset = op.run(dataset, exporter=exporter, tracer=tracer) + # record processed ops + if checkpointer is not None: + checkpointer.record(op._of_cfg) + end = time() + logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. ' + f'Left {len(dataset)} samples.') + except: # noqa: E722 + logger.error(f'An error occurred during Op [{op._name}].') + traceback.print_exc() + exit(1) + finally: + if checkpointer: + logger.info('Writing checkpoint of dataset processed by ' + 'last op...') + dataset.cleanup_cache_files() + checkpointer.save_ckpt(dataset) return dataset def map(self, *args, **kargs): diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index c514cd99d..5949df76d 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -1,5 +1,4 @@ import os -import traceback from time import time from loguru import logger @@ -38,7 +37,6 @@ def __init__(self, cfg=None): self.work_dir = self.cfg.work_dir - self.ops = None self.tracer = None self.ckpt_manager = None @@ -58,17 +56,15 @@ def __init__(self, cfg=None): # check if there are existing checkpoints first and try to load the # checkpoints. If the checkpoints are loaded successfully, ops that # have been processed will be skipped. - self.process_list = self.cfg.process if self.cfg.use_checkpoint: logger.info('Preparing checkpoint manager...') self.ckpt_dir = os.path.join(self.work_dir, 'ckpt') self.ckpt_manager = CheckpointManager(self.ckpt_dir, - self.process_list, + self.cfg.process, self.cfg.np) if self.ckpt_manager.ckpt_available: logger.info('Found existed dataset checkpoint.') - self.process_list = self.ckpt_manager.get_left_process_list() - self.cfg.process = self.process_list + self.cfg.process = self.ckpt_manager.get_left_process_list() # prepare exporter and check export path suffix logger.info('Preparing exporter...') @@ -155,15 +151,14 @@ def run(self, load_data_np=None): # 2. extract processes logger.info('Preparing process operators...') - self.process_list, self.ops = load_ops(self.cfg.process, - self.cfg.op_fusion) + ops = load_ops(self.cfg.process, self.cfg.op_fusion) # 3. data process # - If tracer is open, trace each op after it's processed # - If checkpoint is open, clean the cache files after each process logger.info('Processing data...') tstart = time() - dataset = dataset.process(self.ops, + dataset = dataset.process(ops, exporter=self.exporter, checkpointer=self.ckpt_manager, tracer=self.tracer) @@ -172,17 +167,7 @@ def run(self, load_data_np=None): # 4. data export logger.info('Exporting dataset to disk...') - try: - self.exporter.export(dataset) - except: # noqa: E722 - logger.error('An error occurred during exporting the processed ' - 'dataset.') - traceback.print_exc() - if self.cfg.use_checkpoint: - logger.info('Writing checkpoint of dataset processed by ' - 'last op...') - dataset.cleanup_cache_files() - self.ckpt_manager.save_ckpt(dataset) + self.exporter.export(dataset) # compress the last dataset after exporting if self.cfg.use_cache and self.cfg.cache_compress: from data_juicer.utils.compress import compress diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index 291792d1e..ae1a51359 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -34,11 +34,9 @@ def __init__(self, cfg=None): self.work_dir = self.cfg.work_dir - self.ops = None # init ray logger.info('Initing Ray ...') ray.init(self.cfg.ray_address) - self.process_list = self.cfg.process def run(self, load_data_np=None): """ @@ -55,13 +53,12 @@ def run(self, load_data_np=None): dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) # 2. extract processes logger.info('Preparing process operators...') - self.process_list, self.ops = load_ops(self.cfg.process, - self.cfg.op_fusion) + ops = load_ops(self.cfg.process, self.cfg.op_fusion) # 3. data process logger.info('Processing data...') tstart = time.time() - dataset.process(self.ops) + dataset.process(ops) tend = time.time() logger.info(f'All Ops are done in {tend - tstart:.3f}s.') diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 5f602e165..d822a4fc7 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -157,27 +157,6 @@ def __init__(self, *args, **kwargs): method = wrap_func_with_nested_access(method) setattr(self, name, method) - def __call__(self, - dataset, - *, - exporter=None, - checkpointer=None, - tracer=None): - try: - dataset = self.run(dataset, exporter=exporter, tracer=tracer) - if checkpointer: - checkpointer.record(self._name, self._process_kwargs) - return dataset - except: # noqa: E722 - logger.error(f'An error occurred during Op [{self._name}].') - traceback.print_exc() - if checkpointer: - logger.info('Writing checkpoint of dataset processed by ' - 'last op...') - dataset.cleanup_cache_files() - checkpointer.save_ckpt(dataset) - exit(1) - @classmethod def is_batched_op(cls): return cls._batched_op diff --git a/data_juicer/ops/load.py b/data_juicer/ops/load.py index 60aac3ec4..e82ebb16a 100644 --- a/data_juicer/ops/load.py +++ b/data_juicer/ops/load.py @@ -32,7 +32,7 @@ def load_ops(process_list, op_fusion=False): if op_fusion: new_process_list, ops = fuse_operators(new_process_list, ops) - for process, op in zip(new_process_list, ops): - op._process_kwargs = process + for op_cfg, op in zip(new_process_list, ops): + op._op_cfg = op_cfg - return new_process_list, ops + return ops diff --git a/data_juicer/utils/ckpt_utils.py b/data_juicer/utils/ckpt_utils.py index 61b90b248..7a9bcffd6 100644 --- a/data_juicer/utils/ckpt_utils.py +++ b/data_juicer/utils/ckpt_utils.py @@ -58,10 +58,10 @@ def check_ckpt(self): os.makedirs(self.ckpt_dir, exist_ok=True) return False - def record(self, op_name, op_args): + def record(self, op_cfg: dict): """Save op name and args to op record, which is used to compare with the process list from config to decide if a checkpoint is available.""" - self.op_record.append({op_name: op_args}) + self.op_record.append(op_cfg) def check_ops_to_skip(self): """ diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index cf7e21f3d..5b3eeef06 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -68,8 +68,8 @@ def test_yaml_cfg_file(self): } }, 'nested dict load fail, un-expected internal value') - _, op_from_cfg = load_ops(cfg.process) - self.assertTrue(len(op_from_cfg) == 3) + ops_from_cfg = load_ops(cfg.process) + self.assertTrue(len(ops_from_cfg) == 3) def test_val_range_check_cmd(self): out = StringIO() diff --git a/tests/ops/test_op_fusion.py b/tests/ops/test_op_fusion.py index ad50ba472..d545e0074 100644 --- a/tests/ops/test_op_fusion.py +++ b/tests/ops/test_op_fusion.py @@ -9,7 +9,8 @@ class OpFusionTest(DataJuicerTestCaseBase): def _run_op_fusion(self, original_process_list, target_process_list): - new_process_list, _ = load_ops(original_process_list, op_fusion=True) + ops = load_ops(original_process_list, op_fusion=True) + new_process_list = [op._op_cfg for op in ops] self.assertEqual(new_process_list, target_process_list) def test_regular_config(self): From 69e199e48cf412a9a4ba888b3af2331d7bed32a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ce=20Ge=20=28=E6=88=88=E7=AD=96=29?= Date: Thu, 22 Aug 2024 19:26:01 +0800 Subject: [PATCH 2/2] Enhance/ckpt (#399) --- data_juicer/core/data.py | 8 ++++++-- data_juicer/utils/ckpt_utils.py | 4 ++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index ccf6ff1aa..c8f05a767 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -187,7 +187,7 @@ def process(self, dataset = op.run(dataset, exporter=exporter, tracer=tracer) # record processed ops if checkpointer is not None: - checkpointer.record(op._of_cfg) + checkpointer.record(op._op_cfg) end = time() logger.info(f'OP [{op._name}] Done in {end - start:.3f}s. ' f'Left {len(dataset)} samples.') @@ -196,7 +196,7 @@ def process(self, traceback.print_exc() exit(1) finally: - if checkpointer: + if checkpointer and dataset is not self: logger.info('Writing checkpoint of dataset processed by ' 'last op...') dataset.cleanup_cache_files() @@ -334,6 +334,10 @@ def cleanup_cache_files(self): cleanup_compressed_cache_files(self) return super().cleanup_cache_files() + @staticmethod + def load_from_disk(*args, **kargs): + return NestedDataset(Dataset.load_from_disk(*args, **kargs)) + def nested_query(root_obj: Union[NestedDatasetDict, NestedDataset, NestedQueryDict], key): diff --git a/data_juicer/utils/ckpt_utils.py b/data_juicer/utils/ckpt_utils.py index 7a9bcffd6..d22762adb 100644 --- a/data_juicer/utils/ckpt_utils.py +++ b/data_juicer/utils/ckpt_utils.py @@ -1,7 +1,6 @@ import json import os -from datasets import Dataset from loguru import logger @@ -133,5 +132,6 @@ def load_ckpt(self): :return: a dataset stored in checkpoint file. """ - ds = Dataset.load_from_disk(self.ckpt_ds_dir) + from data_juicer.core.data import NestedDataset + ds = NestedDataset.load_from_disk(self.ckpt_ds_dir) return ds