Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize RelationClassifier by adding the option to filter long sentences and truncate context #3593

Merged
merged 21 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
fc786b3
Optimize RelationClassifier by filtering long sentences
alanakbik Jan 2, 2025
594d858
Optimize RelationClassifier by filtering long sentences
alanakbik Jan 2, 2025
8fc8a58
Fix serialization
alanakbik Jan 2, 2025
1fd1851
Change context window calculation
alanakbik Jan 2, 2025
7f89bb0
Change context window calculation
alanakbik Jan 2, 2025
70148da
Add sanity check to ensure entities are not contained in one another
alanakbik Jan 2, 2025
f50c3b3
Fix slicing such that left and right context are of equal length
alanakbik Jan 2, 2025
142703b
Make mypy happy
alanakbik Jan 2, 2025
3ad499b
Remove unnecessary if statement
alanakbik Jan 2, 2025
f798a3c
Merge branch 'master' into filter_relations
alanakbik Jan 11, 2025
0cca6a8
Ensure presence of head and tail entity in the original sentence
dobbersc Jan 15, 2025
4ed2e49
Refactor `_slice_encoded_sentence_to_max_allowed_length` to use min/m…
dobbersc Jan 21, 2025
306412f
Allow to disable the `max_allowed_tokens_between_entities` and `max_s…
dobbersc Jan 21, 2025
4fc4878
Rearrange parameters and make sentence filters public
dobbersc Jan 21, 2025
de8b7f4
Add test cases for `max_allowed_tokens_between_entities` and `max_sur…
dobbersc Jan 22, 2025
6789a6a
Merge branch 'master' into filter_relations
dobbersc Jan 22, 2025
a06fa30
Fix tests due to additional training data point in `train.conllup`
dobbersc Jan 22, 2025
3ce22f1
Merge remote-tracking branch 'origin/filter_relations' into filter_re…
dobbersc Jan 22, 2025
6f8d168
Fix serialization issue in ModelTrainer
alanakbik Jan 27, 2025
cbd8be3
Merge branch 'master' into filter_relations
alanakbik Feb 1, 2025
863d903
Merge branch 'master' into filter_relations
alanakbik Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flair/models/regexp_tagger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import typing
from dataclasses import dataclass, field
from typing import Union

Expand Down
86 changes: 80 additions & 6 deletions flair/models/relation_classifier_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,13 @@ def __init__(
entity_label_types: Union[str, Sequence[str], dict[str, Optional[set[str]]]],
entity_pair_labels: Optional[set[tuple[str, str]]] = None,
entity_threshold: Optional[float] = None,
max_allowed_tokens_between_entities: Optional[int] = 20,
max_surrounding_context_length: Optional[int] = 10,
cross_augmentation: bool = True,
encoding_strategy: EncodingStrategy = TypedEntityMarker(),
zero_tag_value: str = "O",
allow_unk_tag: bool = True,
**classifierargs,
**classifierargs: Any,
) -> None:
"""Initializes a `RelationClassifier`.

Expand All @@ -267,6 +269,8 @@ def __init__(
entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'PER' and 'ORG' labels from a NER-tagger: `{'ner': {'PER', 'ORG'}}`. To use all labels from 'ner', pass 'ner'.
entity_pair_labels: A set of valid relation entity pair combinations, used as relation candidates. Specify valid entity pairs in a set of tuples of labels (<HEAD>, <TAIL>). E.g. for the `born_in` relation, only relations from 'PER' to 'LOC' make sense. Here, relations from 'PER' to 'PER' are not meaningful, so it is advised to specify the `entity_pair_labels` as `{('PER', 'ORG')}`. This setting may help to reduce the number of relation candidates. Leaving this parameter as `None` (default) disables the relation-candidate-filter, i.e. the model classifies the relation for each entity pair in the cross product of *all* entity pairs (inefficient).
entity_threshold: Only pre-labelled entities above this threshold are taken into account by the model.
max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. If `None`, the filter will be disabled.
max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. The context, in between the entity pairs will always be included. If `None`, the filter will be disabled.
cross_augmentation: If `True`, use cross augmentation to transform `Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus`, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence.
encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol
zero_tag_value: The label to use for out-of-class relations
Expand Down Expand Up @@ -302,6 +306,8 @@ def __init__(
self.entity_pair_labels = entity_pair_labels

self.entity_threshold = entity_threshold
self.max_allowed_tokens_between_entities = max_allowed_tokens_between_entities
self.max_surrounding_context_length = max_surrounding_context_length
self.cross_augmentation = cross_augmentation
self.encoding_strategy = encoding_strategy

Expand Down Expand Up @@ -393,12 +399,41 @@ def _entity_pair_permutations(

yield head, tail, gold_label

@staticmethod
def _truncate_context_around_entities(
encoded_sentence_tokens: list[str],
head_idx: int,
tail_idx: int,
context_length: int,
) -> list[str]:
"""Truncates the encoded sentence to include the head and tail entity and their surrounding context.

