diff --git a/datumaro/components/dataset.py b/datumaro/components/dataset.py index d5309dadf1..b9752092c1 100644 --- a/datumaro/components/dataset.py +++ b/datumaro/components/dataset.py @@ -22,6 +22,7 @@ from datumaro.components.errors import ( CategoriesRedefinedError, ConflictingCategoriesError, + DatasetNotFoundError, MediaTypeError, MultipleFormatsMatchError, NoMatchingFormatsError, @@ -421,6 +422,7 @@ def _update_status(item_id, new_status: ItemStatus): ) transform = None + old_ids = set() if self._transforms: transform = _StackedTransform(source, self._transforms) if transform.is_local: @@ -1157,6 +1159,15 @@ def import_from( if not format: format = cls.detect(path, env=env) + else: + not_found_error_instance = DatasetNotFoundError(path) + + def not_found_error(format_name, reason, human_message): + not_found_error_instance.reason = human_message + + detected = env.detect_dataset(path, rejection_callback=not_found_error, depth=3, format=format) + if not detected: + raise not_found_error_instance # TODO: remove importers, put this logic into extractors if format in env.importers: @@ -1226,7 +1237,14 @@ def import_from( return dataset @staticmethod - def detect(path: str, *, env: Optional[Environment] = None, depth: int = 2) -> str: + def detect( + path: str, + *, + env: Optional[Environment] = None, + depth: int = 2, + rejection_callback:Optional[Callable] = None, + format: Optional[str] = None, + ) -> str: """ Attempts to detect dataset format of a given directory. @@ -1246,7 +1264,7 @@ def detect(path: str, *, env: Optional[Environment] = None, depth: int = 2) -> s if depth < 0: raise ValueError("Depth cannot be less than zero") - matches = env.detect_dataset(path, depth=depth) + matches = env.detect_dataset(path, depth=depth, rejection_callback=rejection_callback, format=format) if not matches: raise NoMatchingFormatsError() elif 1 < len(matches): diff --git a/datumaro/components/environment.py b/datumaro/components/environment.py index 4e2b9e72c3..4a1848da1b 100644 --- a/datumaro/components/environment.py +++ b/datumaro/components/environment.py @@ -252,15 +252,22 @@ def detect_dataset( path: str, depth: int = 1, rejection_callback: Optional[Callable[[str, RejectionReason, str], None]] = None, + format: Optional[str] = None ) -> List[str]: ignore_dirs = {"__MSOSX", "__MACOSX"} matched_formats = set() + detectors = [] + if not format: + detectors = ( + (format_name, importer.detect) + for format_name, importer in self.importers.items.items() + ) + elif self.is_format_known(format): + detectors = [(format, self.importers.get(format).detect)] + for _ in range(depth + 1): detected_formats = detect_dataset_format( - ( - (format_name, importer.detect) - for format_name, importer in self.importers.items.items() - ), + detectors, path, rejection_callback=rejection_callback, ) diff --git a/datumaro/components/errors.py b/datumaro/components/errors.py index 45c533efdc..b6387fe38b 100644 --- a/datumaro/components/errors.py +++ b/datumaro/components/errors.py @@ -273,9 +273,10 @@ def __str__(self): @define(auto_exc=False) class DatasetNotFoundError(DatasetImportError): path = field() + reason = field(default='') def __str__(self): - return f"Failed to find dataset at '{self.path}'" + return f"Failed to find dataset at '{self.path}' {self.reason}" @define(auto_exc=False) diff --git a/datumaro/components/extractor.py b/datumaro/components/extractor.py index 290798e02f..6b1867394f 100644 --- a/datumaro/components/extractor.py +++ b/datumaro/components/extractor.py @@ -504,6 +504,7 @@ def _find_sources_recursive( ) if sources: break + return sources