Skip to content

Commit b7d527e

Browse files
authored
Merge branch 'develop' into zm/fix-dataset-transforms
2 parents 3150a6f + d0456d1 commit b7d527e

7 files changed

+186
-34
lines changed

CHANGELOG.md

+8
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6262
(<https://github.com/cvat-ai/datumaro/pull/17>)
6363
- Annotation matching algorithm in `datumaro.components.operations.match_segments()`
6464
(<https://github.com/cvat-ai/datumaro/pull/30>)
65+
- Automatic detection of `is_crowd` parameter is disabled in
66+
`segment_iou()`, added a separate function argument
67+
(turned off by default)
68+
(<https://github.com/cvat-ai/datumaro/pull/41>)
6569

6670
### Deprecated
6771
- `--save-images` is replaced with `--save-media` in CLI and converter API
6872
(<https://github.com/openvinotoolkit/datumaro/pull/539>)
6973
- \[API\] `image`, `point_cloud` and `related_images` of `DatasetItem` are
7074
replaced with `media` and `media_as(type)` members and c-tor parameters
7175
(<https://github.com/openvinotoolkit/datumaro/pull/539>)
76+
- \[API\] `datumaro.util.annotation_util._get_bbox()` is renamed into `get_bbox()`
77+
(<https://github.com/cvat-ai/datumaro/pull/41>)
7278

7379
### Removed
7480
- TBD
@@ -86,6 +92,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
8692
(<https://github.com/cvat-ai/datumaro/pull/34>)
8793
- Added missing `PointCloud` media type in the datumaro module namespace
8894
(<https://github.com/cvat-ai/datumaro/pull/34>)
95+
- Incorrect computation of binary mask bbox (missed 1 pixel of the size)
96+
(<https://github.com/cvat-ai/datumaro/pull/41>)
8997
- `Dataset.get()` could ignore existing transforms in the dataset
9098
(<https://github.com/cvat-ai/datumaro/pull/45>)
9199

datumaro/util/annotation_util.py

+75-29
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,34 @@
11
# Copyright (C) 2020-2021 Intel Corporation
2+
# Copyright (C) 2024 CVAT.ai Corporation
23
#
34
# SPDX-License-Identifier: MIT
45

6+
import warnings
57
from itertools import groupby
68
from typing import Callable, Dict, Iterable, NewType, Optional, Sequence, Tuple, Union
79

810
import numpy as np
911
from typing_extensions import Literal
1012

11-
from datumaro.components.annotation import AnnotationType, LabelCategories, Mask, RleMask, _Shape
13+
from datumaro.components.annotation import (
14+
Annotation,
15+
AnnotationType,
16+
LabelCategories,
17+
Mask,
18+
RleMask,
19+
_Shape,
20+
)
1221
from datumaro.util.mask_tools import mask_to_rle
1322

23+
BboxCoords = Tuple[float, float, float, float]
24+
"A tuple of bounding box coordinates, (x, y, w, h)"
25+
26+
Shape = NewType("Shape", _Shape)
27+
28+
SpatialAnnotation = Union[Shape, Mask]
1429

15-
def find_instances(instance_anns):
30+
31+
def find_instances(instance_anns: Sequence[Annotation]) -> Sequence[Sequence[Annotation]]:
1632
instance_anns = sorted(instance_anns, key=lambda a: a.group)
1733
ann_groups = []
1834
for g_id, group in groupby(instance_anns, lambda a: a.group):
@@ -24,22 +40,22 @@ def find_instances(instance_anns):
2440
return ann_groups
2541

2642

27-
def find_group_leader(group):
43+
def find_group_leader(group: Sequence[SpatialAnnotation]) -> SpatialAnnotation:
2844
return max(group, key=lambda x: x.get_area())
2945

3046

31-
BboxCoords = Tuple[float, float, float, float]
32-
Shape = NewType("Shape", _Shape)
33-
SpatialAnnotation = Union[Shape, Mask]
34-
47+
def get_bbox(ann: Union[Sequence, BboxCoords, SpatialAnnotation]) -> BboxCoords:
48+
"An utility function to get a bbox of the bbox-like annotation"
3549

36-
def _get_bbox(ann: Union[Sequence, SpatialAnnotation]) -> BboxCoords:
37-
if isinstance(ann, (_Shape, Mask)):
50+
if hasattr(ann, "get_bbox"):
3851
return ann.get_bbox()
3952
elif hasattr(ann, "__len__") and len(ann) == 4:
4053
return ann
4154
else:
42-
raise ValueError("The value of type '%s' can't be treated as a " "bounding box" % type(ann))
55+
raise ValueError("The value of type '%s' can't be treated as a bounding box" % type(ann))
56+
57+
58+
_deprecated_get_bbox = get_bbox # backward compatibility
4359

4460

4561
def max_bbox(annotations: Iterable[Union[BboxCoords, SpatialAnnotation]]) -> BboxCoords:
@@ -50,7 +66,7 @@ def max_bbox(annotations: Iterable[Union[BboxCoords, SpatialAnnotation]]) -> Bbo
5066
bbox (tuple): (x, y, w, h)
5167
"""
5268

53-
boxes = [_get_bbox(ann) for ann in annotations]
69+
boxes = [get_bbox(ann) for ann in annotations]
5470
x0 = min((b[0] for b in boxes), default=0)
5571
y0 = min((b[1] for b in boxes), default=0)
5672
x1 = max((b[0] + b[2] for b in boxes), default=0)
@@ -67,7 +83,7 @@ def mean_bbox(annotations: Iterable[Union[BboxCoords, SpatialAnnotation]]) -> Bb
6783
"""
6884

6985
le = len(annotations)
70-
boxes = [_get_bbox(ann) for ann in annotations]
86+
boxes = [get_bbox(ann) for ann in annotations]
7187
mlb = sum(b[0] for b in boxes) / le
7288
mtb = sum(b[1] for b in boxes) / le
7389
mrb = sum(b[0] + b[2] for b in boxes) / le
@@ -101,12 +117,15 @@ def nms(segments, iou_thresh=0.5):
101117
return predictions
102118

103119

104-
def bbox_iou(a, b) -> Union[Literal[-1], float]:
120+
def bbox_iou(
121+
a: Union[SpatialAnnotation, BboxCoords],
122+
b: Union[SpatialAnnotation, BboxCoords],
123+
) -> Union[Literal[-1], float]:
105124
"""
106125
IoU computations for simple cases with bounding boxes
107126
"""
108-
bbox_a = _get_bbox(a)
109-
bbox_b = _get_bbox(b)
127+
bbox_a = get_bbox(a)
128+
bbox_b = get_bbox(b)
110129

111130
aX, aY, aW, aH = bbox_a
112131
bX, bY, bW, bH = bbox_b
@@ -127,23 +146,39 @@ def bbox_iou(a, b) -> Union[Literal[-1], float]:
127146
return intersection / union
128147

129148

130-
def segment_iou(a, b):
149+
def segment_iou(
150+
gt_ann: SpatialAnnotation,
151+
ds_ann: SpatialAnnotation,
152+
*,
153+
is_crowd: Union[bool, str] = False,
154+
) -> float:
131155
"""
132156
Generic IoU computation with masks, polygons, and boxes.
133-
Returns -1 if no intersection, [0; 1] otherwise
157+
158+
Parameters:
159+
is_crowd - bool or GT annotation attribute name - if true, consider
160+
the GT annotation a crowd, so that the DS annotation is excluded
161+
from the denominator of the IoU formula, i.e. it becomes I / GT area.
162+
This is useful if you want to check a specific object to be within a crowd,
163+
where the crowd ob objects is annotated by a single GT mask.
164+
165+
Returns: -1 if no intersection, [0; 1] otherwise
134166
"""
135167
from pycocotools import mask as mask_utils
136168

137-
a_bbox = list(a.get_bbox())
138-
b_bbox = list(b.get_bbox())
169+
gt_bbox = list(gt_ann.get_bbox())
170+
ds_bbox = list(ds_ann.get_bbox())
171+
172+
if isinstance(is_crowd, str):
173+
is_crowd = gt_ann.attributes.get(is_crowd, False) is True
139174

140-
is_bbox = AnnotationType.bbox in [a.type, b.type]
175+
is_bbox = AnnotationType.bbox in [gt_ann.type, ds_ann.type]
141176
if is_bbox:
142-
a = [a_bbox]
143-
b = [b_bbox]
177+
gt_ann = [gt_bbox]
178+
ds_ann = [ds_bbox]
144179
else:
145-
w = max(a_bbox[0] + a_bbox[2], b_bbox[0] + b_bbox[2])
146-
h = max(a_bbox[1] + a_bbox[3], b_bbox[1] + b_bbox[3])
180+
w = max(gt_bbox[0] + gt_bbox[2], ds_bbox[0] + ds_bbox[2])
181+
h = max(gt_bbox[1] + gt_bbox[3], ds_bbox[1] + ds_bbox[3])
147182

148183
def _to_rle(ann):
149184
if ann.type == AnnotationType.polygon:
@@ -153,11 +188,12 @@ def _to_rle(ann):
153188
elif ann.type == AnnotationType.mask:
154189
return mask_utils.frPyObjects([mask_to_rle(ann.image)], h, w)
155190
else:
156-
raise TypeError("Unexpected arguments: %s, %s" % (a, b))
191+
raise TypeError("Unexpected arguments: %s, %s" % (gt_ann, ds_ann))
157192

158-
a = _to_rle(a)
159-
b = _to_rle(b)
160-
return float(mask_utils.iou(a, b, [not is_bbox]).item())
193+
gt_ann = _to_rle(gt_ann)
194+
ds_ann = _to_rle(ds_ann)
195+
196+
return float(mask_utils.iou(gt_ann, ds_ann, [is_crowd]).item())
161197

162198

163199
def PDJ(a, b, eps=None, ratio=0.05, bbox=None):
@@ -270,7 +306,7 @@ def make_label_id_mapping(
270306
Returns:
271307
272308
| map_id (callable): src id -> dst id
273-
| id_mapping (dict): src id -> dst i
309+
| id_mapping (dict): src id -> dst id
274310
| src_labels (dict): src id -> src label
275311
| dst_labels (dict): dst id -> dst label
276312
"""
@@ -286,3 +322,13 @@ def map_id(src_id):
286322
return id_mapping.get(src_id, fallback)
287323

288324
return map_id, id_mapping, source_labels, target_labels
325+
326+
327+
def __getattr__(name: str):
328+
if name is "_get_bbox":
329+
warnings.warn(
330+
"_get_bbox() is deprecated, please use get_bbox() instead", category=DeprecationWarning
331+
)
332+
return _deprecated_get_bbox
333+
334+
return globals().get(name)

datumaro/util/mask_tools.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,13 @@ def rles_to_mask(rles: Sequence[Union[CompressedRle, Polygon]], width, height) -
367367
def find_mask_bbox(mask: BinaryMask) -> BboxCoords:
368368
cols = np.any(mask, axis=0)
369369
rows = np.any(mask, axis=1)
370+
has_pixels = np.any(cols)
371+
if not has_pixels:
372+
return BboxCoords(0, 0, 0, 0)
373+
370374
x0, x1 = np.where(cols)[0][[0, -1]]
371375
y0, y1 = np.where(rows)[0][[0, -1]]
372-
return BboxCoords(x0, y0, x1 - x0, y1 - y0)
376+
return BboxCoords(x0, y0, x1 - x0 + 1, y1 - y0 + 1)
373377

374378

375379
def merge_masks(

tests/test_annotation_util.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (C) 2024 CVAT.ai Corporation
2+
#
3+
# SPDX-License-Identifier: MIT
4+
5+
import numpy as np
6+
import pytest
7+
8+
from datumaro.components.annotation import Bbox, Mask, Polygon
9+
from datumaro.util.annotation_util import SpatialAnnotation, get_bbox, segment_iou
10+
11+
from .requirements import Requirements, mark_requirement
12+
13+
14+
class SegmentIouTest:
15+
@pytest.mark.parametrize(
16+
"a, b, expected_iou",
17+
[
18+
(Bbox(0, 0, 2, 2), Bbox(0, 0, 2, 1), 0.5), # nested
19+
(Bbox(0, 0, 2, 2), Bbox(1, 0, 2, 2), 1 / 3), # partially intersecting
20+
(Bbox(0, 0, 2, 2), Polygon([0, 0, 0, 1, 1, 1, 1, 0]), 0.25),
21+
(Polygon([0, 0, 0, 2, 2, 2, 2, 0]), Polygon([1, 0, 3, 0, 3, 2, 1, 2]), 1 / 3),
22+
(Bbox(0, 0, 2, 2), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 1 / 3),
23+
(Mask(np.array([[1, 1, 0], [1, 1, 0]])), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 1 / 3),
24+
(Polygon([0, 0, 0, 2, 2, 2, 2, 0]), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 1 / 3),
25+
],
26+
)
27+
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
28+
def test_segment_iou_can_match_shapes(
29+
self, a: SpatialAnnotation, b: SpatialAnnotation, expected_iou: float
30+
):
31+
assert expected_iou == segment_iou(a, b)
32+
33+
@pytest.mark.parametrize(
34+
"a, b, expected_iou",
35+
[
36+
(Bbox(0, 0, 2, 2), Bbox(0, 0, 2, 1), 0.5), # nested
37+
(Bbox(0, 0, 2, 2), Bbox(1, 0, 2, 2), 0.5), # partially intersecting
38+
(Bbox(0, 0, 2, 2), Polygon([0, 0, 0, 1, 1, 1, 1, 0]), 0.25),
39+
(Polygon([0, 0, 0, 2, 2, 2, 2, 0]), Polygon([1, 0, 3, 0, 3, 2, 1, 2]), 0.5),
40+
(Bbox(0, 0, 2, 2), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 0.5),
41+
(Mask(np.array([[1, 1, 0], [1, 1, 0]])), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 0.5),
42+
(Polygon([0, 0, 0, 2, 2, 2, 2, 0]), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 0.5),
43+
],
44+
)
45+
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
46+
def test_segment_iou_can_match_shapes_as_crowd(
47+
self, a: SpatialAnnotation, b: SpatialAnnotation, expected_iou: float
48+
):
49+
# In this mode, intersection is divided by the GT object area
50+
assert expected_iou == segment_iou(a, b, is_crowd=True)
51+
52+
@pytest.mark.parametrize(
53+
"a, b, expected_iou",
54+
[
55+
(Bbox(0, 0, 2, 2, attributes={"is_crowd": True}), Bbox(1, 0, 2, 2), 0.5),
56+
],
57+
)
58+
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
59+
def test_segment_iou_can_get_is_crowd_from_attribute(
60+
self, a: SpatialAnnotation, b: SpatialAnnotation, expected_iou: float
61+
):
62+
# In this mode, intersection is divided by the GT object area
63+
assert expected_iou == segment_iou(a, b, is_crowd="is_crowd")
64+
65+
66+
@pytest.mark.parametrize(
67+
"obj, expected_bbox",
68+
[
69+
((0, 1, 3, 4), (0, 1, 3, 4)),
70+
(Bbox(0, 0, 2, 2), (0, 0, 2, 2)),
71+
(Polygon([0, 0, 0, 1, 1, 1, 1, 0]), (0, 0, 1, 1)), # polygons don't include the last pixel
72+
(Polygon([1, 0, 3, 0, 3, 2, 1, 2]), (1, 0, 2, 2)),
73+
(Mask(np.array([[0, 1, 1], [0, 1, 1]])), (1, 0, 2, 2)),
74+
],
75+
)
76+
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
77+
def test_can_get_bbox(obj, expected_bbox):
78+
assert expected_bbox == tuple(get_bbox(obj))

tests/test_masks.py

+15
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from unittest import TestCase
22

33
import numpy as np
4+
import pytest
45

56
import datumaro.util.mask_tools as mask_tools
67
from datumaro.components.annotation import CompiledMask
8+
from datumaro.util.annotation_util import BboxCoords
79

810
from .requirements import Requirements, mark_requirement
911

@@ -461,3 +463,16 @@ def test_can_decode_compiled_mask(self):
461463
labels = compiled_mask.get_instance_labels()
462464

463465
self.assertEqual({instance_idx: class_idx}, labels)
466+
467+
468+
class MaskTest:
469+
@pytest.mark.parametrize(
470+
"mask, expected_bbox",
471+
[
472+
(np.array([[0, 1, 1], [0, 1, 1]]), [1, 0, 2, 2]),
473+
(np.array([[0, 0, 0], [0, 0, 0]]), [0, 0, 0, 0]),
474+
],
475+
)
476+
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
477+
def test_find_mask_bbox(self, mask: mask_tools.BinaryMask, expected_bbox: BboxCoords):
478+
assert tuple(expected_bbox) == mask_tools.find_mask_bbox(mask)

tests/test_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def test_shapes_to_boxes(self):
352352
id=1,
353353
media=Image(data=np.zeros((5, 5, 3))),
354354
annotations=[
355-
Bbox(0, 0, 4, 4, id=1),
355+
Bbox(0, 0, 5, 5, id=1),
356356
Bbox(1, 1, 3, 3, id=2),
357357
Bbox(1, 1, 1, 1, id=3),
358358
Bbox(2, 2, 2, 2, id=4),

tests/test_validator.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (C) 2021 Intel Corporation
2+
# Copyright (C) 2024 CVAT.ai Corporation
23
#
34
# SPDX-License-Identifier: MIT
45

@@ -967,17 +968,17 @@ def test_validate_annotations_segmentation(self):
967968
report_types = [r["anomaly_type"] for r in actual_reports]
968969
count_by_type = Counter(report_types)
969970

970-
self.assertEqual(len(actual_reports), 24)
971+
self.assertEqual(len(actual_reports), 25)
971972
self.assertEqual(count_by_type["ImbalancedDistInLabel"], 0)
972-
self.assertEqual(count_by_type["ImbalancedDistInAttribute"], 13)
973+
self.assertEqual(count_by_type["ImbalancedDistInAttribute"], 14)
973974
self.assertEqual(count_by_type["MissingAnnotation"], 1)
974975
self.assertEqual(count_by_type["UndefinedLabel"], 2)
975976
self.assertEqual(count_by_type["FewSamplesInAttribute"], 4)
976977
self.assertEqual(count_by_type["UndefinedAttribute"], 4)
977978

978979
with self.subTest("Test of summary", i=2):
979980
actual_summary = actual_results["summary"]
980-
expected_summary = {"errors": 6, "warnings": 18}
981+
expected_summary = {"errors": 6, "warnings": 19}
981982

982983
self.assertEqual(actual_summary, expected_summary)
983984

0 commit comments

Comments
 (0)