The context, in between the entity pairs will always be included.

Args:
encoded_sentence_tokens: The list of tokens corresponding to the encoded sentence.
head_idx: The index of the head entity in the token list.
tail_idx: The index of the tail entity in the token list.
context_length: The maximum number of tokens to include as surrounding context around the head and tail entities.

Returns:
The tokens of the truncated sentence.
"""
begin_slice: int = min(head_idx, tail_idx)
end_slice: int = max(head_idx, tail_idx)

# Preserve context around the entities. Always include their in-between context.
begin_slice = max(begin_slice - context_length, 0)
end_slice = min(end_slice + context_length + 1, len(encoded_sentence_tokens))

return encoded_sentence_tokens[begin_slice:end_slice]

def _encode_sentence(
self,
head: _Entity,
tail: _Entity,
gold_label: Optional[str] = None,
) -> EncodedSentence:
) -> Optional[EncodedSentence]:
"""Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy.

If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`.
Expand All @@ -414,6 +449,12 @@ def _encode_sentence(
original_sentence: Sentence = head.span.sentence
assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence."

# Sanity check: Do not create a labeled span if one entity contains the other
if head.span[0].idx <= tail.span[0].idx and head.span[-1].idx >= tail.span[-1].idx:
return None
if head.span[0].idx >= tail.span[0].idx and head.span[-1].idx <= tail.span[-1].idx:
return None

# Pre-compute non-leading head and tail tokens for entity masking
non_leading_head_tokens: list[Token] = head.span.tokens[1:]
non_leading_tail_tokens: list[Token] = tail.span.tokens[1:]
Expand All @@ -422,11 +463,15 @@ def _encode_sentence(
# since there may be multiple occurrences of the same entity mentioned in the sentence.
# Therefore, we use the span's position in the sentence.
encoded_sentence_tokens: list[str] = []
head_idx: Optional[int] = None
tail_idx: Optional[int] = None
for token in original_sentence:
if token is head.span[0]:
head_idx = len(encoded_sentence_tokens)
encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label))

elif token is tail.span[0]:
tail_idx = len(encoded_sentence_tokens)
encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label))

elif all(
Expand All @@ -435,6 +480,27 @@ def _encode_sentence(
):
encoded_sentence_tokens.append(token.text)

msg: str
if head_idx is None:
msg = f"The head entity ({head!r}) is not located inside the original sentence ({original_sentence!r})."
raise AssertionError(msg)
if tail_idx is None:
msg = f"The tail entity ({tail!r}) is not located inside the original sentence ({original_sentence!r})."
raise AssertionError(msg)

# Filter cases in which the distance between the two entities is too large
if (
self.max_allowed_tokens_between_entities is not None
and abs(head_idx - tail_idx) > self.max_allowed_tokens_between_entities
):
return None

# Remove excess tokens left and right of entity pair to make encoded sentence shorter
if self.max_surrounding_context_length is not None:
encoded_sentence_tokens = self._truncate_context_around_entities(
encoded_sentence_tokens, head_idx, tail_idx, self.max_surrounding_context_length
)

# Create masked sentence
encoded_sentence: EncodedSentence = EncodedSentence(
" ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer()
Expand All @@ -445,6 +511,7 @@ def _encode_sentence(
# Using the sentence label instead of annotating a separate `Relation` object is easier to manage since,
# during prediction, the forward pass does not need any knowledge about the entities in the sentence.
encoded_sentence.add_label(typename=self.label_type, value=gold_label, score=1.0)

encoded_sentence.copy_context_from_sentence(original_sentence)
return encoded_sentence

Expand All @@ -469,13 +536,15 @@ def _encode_sentence_for_inference(
Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence
"""
for head, tail, gold_label in self._entity_pair_permutations(sentence):
masked_sentence: EncodedSentence = self._encode_sentence(
masked_sentence: Optional[EncodedSentence] = self._encode_sentence(
head=head,
tail=tail,
gold_label=gold_label if gold_label is not None else self.zero_tag_value,
)
original_relation: Relation = Relation(first=head.span, second=tail.span)
yield masked_sentence, original_relation

if masked_sentence is not None:
yield masked_sentence, original_relation

def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]:
"""Create Encoded Sentences and Relation pairs for Training.
Expand All @@ -492,13 +561,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS
else:
continue # Skip generated data points that do not express an originally annotated relation

masked_sentence: EncodedSentence = self._encode_sentence(
masked_sentence: Optional[EncodedSentence] = self._encode_sentence(
head=head,
tail=tail,
gold_label=gold_label,
)

yield masked_sentence
if masked_sentence is not None:
yield masked_sentence

def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]:
"""Transforms sentences into encoded sentences specific to the `RelationClassifier`.
Expand Down Expand Up @@ -702,6 +772,8 @@ def _get_state_dict(self) -> dict[str, Any]:
"entity_label_types": self.entity_label_types,
"entity_pair_labels": self.entity_pair_labels,
"entity_threshold": self.entity_threshold,
"max_allowed_tokens_between_entities": self.max_allowed_tokens_between_entities,
"max_surrounding_context_length": self.max_surrounding_context_length,
"cross_augmentation": self.cross_augmentation,
"encoding_strategy": self.encoding_strategy,
"zero_tag_value": self.zero_tag_value,
Expand All @@ -719,6 +791,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs):
entity_label_types=state["entity_label_types"],
entity_pair_labels=state["entity_pair_labels"],
entity_threshold=state["entity_threshold"],
max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities"),
max_surrounding_context_length=state.get("max_surrounding_context_length"),
cross_augmentation=state["cross_augmentation"],
encoding_strategy=state["encoding_strategy"],
zero_tag_value=state["zero_tag_value"],
Expand Down
20 changes: 12 additions & 8 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,14 +966,18 @@ def _initialize_model_card(self, **training_parameters):
except ImportError:
pass

# remember all parameters used in train() call
model_card["training_parameters"] = {
k: str(v) if isinstance(v, Path) else v for k, v in training_parameters.items()
}

model_card["training_parameters"] = {
k: f"{v.__module__}.{v.__name__}" if inspect.isclass(v) else v for k, v in training_parameters.items()
}
# remember the training parameters
model_card["training_parameters"] = {}
for k, v in training_parameters.items():

# special rule for Path variables to make sure models can be deserialized on other OS
if isinstance(v, Path):
v = str(v)
# classes are only serialized as names
if inspect.isclass(v):
v = f"{v.__module__}.{v.__name__}"

model_card["training_parameters"][k] = v

plugins = [plugin.get_state() for plugin in model_card["training_parameters"]["plugins"]]
model_card["training_parameters"]["plugins"] = plugins
Expand Down
Loading