|
| 1 | +# Copyright (C) 2024 CVAT.ai Corporation |
| 2 | +# |
| 3 | +# SPDX-License-Identifier: MIT |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import pytest |
| 7 | + |
| 8 | +from datumaro.components.annotation import Bbox, Mask, Polygon |
| 9 | +from datumaro.util.annotation_util import SpatialAnnotation, get_bbox, segment_iou |
| 10 | + |
| 11 | +from .requirements import Requirements, mark_requirement |
| 12 | + |
| 13 | + |
| 14 | +class SegmentIouTest: |
| 15 | + @pytest.mark.parametrize( |
| 16 | + "a, b, expected_iou", |
| 17 | + [ |
| 18 | + (Bbox(0, 0, 2, 2), Bbox(0, 0, 2, 1), 0.5), # nested |
| 19 | + (Bbox(0, 0, 2, 2), Bbox(1, 0, 2, 2), 1 / 3), # partially intersecting |
| 20 | + (Bbox(0, 0, 2, 2), Polygon([0, 0, 0, 1, 1, 1, 1, 0]), 0.25), |
| 21 | + (Polygon([0, 0, 0, 2, 2, 2, 2, 0]), Polygon([1, 0, 3, 0, 3, 2, 1, 2]), 1 / 3), |
| 22 | + (Bbox(0, 0, 2, 2), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 1 / 3), |
| 23 | + (Mask(np.array([[1, 1, 0], [1, 1, 0]])), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 1 / 3), |
| 24 | + (Polygon([0, 0, 0, 2, 2, 2, 2, 0]), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 1 / 3), |
| 25 | + ], |
| 26 | + ) |
| 27 | + @mark_requirement(Requirements.DATUM_GENERAL_REQ) |
| 28 | + def test_segment_iou_can_match_shapes( |
| 29 | + self, a: SpatialAnnotation, b: SpatialAnnotation, expected_iou: float |
| 30 | + ): |
| 31 | + assert expected_iou == segment_iou(a, b) |
| 32 | + |
| 33 | + @pytest.mark.parametrize( |
| 34 | + "a, b, expected_iou", |
| 35 | + [ |
| 36 | + (Bbox(0, 0, 2, 2), Bbox(0, 0, 2, 1), 0.5), # nested |
| 37 | + (Bbox(0, 0, 2, 2), Bbox(1, 0, 2, 2), 0.5), # partially intersecting |
| 38 | + (Bbox(0, 0, 2, 2), Polygon([0, 0, 0, 1, 1, 1, 1, 0]), 0.25), |
| 39 | + (Polygon([0, 0, 0, 2, 2, 2, 2, 0]), Polygon([1, 0, 3, 0, 3, 2, 1, 2]), 0.5), |
| 40 | + (Bbox(0, 0, 2, 2), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 0.5), |
| 41 | + (Mask(np.array([[1, 1, 0], [1, 1, 0]])), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 0.5), |
| 42 | + (Polygon([0, 0, 0, 2, 2, 2, 2, 0]), Mask(np.array([[0, 1, 1], [0, 1, 1]])), 0.5), |
| 43 | + ], |
| 44 | + ) |
| 45 | + @mark_requirement(Requirements.DATUM_GENERAL_REQ) |
| 46 | + def test_segment_iou_can_match_shapes_as_crowd( |
| 47 | + self, a: SpatialAnnotation, b: SpatialAnnotation, expected_iou: float |
| 48 | + ): |
| 49 | + # In this mode, intersection is divided by the GT object area |
| 50 | + assert expected_iou == segment_iou(a, b, is_crowd=True) |
| 51 | + |
| 52 | + @pytest.mark.parametrize( |
| 53 | + "a, b, expected_iou", |
| 54 | + [ |
| 55 | + (Bbox(0, 0, 2, 2, attributes={"is_crowd": True}), Bbox(1, 0, 2, 2), 0.5), |
| 56 | + ], |
| 57 | + ) |
| 58 | + @mark_requirement(Requirements.DATUM_GENERAL_REQ) |
| 59 | + def test_segment_iou_can_get_is_crowd_from_attribute( |
| 60 | + self, a: SpatialAnnotation, b: SpatialAnnotation, expected_iou: float |
| 61 | + ): |
| 62 | + # In this mode, intersection is divided by the GT object area |
| 63 | + assert expected_iou == segment_iou(a, b, is_crowd="is_crowd") |
| 64 | + |
| 65 | + |
| 66 | +@pytest.mark.parametrize( |
| 67 | + "obj, expected_bbox", |
| 68 | + [ |
| 69 | + ((0, 1, 3, 4), (0, 1, 3, 4)), |
| 70 | + (Bbox(0, 0, 2, 2), (0, 0, 2, 2)), |
| 71 | + (Polygon([0, 0, 0, 1, 1, 1, 1, 0]), (0, 0, 1, 1)), # polygons don't include the last pixel |
| 72 | + (Polygon([1, 0, 3, 0, 3, 2, 1, 2]), (1, 0, 2, 2)), |
| 73 | + (Mask(np.array([[0, 1, 1], [0, 1, 1]])), (1, 0, 2, 2)), |
| 74 | + ], |
| 75 | +) |
| 76 | +@mark_requirement(Requirements.DATUM_GENERAL_REQ) |
| 77 | +def test_can_get_bbox(obj, expected_bbox): |
| 78 | + assert expected_bbox == tuple(get_bbox(obj)) |
0 commit comments