Skip to content

Commit 97cc020

Browse files
committed
add tests for yolo_detection_format
1 parent 06da123 commit 97cc020

File tree

10 files changed

+1122
-31
lines changed

10 files changed

+1122
-31
lines changed

datumaro/plugins/yolo_detection_format/converter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _check_dataset(self):
6363
if subset not in YoloDetectionPath.ALLOWED_SUBSET_NAMES:
6464
raise DatasetExportError(
6565
f"The allowed subset name is in {YoloDetectionPath.ALLOWED_SUBSET_NAMES}, "
66-
f'so that subset "{subset}" is not allowed.'
66+
f"so that subset '{subset}' is not allowed."
6767
)
6868

6969
for must_name in YoloDetectionPath.MUST_SUBSET_NAMES:

datumaro/plugins/yolo_detection_format/extractor.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import yaml
88

99
import os.path as osp
10+
import os
1011
import re
1112
from collections import OrderedDict
1213
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union, Iterator
@@ -68,6 +69,9 @@ def __init__(
6869
if not osp.isdir(config_path):
6970
raise DatasetImportError(f"{config_path} should be a directory.")
7071

72+
if not osp.isfile(osp.join(config_path, META_FILE)):
73+
raise DatasetImportError(f"Can't find {META_FILE} in {config_path}")
74+
7175
if not urls:
7276
raise DatasetImportError(
7377
f"`urls` should be specified for {self.__class__.__name__}, "
@@ -86,7 +90,32 @@ def __init__(
8690
self._urls = urls
8791
self._img_files = self._load_img_files(rootpath)
8892
self._ann_types = set()
93+
94+
config = YoloDetectionPath._parse_config(osp.join(config_path, META_FILE))
8995

96+
subsets = {k: v for k, v in config.items() if k in YoloDetectionPath.ALLOWED_SUBSET_NAMES and v is not None}
97+
98+
for subset_name, list_path in subsets.items():
99+
subset = YoloDetectionExtractor.Subset(subset_name, self)
100+
101+
if osp.isdir(osp.join(rootpath, list_path)):
102+
list_path = osp.join(rootpath, list_path)
103+
f = os.listdir(list_path)
104+
subset.items = OrderedDict(
105+
(self.name_from_path(p), self.localize_path(p)) for p in f if p.strip()
106+
)
107+
elif osp.isfile(osp.join(rootpath, list_path)):
108+
with open(osp.join(rootpath, list_path), "r", encoding="utf-8") as f:
109+
subset.items = OrderedDict(
110+
(self.name_from_path(p), self.localize_path(p)) for p in f if p.strip()
111+
)
112+
else:
113+
raise InvalidAnnotationError(f"Can't find '{subset_name}' subset list file")
114+
115+
subsets[subset_name] = subset
116+
117+
self._subsets: Dict[str, YoloDetectionExtractor.Subset] = subsets
118+
90119
self._categories = {
91120
AnnotationType.label: self._load_categories(
92121
osp.join(self._path, META_FILE)
@@ -292,13 +321,11 @@ def _parse_annotations(
292321

293322
annotations.append(
294323
Bbox(
295-
x * image_width,
296-
y * image_height,
297-
w * image_width,
298-
h * image_height,
324+
int(x * image_width),
325+
int(y * image_height),
326+
int(w * image_width),
327+
int(h * image_height),
299328
label=label_id,
300-
id=idx,
301-
group=idx,
302329
)
303330
)
304331
except Exception as e:

datumaro/plugins/yolo_detection_format/format.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from enum import Enum, auto
66
from typing import Dict
77
import re
8-
8+
import yaml
99

1010
class YoloDetectionPath:
1111
DEFAULT_SUBSET_NAME = "train"
@@ -16,18 +16,10 @@ class YoloDetectionPath:
1616

1717
@staticmethod
1818
def _parse_config(path: str) -> Dict[str, str]:
19-
with open(path, "r", encoding="utf-8") as f:
20-
config_lines = f.readlines()
21-
22-
config = {}
23-
24-
for line in config_lines:
25-
match = re.match(r"^\s*(\w+)\s*=\s*(.+)$", line)
26-
if not match:
27-
continue
28-
29-
key = match.group(1)
30-
value = match.group(2)
31-
config[key] = value
19+
with open(path, "r") as fp:
20+
loaded = yaml.safe_load(fp.read())
3221

33-
return config
22+
if not isinstance(loaded, dict):
23+
raise Exception("Invalid config format")
24+
25+
return loaded

datumaro/plugins/yolo_detection_format/importer.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _check_ann_file_impl(cls, fp: TextIOWrapper) -> bool:
5353
fields = line.rstrip("\n").split(" ")
5454
if len(fields) != 5:
5555
raise DatasetImportError(
56-
f"Yolo format txt file should have 5 fields for each line, "
56+
f"Yolo Detection format txt file should have 5 fields for each line, "
5757
f"but the read line has {len(fields)} fields: fields={fields}."
5858
)
5959

@@ -68,19 +68,11 @@ def _check_ann_file_impl(cls, fp: TextIOWrapper) -> bool:
6868

6969
@classmethod
7070
def _find_loose(cls, path: str, dirname: str) -> List[Dict[str, Any]]:
71-
def _filter_ann_file(fpath: str):
72-
try:
73-
with open(fpath, "r") as fp:
74-
return cls._check_ann_file_impl(fp)
75-
except DatasetImportError:
76-
return False
77-
7871
sources = cls._find_sources_recursive(
7972
path,
8073
ext=".txt",
8174
extractor_name="",
8275
dirname=dirname,
83-
file_filter=_filter_ann_file,
8476
filename="**/*",
8577
max_depth=1
8678
)

0 commit comments

Comments
 (0)