diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 6c22555a0..d7db64837 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -196,7 +196,7 @@ def process(self, traceback.print_exc() exit(1) finally: - if checkpointer: + if checkpointer and dataset is not self: logger.info('Writing checkpoint of dataset processed by ' 'last op...') dataset.cleanup_cache_files() @@ -337,6 +337,10 @@ def cleanup_cache_files(self): cleanup_compressed_cache_files(self) return super().cleanup_cache_files() + @staticmethod + def load_from_disk(*args, **kargs): + return NestedDataset(Dataset.load_from_disk(*args, **kargs)) + def nested_query(root_obj: Union[NestedDatasetDict, NestedDataset, NestedQueryDict], key): diff --git a/data_juicer/utils/ckpt_utils.py b/data_juicer/utils/ckpt_utils.py index 7a9bcffd6..d22762adb 100644 --- a/data_juicer/utils/ckpt_utils.py +++ b/data_juicer/utils/ckpt_utils.py @@ -1,7 +1,6 @@ import json import os -from datasets import Dataset from loguru import logger @@ -133,5 +132,6 @@ def load_ckpt(self): :return: a dataset stored in checkpoint file. """ - ds = Dataset.load_from_disk(self.ckpt_ds_dir) + from data_juicer.core.data import NestedDataset + ds = NestedDataset.load_from_disk(self.ckpt_ds_dir) return ds