Skip to content

Commit aae4bb8

Browse files
committed
updated solution to use detect_dataset
1 parent 48993d4 commit aae4bb8

File tree

4 files changed

+33
-31
lines changed

4 files changed

+33
-31
lines changed

datumaro/components/dataset.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from datumaro.components.errors import (
2323
CategoriesRedefinedError,
2424
ConflictingCategoriesError,
25+
DatasetNotFoundError,
2526
MediaTypeError,
2627
MultipleFormatsMatchError,
2728
NoMatchingFormatsError,
@@ -1158,6 +1159,15 @@ def import_from(
11581159

11591160
if not format:
11601161
format = cls.detect(path, env=env)
1162+
else:
1163+
not_found_error_instance = DatasetNotFoundError(path)
1164+
1165+
def not_found_error(format_name, reason, human_message):
1166+
not_found_error_instance.reason = human_message
1167+
1168+
detected = env.detect_dataset(path, rejection_callback=not_found_error, depth=3, format=format)
1169+
if not detected:
1170+
raise not_found_error_instance
11611171

11621172
# TODO: remove importers, put this logic into extractors
11631173
if format in env.importers:
@@ -1227,7 +1237,14 @@ def import_from(
12271237
return dataset
12281238

12291239
@staticmethod
1230-
def detect(path: str, *, env: Optional[Environment] = None, depth: int = 2) -> str:
1240+
def detect(
1241+
path: str,
1242+
*,
1243+
env: Optional[Environment] = None,
1244+
depth: int = 2,
1245+
rejection_callback:Optional[Callable] = None,
1246+
format: Optional[str] = None,
1247+
) -> str:
12311248
"""
12321249
Attempts to detect dataset format of a given directory.
12331250
@@ -1247,7 +1264,7 @@ def detect(path: str, *, env: Optional[Environment] = None, depth: int = 2) -> s
12471264
if depth < 0:
12481265
raise ValueError("Depth cannot be less than zero")
12491266

1250-
matches = env.detect_dataset(path, depth=depth)
1267+
matches = env.detect_dataset(path, depth=depth, rejection_callback=rejection_callback, format=format)
12511268
if not matches:
12521269
raise NoMatchingFormatsError()
12531270
elif 1 < len(matches):

datumaro/components/environment.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,22 @@ def detect_dataset(
252252
path: str,
253253
depth: int = 1,
254254
rejection_callback: Optional[Callable[[str, RejectionReason, str], None]] = None,
255+
format: Optional[str] = None
255256
) -> List[str]:
256257
ignore_dirs = {"__MSOSX", "__MACOSX"}
257258
matched_formats = set()
259+
detectors = []
260+
if not format:
261+
detectors = (
262+
(format_name, importer.detect)
263+
for format_name, importer in self.importers.items.items()
264+
)
265+
elif self.is_format_known(format):
266+
detectors = [(format, self.importers.get(format).detect)]
267+
258268
for _ in range(depth + 1):
259269
detected_formats = detect_dataset_format(
260-
(
261-
(format_name, importer.detect)
262-
for format_name, importer in self.importers.items.items()
263-
),
270+
detectors,
264271
path,
265272
rejection_callback=rejection_callback,
266273
)

datumaro/components/errors.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -273,14 +273,10 @@ def __str__(self):
273273
@define(auto_exc=False)
274274
class DatasetNotFoundError(DatasetImportError):
275275
path = field()
276-
ext = field(default="")
277-
filename = field(default="")
276+
reason = field(default='')
278277

279278
def __str__(self):
280-
file_ext_info = (
281-
f", file '{self.filename}{self.ext}' was not found" if self.ext or self.filename else ""
282-
)
283-
return f"Failed to find dataset at '{self.path}' {file_ext_info}"
279+
return f"Failed to find dataset at '{self.path}' {self.reason}"
284280

285281

286282
@define(auto_exc=False)

datumaro/components/extractor.py

+1-19
Original file line numberDiff line numberDiff line change
@@ -415,17 +415,6 @@ def get(self, id, subset=None):
415415

416416

417417
class Importer(CliPlugin):
418-
def __init__(self):
419-
self.__not_found_error_data = {"ext": "", "filename": ""}
420-
421-
@property
422-
def _not_found_error_data(self):
423-
return self.__not_found_error_data
424-
425-
@_not_found_error_data.setter
426-
def _not_found_error_data_setter(self, val):
427-
self.__not_found_error_data = val
428-
429418
@classmethod
430419
def detect(
431420
cls,
@@ -450,11 +439,7 @@ def __call__(self, path, **extra_params):
450439

451440
found_sources = self.find_sources_with_params(osp.normpath(path), **extra_params)
452441
if not found_sources:
453-
raise DatasetNotFoundError(
454-
path,
455-
self._not_found_error_data.get("ext", ""),
456-
self._not_found_error_data.get("filename", ""),
457-
)
442+
raise DatasetNotFoundError(path)
458443

459444
sources = []
460445
for desc in found_sources:
@@ -520,9 +505,6 @@ def _find_sources_recursive(
520505
if sources:
521506
break
522507

523-
if not sources:
524-
cls._not_found_error_data = {"ext": ext, "filename": filename}
525-
526508
return sources
527509

528510

0 commit comments

Comments
 (0)