Skip to content

Commit cd39c59

Browse files
committed
added wrappers for DatasetNotFound error
1 parent d0456d1 commit cd39c59

File tree

4 files changed

+54
-2
lines changed

4 files changed

+54
-2
lines changed

datumaro/components/environment.py

+5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from datumaro.components.cli_plugin import CliPlugin, plugin_types
1414
from datumaro.components.format_detection import RejectionReason, detect_dataset_format
15+
from datumaro.components.wrappers import wrap_importer
1516
from datumaro.util.os_util import import_foreign_module, split_path
1617

1718
T = TypeVar("T")
@@ -226,6 +227,10 @@ def _register_plugins(self, plugins):
226227
self.transforms.batch_register(plugins)
227228
self.validators.batch_register(plugins)
228229

230+
for key in self.importers:
231+
importer = self.importers.get(key)
232+
wrap_importer(importer)
233+
229234
def make_extractor(self, name, *args, **kwargs):
230235
return self.extractors.get(name)(*args, **kwargs)
231236

datumaro/components/errors.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,11 @@ def __str__(self):
273273
@define(auto_exc=False)
274274
class DatasetNotFoundError(DatasetImportError):
275275
path = field()
276+
ext = field(default="")
276277

277278
def __str__(self):
278-
return f"Failed to find dataset at '{self.path}'"
279+
file_ext_info = f", file '{self.ext}' was not found" if self.ext else ''
280+
return f"Failed to find dataset at '{self.path}' {file_ext_info}"
279281

280282

281283
@define(auto_exc=False)

datumaro/components/extractor.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -433,13 +433,17 @@ def find_sources(cls, path) -> List[Dict]:
433433
def find_sources_with_params(cls, path, **extra_params) -> List[Dict]:
434434
return cls.find_sources(path)
435435

436+
@classmethod
437+
def _generate_not_found_error(self, path):
438+
return DatasetNotFoundError(path)
439+
436440
def __call__(self, path, **extra_params):
437441
if not path or not osp.exists(path):
438442
raise DatasetNotFoundError(path)
439443

440444
found_sources = self.find_sources_with_params(osp.normpath(path), **extra_params)
441445
if not found_sources:
442-
raise DatasetNotFoundError(path)
446+
raise self._generate_not_found_error(path)
443447

444448
sources = []
445449
for desc in found_sources:

datumaro/components/wrappers.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (C) 2024 CVAT.ai Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
from unittest import mock
6+
from typing import Optional, Callable
7+
8+
from datumaro.components.errors import DatasetNotFoundError
9+
10+
def wrap_find_sources_recursive(importer):
11+
@classmethod
12+
def updated_find_sources_recursive(
13+
cls,
14+
path: str,
15+
ext: Optional[str],
16+
extractor_name: str,
17+
filename: str = "*",
18+
dirname: str = "",
19+
file_filter: Optional[Callable[[str], bool]] = None,
20+
max_depth: int = 3,
21+
):
22+
sources = super(importer, cls)._find_sources_recursive(
23+
path, ext, extractor_name,
24+
filename, dirname, file_filter, max_depth
25+
)
26+
if not sources:
27+
cls._not_found_error_data = {"ext": ext}
28+
29+
return updated_find_sources_recursive
30+
31+
def wrap_generate_not_found_error(importer):
32+
@classmethod
33+
def updated_generate_not_found_error(cls, path):
34+
return DatasetNotFoundError(path, cls._not_found_error_data.get("ext"))
35+
36+
return updated_generate_not_found_error
37+
38+
def wrap_importer(importer):
39+
mock.patch.object(importer, '_find_sources_recursive', new=wrap_find_sources_recursive(importer)).start()
40+
mock.patch.object(importer, '_generate_not_found_error', new=wrap_generate_not_found_error(importer)).start()
41+
mock.patch.object(importer, '_not_found_error_data', new={"ext": ""}, create=True).start()

0 commit comments

Comments
 (0)