From c38afb169c3a09784237842958f9dd712827ab52 Mon Sep 17 00:00:00 2001 From: Rosie Wood Date: Mon, 12 Aug 2024 15:12:06 +0100 Subject: [PATCH 1/2] fix for get_label_index --- mapreader/classify/load_annotations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mapreader/classify/load_annotations.py b/mapreader/classify/load_annotations.py index 3c31d408..8a52e273 100644 --- a/mapreader/classify/load_annotations.py +++ b/mapreader/classify/load_annotations.py @@ -882,7 +882,8 @@ def _get_label_index(self, label: str) -> int: Used to generate the ``label_index`` column. """ - return self.unique_labels.index(label) + index_map = {v: k for k, v in self.labels_map.items()} + return index_map[label] def __str__(self): print(f"[INFO] Number of annotations: {len(self.annotations)}\n") From 5c9fc6f02f3088ac71fef3c22bae3591df7bd438 Mon Sep 17 00:00:00 2001 From: Rosie Wood Date: Mon, 12 Aug 2024 15:22:40 +0100 Subject: [PATCH 2/2] add tests --- .../test_classify/test_annotations_loader.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_classify/test_annotations_loader.py b/tests/test_classify/test_annotations_loader.py index ac5e64ec..fd2d1376 100644 --- a/tests/test_classify/test_annotations_loader.py +++ b/tests/test_classify/test_annotations_loader.py @@ -68,6 +68,34 @@ def test_labels_map(sample_dir): assert annots.labels_map == {0: "railspace", 1: "no", 2: "building"} +def test_get_label_index(sample_dir): + annots = AnnotationsLoader() + annots.load( + f"{sample_dir}/test_annots.csv", + reset_index=True, + remove_broken=False, + ignore_broken=True, + labels_map={ + 0: "railspace", + 1: "no", + }, # different order vs in the csv + ) + assert annots.labels_map == {0: "railspace", 1: "no"} + assert annots._get_label_index("railspace") == 0 + + # test append + annots.load( + f"{sample_dir}/test_annots_append.csv", + append=True, + remove_broken=False, + ignore_broken=True, + ) + assert annots.unique_labels == ["no", "railspace", "building"] + assert annots.labels_map == {0: "railspace", 1: "no", 2: "building"} + assert annots._get_label_index("railspace") == 0 + assert annots._get_label_index("building") == 2 + + @pytest.mark.dependency(name="load_annots_df", scope="session") def test_load_df(sample_dir): annots = AnnotationsLoader()