diff --git a/src/distilabel/steps/generators/huggingface.py b/src/distilabel/steps/generators/huggingface.py index 243277fd4..4789dfe10 100644 --- a/src/distilabel/steps/generators/huggingface.py +++ b/src/distilabel/steps/generators/huggingface.py @@ -460,16 +460,16 @@ class LoadDataFromDisk(LoadDataFromHub): Attributes: dataset_path: The path to the dataset or distiset. split: The split of the dataset to load (typically will be `train`, `test` or `validation`). - config: The configuration of the dataset to load. This is optional and only needed - if the dataset has multiple configurations. + config: The configuration of the dataset to load. Defaults to `default`, if there are + multiple configurations in the dataset this must be suplied or an error is raised. Runtime parameters: - `batch_size`: The batch size to use when processing the data. - `dataset_path`: The path to the dataset or distiset. - `is_distiset`: Whether the dataset to load is a `Distiset` or not. Defaults to False. - `split`: The split of the dataset to load. Defaults to 'train'. - - `config`: The configuration of the dataset to load. This is optional and only - needed if the dataset has multiple configurations. + - `config`: The configuration of the dataset to load. Defaults to `default`, if there are + multiple configurations in the dataset this must be suplied or an error is raised. - `num_examples`: The number of examples to load from the dataset. By default will load all examples. - `storage_options`: Key/value pairs to be passed on to the file-system backend, if any. @@ -539,10 +539,12 @@ class LoadDataFromDisk(LoadDataFromHub): default=None, description="Path to the dataset or distiset.", ) - config: RuntimeParameter[str] = Field( - default=None, - description="The configuration of the dataset to load. This is optional and only" - " needed if the dataset has multiple configurations.", + config: Optional[RuntimeParameter[str]] = Field( + default="default", + description=( + "The configuration of the dataset to load. Will default to 'default'", + " which corresponds to a distiset with a single configuration.", + ), ) is_distiset: Optional[RuntimeParameter[bool]] = Field( default=False, @@ -557,6 +559,7 @@ class LoadDataFromDisk(LoadDataFromHub): default=None, description="The split of the dataset to load. By default will load the whole Dataset/Distiset.", ) + repo_id: ExcludedField[Union[str, None]] = None def load(self) -> None: """Load the dataset from the file/s in disk.""" @@ -567,8 +570,15 @@ def load(self) -> None: keep_in_memory=self.keep_in_memory, storage_options=self.storage_options, ) - if self.config: - ds = ds[self.config] + if self.config not in ds.keys(): + # TODO: Add FAQ for this case, pointing to the Distiset documentation on configurations. + raise ValueError( + f"Configuration '{self.config}' not found in the Distiset, available ones" + f" are: {list(ds.keys())}. Please try changing the `config` parameter to one " + "of the available configurations.\n\n" + f"For more information on Distiset configurations, please visit https://distilabel.argilla.io/dev/sections/how_to_guides/advanced/distiset/#using-the-distiset-dataset-object" + ) + ds = ds[self.config] else: ds = load_from_disk( @@ -596,9 +606,7 @@ def outputs(self) -> List[str]: The columns that will be generated by this step. """ # We assume there are Dataset/IterableDataset, not it's ...Dict counterparts - if self._dataset is Ellipsis: - raise ValueError( - "Dataset not loaded yet, you must call `load` method first." - ) + if self._dataset is None: + self.load() return self._dataset.column_names diff --git a/tests/unit/steps/generators/test_huggingface.py b/tests/unit/steps/generators/test_huggingface.py index 281d5187e..23b3b8007 100644 --- a/tests/unit/steps/generators/test_huggingface.py +++ b/tests/unit/steps/generators/test_huggingface.py @@ -168,6 +168,32 @@ def test_load_dataset_from_disk(self) -> None: assert isinstance(generator_step_output[1], bool) assert len(generator_step_output[0]) == 3 + @pytest.mark.parametrize("config_name", ["default", "missnamed_config"]) + def test_load_distiset_from_disk_default(self, config_name: str) -> None: + distiset = Distiset( + { + "default": Dataset.from_dict({"a": [1, 2, 3]}), + } + ) + with tempfile.TemporaryDirectory() as tmpdir: + dataset_path = str(Path(tmpdir) / "dataset_path") + distiset.save_to_disk(dataset_path) + + loader = LoadDataFromDisk( + dataset_path=dataset_path, + is_distiset=True, + config=config_name, + ) + if config_name != "default": + with pytest.raises(ValueError): + loader.load() + else: + loader.load() + generator_step_output = next(loader.process()) + assert isinstance(generator_step_output, tuple) + assert isinstance(generator_step_output[1], bool) + assert len(generator_step_output[0]) == 3 + def test_load_distiset_from_disk(self) -> None: distiset = Distiset( {