diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 305b675..4db7836 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -1,5 +1,7 @@ name: CI Code Checks on: [pull_request] +permissions: + contents: write jobs: build: name: code checks @@ -41,4 +43,4 @@ jobs: folder: coverage - name: Type checking with mypy - run: mypy --config setup.cfg src \ No newline at end of file + run: mypy --config setup.cfg src diff --git a/src/nervaluate/evaluate.py b/src/nervaluate/evaluate.py index 290cc05..0468f7a 100644 --- a/src/nervaluate/evaluate.py +++ b/src/nervaluate/evaluate.py @@ -144,6 +144,10 @@ def compute_metrics( # type: ignore true_named_entities = [clean_entities(ent) for ent in true_named_entities if ent["label"] in tags] pred_named_entities = [clean_entities(ent) for ent in pred_named_entities if ent["label"] in tags] + # Sort the lists to improve the speed of the overlap comparison + true_named_entities.sort(key=lambda x: x["start"]) + pred_named_entities.sort(key=lambda x: x["end"]) + # go through each predicted named-entity for pred in pred_named_entities: found_overlap = False @@ -169,6 +173,10 @@ def compute_metrics( # type: ignore else: # check for overlaps with any of the true entities for true in true_named_entities: + # Only enter this block if an overlap is possible + if pred["end"] < true["start"]: + break + # overlapping needs to take into account last token as well pred_range = range(pred["start"], pred["end"] + 1) true_range = range(true["start"], true["end"] + 1) @@ -214,29 +222,27 @@ def compute_metrics( # type: ignore found_overlap = True - break + else: + # Scenario VI: Entities overlap, but the entity type is + # different. - # Scenario VI: Entities overlap, but the entity type is - # different. - - # overall results - evaluation["strict"]["incorrect"] += 1 - evaluation["ent_type"]["incorrect"] += 1 - evaluation["partial"]["partial"] += 1 - evaluation["exact"]["incorrect"] += 1 + # overall results + evaluation["strict"]["incorrect"] += 1 + evaluation["ent_type"]["incorrect"] += 1 + evaluation["partial"]["partial"] += 1 + evaluation["exact"]["incorrect"] += 1 - # aggregated by entity type results - # Results against the true entity + # aggregated by entity type results + # Results against the true entity - evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1 - evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1 - evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1 - evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1 + evaluation_agg_entities_type[true["label"]]["strict"]["incorrect"] += 1 + evaluation_agg_entities_type[true["label"]]["partial"]["partial"] += 1 + evaluation_agg_entities_type[true["label"]]["ent_type"]["incorrect"] += 1 + evaluation_agg_entities_type[true["label"]]["exact"]["incorrect"] += 1 - # Results against the predicted entity - # evaluation_agg_entities_type[pred['label']]['strict']['spurious'] += 1 - found_overlap = True - break + # Results against the predicted entity + # evaluation_agg_entities_type[pred['label']]['strict']['spurious'] += 1 + found_overlap = True # Scenario II: Entities are spurious (i.e., over-generated). if not found_overlap: diff --git a/tests/test_nervaluate.py b/tests/test_nervaluate.py index 95bc75f..c4e0c07 100644 --- a/tests/test_nervaluate.py +++ b/tests/test_nervaluate.py @@ -816,3 +816,74 @@ def test_compute_precision_recall(): out = compute_precision_recall(results) assert out == expected + + +def test_compute_metrics_one_pred_two_true(): + true_named_entities_1 = [ + {"start": 0, "end": 12, "label": "A"}, + {"start": 14, "end": 17, "label": "B"}, + ] + true_named_entities_2 = [ + {"start": 14, "end": 17, "label": "B"}, + {"start": 0, "end": 12, "label": "A"}, + ] + pred_named_entities = [ + {"start": 0, "end": 17, "label": "A"}, + ] + + results1, _ = compute_metrics(true_named_entities_1, pred_named_entities, ["A", "B"]) + results2, _ = compute_metrics(true_named_entities_2, pred_named_entities, ["A", "B"]) + + expected = { + 'ent_type': { + 'correct': 1, + 'incorrect': 1, + 'partial': 0, + 'missed': 0, + 'spurious': 0, + 'possible': 2, + 'actual': 2, + 'precision': 0, + 'recall': 0, + 'f1': 0 + }, + 'partial': { + 'correct': 0, + 'incorrect': 0, + 'partial': 2, + 'missed': 0, + 'spurious': 0, + 'possible': 2, + 'actual': 2, + 'precision': 0, + 'recall': 0, + 'f1': 0 + }, + 'strict': { + 'correct': 0, + 'incorrect': 2, + 'partial': 0, + 'missed': 0, + 'spurious': 0, + 'possible': 2, + 'actual': 2, + 'precision': 0, + 'recall': 0, + 'f1': 0 + }, + 'exact': { + 'correct': 0, + 'incorrect': 2, + 'partial': 0, + 'missed': 0, + 'spurious': 0, + 'possible': 2, + 'actual': 2, + 'precision': 0, + 'recall': 0, + 'f1': 0 + } + } + + assert results1 == expected + assert results2 == expected