From d34bfd096353d057ae0e39aaa9d4c041f47f7278 Mon Sep 17 00:00:00 2001 From: Sheldon Roberts Date: Fri, 16 Aug 2024 17:23:40 -0700 Subject: [PATCH 01/46] Add DeepNCMClassifier model Add tests for DeepNCMClassifier Remove old test Add multi label support Add type hints and doc strings --- flair/models/__init__.py | 2 + flair/models/deepncm_classification_model.py | 455 ++++++++++++++++++ flair/trainers/plugins/__init__.py | 2 + .../functional/deepncm_trainer_plugin.py | 41 ++ tests/models/test_deepncm_classifier.py | 167 +++++++ 5 files changed, 667 insertions(+) create mode 100644 flair/models/deepncm_classification_model.py create mode 100644 flair/trainers/plugins/functional/deepncm_trainer_plugin.py create mode 100644 tests/models/test_deepncm_classifier.py diff --git a/flair/models/__init__.py b/flair/models/__init__.py index e75daf074b..bf3651078a 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,3 +1,4 @@ +from .deepncm_classification_model import DeepNCMClassifier from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -37,4 +38,5 @@ "TextClassifier", "TextRegressor", "MultitaskModel", + "DeepNCMClassifier", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py new file mode 100644 index 0000000000..b942e28919 --- /dev/null +++ b/flair/models/deepncm_classification_model.py @@ -0,0 +1,455 @@ +import logging +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +from tqdm import tqdm + +import flair +from flair.data import Dictionary, Sentence +from flair.datasets import DataLoader, FlairDatapointDataset +from flair.embeddings import DocumentEmbeddings +from flair.embeddings.base import load_embeddings +from flair.nn import Classifier + +log = logging.getLogger("flair") + + +class DeepNCMClassifier(Classifier[Sentence]): + """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. + + This model combines deep learning with the Nearest Class Mean (NCM) approach. + It uses document embeddings to represent text, optionally applies an encoder, + and classifies based on the nearest class prototype in the embedded space. + + The model supports various methods for updating class prototypes during training, + making it adaptable to different learning scenarios. + + This implementation is based on the research paper: + Guerriero, S., Caputo, B., & Mensink, T. (2018). DeepNCM: Deep Nearest Class Mean Classifiers. + In International Conference on Learning Representations (ICLR) 2018 Workshop. + URL: https://openreview.net/forum?id=rkPLZ4JPM + """ + + def __init__( + self, + embeddings: DocumentEmbeddings, + label_dictionary: Dictionary, + label_type: str, + encoding_dim: Optional[int] = None, + alpha: float = 0.9, + mean_update_method: Literal["online", "condensation", "decay"] = "online", + use_encoder: bool = True, + multi_label: bool = False, + multi_label_threshold: float = 0.5, + ): + """Initialize a DeepNCMClassifier. + + Args: + embeddings: Document embeddings to use for encoding text. + label_dictionary: Dictionary containing the label vocabulary. + label_type: The type of label to predict. + encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). + alpha: The decay factor for updating class prototypes (default is 0.9). + mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). + use_encoder: Whether to apply an encoder to the input embeddings (default is True). + multi_label: Whether to predict multiple labels per sentence (default is False). + multi_label_threshold: The threshold for multi-label prediction (default is 0.5). + """ + super().__init__() + + self.embeddings = embeddings + self.label_dictionary = label_dictionary + self._label_type = label_type + self.alpha = alpha + self.mean_update_method = mean_update_method + self.use_encoder = use_encoder + self.multi_label = multi_label + self.multi_label_threshold = multi_label_threshold + self.num_classes = len(label_dictionary) + self.embedding_dim = embeddings.embedding_length + + if use_encoder: + self.encoding_dim = encoding_dim or self.embedding_dim + else: + self.encoding_dim = self.embedding_dim + + self._validate_parameters() + + if self.use_encoder: + self.encoder = torch.nn.Sequential( + torch.nn.Linear(self.embedding_dim, self.encoding_dim * 2), + torch.nn.ReLU(), + torch.nn.Linear(self.encoding_dim * 2, self.encoding_dim), + ) + else: + self.encoder = torch.nn.Sequential(torch.nn.Identity()) + + self.loss_function = ( + torch.nn.BCEWithLogitsLoss(reduction="sum") + if self.multi_label + else torch.nn.CrossEntropyLoss(reduction="sum") + ) + + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(self.num_classes, self.encoding_dim)), requires_grad=False + ) + self.class_counts = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(self.num_classes).to(flair.device) + self.to(flair.device) + + def _validate_parameters(self) -> None: + """Validate the input parameters.""" + assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" + assert self.mean_update_method in [ + "online", + "condensation", + "decay", + ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" + assert self.encoding_dim > 0, "encoding_dim must be greater than 0" + + def forward(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor: + """Encode the input sentences using embeddings and optional encoder. + + Args: + sentences: Input sentence or list of sentences. + + Returns: + torch.Tensor: Encoded representations of the input sentences. + """ + if not isinstance(sentences, list): + sentences = [sentences] + + self.embeddings.embed(sentences) + sentence_embeddings = torch.stack([sentence.get_embedding() for sentence in sentences]) + encoded_embeddings = self.encoder(sentence_embeddings) + + return encoded_embeddings + + def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: + """Calculate distances between encoded embeddings and class prototypes. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + + Returns: + torch.Tensor: Distances between encoded embeddings and class prototypes. + """ + return torch.cdist(encoded_embeddings, self.class_prototypes) + + def forward_loss(self, data_points: List[Sentence]) -> Tuple[torch.Tensor, int]: + """Compute the loss for a batch of sentences. + + Args: + data_points: A list of sentences. + + Returns: + Tuple[torch.Tensor, int]: The total loss and the number of sentences. + """ + encoded_embeddings = self.forward(data_points) + labels = self._prepare_label_tensor(data_points) + distances = self._calculate_distances(encoded_embeddings) + loss = self.loss_function(-distances, labels) + self._calculate_prototype_updates(encoded_embeddings, labels) + + return loss, len(data_points) + + def _prepare_label_tensor(self, sentences: List[Sentence]) -> torch.Tensor: + """Prepare the label tensor for the given sentences. + + Args: + sentences: A list of sentences. + + Returns: + torch.Tensor: The label tensor for the given sentences. + """ + if self.multi_label: + return torch.tensor( + [ + [ + ( + 1 + if label + in [sentence_label.value for sentence_label in sentence.get_labels(self._label_type)] + else 0 + ) + for label in self.label_dictionary.get_items() + ] + for sentence in sentences + ], + dtype=torch.float, + device=flair.device, + ) + else: + return torch.tensor( + [ + self.label_dictionary.get_idx_for_item(sentence.get_label(self._label_type).value) + for sentence in sentences + ], + dtype=torch.long, + device=flair.device, + ) + + def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: + """Calculate updates for class prototypes based on the current batch. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + labels: True labels for the input sentences. + """ + one_hot = ( + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_classes).float() + ) + + updates = torch.matmul(one_hot.t(), encoded_embeddings) + counts = one_hot.sum(dim=0) + mask = counts > 0 + self.prototype_updates[mask] += updates[mask] + self.prototype_update_counts[mask] += counts[mask] + + def update_prototypes(self) -> None: + """Apply accumulated updates to class prototypes.""" + with torch.no_grad(): + update_mask = self.prototype_update_counts > 0 + if update_mask.any(): + if self.mean_update_method in ["online", "condensation"]: + new_counts = self.class_counts[update_mask] + self.prototype_update_counts[update_mask] + self.class_prototypes[update_mask] = ( + self.class_counts[update_mask].unsqueeze(1) * self.class_prototypes[update_mask] + + self.prototype_updates[update_mask] + ) / new_counts.unsqueeze(1) + self.class_counts[update_mask] = new_counts + elif self.mean_update_method == "decay": + new_prototypes = self.prototype_updates[update_mask] / self.prototype_update_counts[ + update_mask + ].unsqueeze(1) + self.class_prototypes[update_mask] = ( + self.alpha * self.class_prototypes[update_mask] + (1 - self.alpha) * new_prototypes + ) + self.class_counts[update_mask] += self.prototype_update_counts[update_mask] + + # Reset prototype updates + self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) + self.prototype_update_counts = torch.zeros(self.num_classes, device=flair.device) + + def predict( + self, + sentences: Union[List[Sentence], Sentence], + mini_batch_size: int = 32, + return_probabilities_for_all_classes: bool = False, + verbose: bool = False, + label_name: Optional[str] = None, + return_loss: bool = False, + embedding_storage_mode: str = "none", + ) -> Union[List[Sentence], Tuple[float, int]]: + """Predict classes for a list of sentences. + + Args: + sentences: A list of sentences or a single sentence. + mini_batch_size: Size of mini batches during prediction. + return_probabilities_for_all_classes: Whether to return probabilities for all classes. + verbose: If True, show progress bar during prediction. + label_name: The name of the label to use for prediction. + return_loss: If True, compute and return loss. + embedding_storage_mode: The mode for storing embeddings ('none', 'cpu', or 'gpu'). + + Returns: + Union[List[Sentence], Tuple[float, int]]: + if return_loss is True, returns a tuple of total loss and total number of sentences; + otherwise, returns the list of sentences with predicted labels. + """ + with torch.no_grad(): + if not isinstance(sentences, list): + sentences = [sentences] + if not sentences: + return sentences + + label_name = label_name or self.label_type + Sentence.set_context_for_sentences(sentences) + + filtered_sentences = [sent for sent in sentences if len(sent) > 0] + reordered_sentences = sorted(filtered_sentences, key=len, reverse=True) + + if len(reordered_sentences) == 0: + return sentences + + dataloader = DataLoader( + dataset=FlairDatapointDataset(reordered_sentences), + batch_size=mini_batch_size, + ) + + if verbose: + progress_bar = tqdm(dataloader) + progress_bar.set_description("Predicting") + dataloader = progress_bar + + total_loss = 0.0 + total_sentences = 0 + + for batch in dataloader: + if not batch: + continue + + encoded_embeddings = self.forward(batch) + distances = self._calculate_distances(encoded_embeddings) + + if self.multi_label: + probabilities = torch.sigmoid(-distances) + else: + probabilities = torch.nn.functional.softmax(-distances, dim=1) + + if return_loss: + labels = self._prepare_label_tensor(batch) + loss = self.loss_function(-distances, labels) + total_loss += loss.item() + total_sentences += len(batch) + + for sentence_index, sentence in enumerate(batch): + sentence.remove_labels(label_name) + + if self.multi_label: + for label_index, probability in enumerate(probabilities[sentence_index]): + if probability > self.multi_label_threshold or return_probabilities_for_all_classes: + label_value = self.label_dictionary.get_item_for_index(label_index) + sentence.add_label(label_name, label_value, probability.item()) + else: + predicted_idx = torch.argmax(probabilities[sentence_index]) + label_value = self.label_dictionary.get_item_for_index(predicted_idx.item()) + sentence.add_label(label_name, label_value, probabilities[sentence_index, predicted_idx].item()) + + if return_probabilities_for_all_classes: + for label_index, probability in enumerate(probabilities[sentence_index]): + label_value = self.label_dictionary.get_item_for_index(label_index) + sentence.add_label(f"{label_name}_all", label_value, probability.item()) + + for sentence in batch: + sentence.clear_embeddings(embedding_storage_mode) + + if return_loss: + return total_loss, total_sentences + return sentences + + def _get_state_dict(self) -> Dict[str, Any]: + """Get the state dictionary of the model. + + Returns: + Dict[str, Any]: The state dictionary containing model parameters and configuration. + """ + model_state = { + "embeddings": self.embeddings.save_embeddings(), + "label_dictionary": self.label_dictionary, + "label_type": self.label_type, + "encoding_dim": self.encoding_dim, + "alpha": self.alpha, + "mean_update_method": self.mean_update_method, + "use_encoder": self.use_encoder, + "multi_label": self.multi_label, + "multi_label_threshold": self.multi_label_threshold, + "class_prototypes": self.class_prototypes.cpu(), + "class_counts": self.class_counts.cpu(), + "encoder": self.encoder.state_dict(), + } + return model_state + + @classmethod + def _init_model_with_state_dict(cls, state, **kwargs) -> "DeepNCMClassifier": + """Initialize the model from a state dictionary. + + Args: + state: The state dictionary containing model parameters and configuration. + **kwargs: Additional keyword arguments for model initialization. + + Returns: + DeepNCMClassifier: An instance of the model initialized with the given state. + """ + embeddings = state["embeddings"] + if isinstance(embeddings, dict): + embeddings = load_embeddings(embeddings) + + model = cls( + embeddings=embeddings, + label_dictionary=state["label_dictionary"], + label_type=state["label_type"], + encoding_dim=state["encoding_dim"], + alpha=state["alpha"], + mean_update_method=state["mean_update_method"], + use_encoder=state["use_encoder"], + multi_label=state.get("multi_label", False), + multi_label_threshold=state.get("multi_label_threshold", 0.5), + **kwargs, + ) + + if "encoder" in state: + model.encoder.load_state_dict(state["encoder"]) + if "class_prototypes" in state: + model.class_prototypes.data = state["class_prototypes"].to(flair.device) + if "class_counts" in state: + model.class_counts.data = state["class_counts"].to(flair.device) + + return model + + def get_prototype(self, class_name: str) -> torch.Tensor: + """Get the prototype vector for a given class name. + + Args: + class_name: The name of the class whose prototype vector is requested. + + Returns: + torch.Tensor: The prototype vector for the given class. + + Raises: + ValueError: If the class name is not found in the label dictionary. + """ + try: + class_idx = self.label_dictionary.get_idx_for_item(class_name) + except IndexError as exc: + raise ValueError(f"Class name '{class_name}' not found in the label dictionary") from exc + + return self.class_prototypes[class_idx].clone() + + def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> List[Tuple[str, float]]: + """Get the top_k closest prototype vectors to the given input vector using the configured distance metric. + + Args: + input_vector (torch.Tensor): The input vector to compare against prototypes. + top_k (int): The number of closest prototypes to return (default is 5). + + Returns: + List[Tuple[str, float]]: Each tuple contains (class_name, distance). + """ + if input_vector.dim() != 1: + raise ValueError("Input vector must be a 1D tensor") + if input_vector.size(0) != self.class_prototypes.size(1): + raise ValueError( + f"Input vector dimension ({input_vector.size(0)}) does not match prototype dimension ({self.class_prototypes.size(1)})" + ) + + input_vector = input_vector.unsqueeze(0) + distances = self._calculate_distances(input_vector) + top_k_values, top_k_indices = torch.topk(distances.squeeze(), k=top_k, largest=False) + + nearest_prototypes = [] + for idx, value in zip(top_k_indices, top_k_values): + class_name = self.label_dictionary.get_item_for_index(idx.item()) + nearest_prototypes.append((class_name, value.item())) + + return nearest_prototypes + + @property + def label_type(self) -> str: + """Get the label type for this classifier.""" + return self._label_type + + def __str__(self) -> str: + """Get a string representation of the model. + + Returns: + str: A string describing the model architecture. + """ + return ( + f"DeepNCMClassifier(\n" + f" (embeddings): {self.embeddings}\n" + f" (encoder): {self.encoder}\n" + f" (prototypes): {self.class_prototypes.shape}\n" + f")" + ) diff --git a/flair/trainers/plugins/__init__.py b/flair/trainers/plugins/__init__.py index 373fdf969b..c3b1c1bab3 100644 --- a/flair/trainers/plugins/__init__.py +++ b/flair/trainers/plugins/__init__.py @@ -1,6 +1,7 @@ from .base import BasePlugin, Pluggable, TrainerPlugin, TrainingInterrupt from .functional.anneal_on_plateau import AnnealingPlugin from .functional.checkpoints import CheckpointPlugin +from .functional.deepncm_trainer_plugin import DeepNCMPlugin from .functional.linear_scheduler import LinearSchedulerPlugin from .functional.reduce_transformer_vocab import ReduceTransformerVocabPlugin from .functional.weight_extractor import WeightExtractorPlugin @@ -15,6 +16,7 @@ "AnnealingPlugin", "CheckpointPlugin", "ClearmlLoggerPlugin", + "DeepNCMPlugin", "LinearSchedulerPlugin", "WeightExtractorPlugin", "LogFilePlugin", diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py new file mode 100644 index 0000000000..2c4c0ccb49 --- /dev/null +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -0,0 +1,41 @@ +import torch + +from flair.models import DeepNCMClassifier, MultitaskModel +from flair.trainers.plugins.base import TrainerPlugin + + +class DeepNCMPlugin(TrainerPlugin): + """Plugin for training DeepNCMClassifier. + + Handles both multitask and single-task scenarios. + """ + + def _process_models(self, operation: str): + """Process updates for all DeepNCMClassifier models in the trainer. + + Args: + operation (str): The operation to perform ('condensation' or 'update') + """ + model = self.trainer.model + + models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] + + for sub_model in models: + if isinstance(sub_model, DeepNCMClassifier): + if operation == "condensation" and sub_model.mean_update_method == "condensation": + sub_model.class_counts.data = torch.ones_like(sub_model.class_counts) + elif operation == "update": + sub_model.update_prototypes() + + @TrainerPlugin.hook + def after_training_epoch(self, **kwargs): + """Update prototypes after each training epoch.""" + self._process_models("condensation") + + @TrainerPlugin.hook + def after_training_batch(self, **kwargs): + """Update prototypes after each training batch.""" + self._process_models("update") + + def __str__(self) -> str: + return "DeepNCMPlugin" diff --git a/tests/models/test_deepncm_classifier.py b/tests/models/test_deepncm_classifier.py new file mode 100644 index 0000000000..3b76b6c0b9 --- /dev/null +++ b/tests/models/test_deepncm_classifier.py @@ -0,0 +1,167 @@ +import pytest +import torch + +from flair.data import Sentence +from flair.datasets import ClassificationCorpus +from flair.embeddings import TransformerDocumentEmbeddings +from flair.models import DeepNCMClassifier +from flair.trainers import ModelTrainer +from flair.trainers.plugins import DeepNCMPlugin +from tests.model_test_utils import BaseModelTest + + +class TestDeepNCMClassifier(BaseModelTest): + model_cls = DeepNCMClassifier + train_label_type = "class" + multiclass_prediction_labels = ["POSITIVE", "NEGATIVE"] + training_args = { + "max_epochs": 2, + "mini_batch_size": 4, + "learning_rate": 1e-5, + } + + @pytest.fixture() + def embeddings(self): + return TransformerDocumentEmbeddings("distilbert-base-uncased", fine_tune=True) + + @pytest.fixture() + def corpus(self, tasks_base_path): + return ClassificationCorpus(tasks_base_path / "imdb", label_type=self.train_label_type) + + @pytest.fixture() + def multiclass_train_test_sentence(self): + return Sentence("This movie was great!") + + def build_model(self, embeddings, label_dict, **kwargs): + model_args = { + "embeddings": embeddings, + "label_dictionary": label_dict, + "label_type": self.train_label_type, + "use_encoder": False, + "encoding_dim": 64, + "alpha": 0.95, + } + model_args.update(kwargs) + return self.model_cls(**model_args) + + @pytest.mark.integration() + def test_train_load_use_classifier( + self, results_base_path, corpus, embeddings, example_sentence, train_test_sentence + ): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + + model = self.build_model(embeddings, label_dict, mean_update_method="condensation") + + trainer = ModelTrainer(model, corpus) + trainer.fine_tune( + results_base_path, optimizer=torch.optim.AdamW, plugins=[DeepNCMPlugin()], **self.training_args + ) + + model.predict(train_test_sentence) + + for label in train_test_sentence.get_labels(self.train_label_type): + assert label.value is not None + assert 0.0 <= label.score <= 1.0 + assert isinstance(label.score, float) + + del trainer, model, corpus + + loaded_model = self.model_cls.load(results_base_path / "final-model.pt") + + loaded_model.predict(example_sentence) + loaded_model.predict([example_sentence, self.empty_sentence]) + loaded_model.predict([self.empty_sentence]) + + def test_get_prototype(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + + prototype = model.get_prototype(next(iter(label_dict.get_items()))) + assert isinstance(prototype, torch.Tensor) + assert prototype.shape == (model.encoding_dim,) + + with pytest.raises(ValueError): + model.get_prototype("NON_EXISTENT_CLASS") + + def test_get_closest_prototypes(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + input_vector = torch.randn(model.encoding_dim) + closest_prototypes = model.get_closest_prototypes(input_vector, top_k=2) + + assert len(closest_prototypes) == 2 + assert all(isinstance(item, tuple) and len(item) == 2 for item in closest_prototypes) + + with pytest.raises(ValueError): + model.get_closest_prototypes(torch.randn(model.encoding_dim + 1)) + + def test_forward_loss(self, corpus, embeddings): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict) + + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + + loss, count = model.forward_loss(sentences) + assert isinstance(loss, torch.Tensor) + assert loss.item() > 0 + assert count == len(sentences) + + @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) + def test_mean_update_methods(self, corpus, embeddings, mean_update_method): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) + + initial_prototypes = model.class_prototypes.clone() + + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + + model.forward_loss(sentences) + model.update_prototypes() + + assert not torch.all(torch.eq(initial_prototypes, model.class_prototypes)) + + @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) + def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): + label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) + model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) + + trainer = ModelTrainer(model, corpus) + plugin = DeepNCMPlugin() + plugin.attach_to(trainer) + + initial_class_counts = model.class_counts.clone() + initial_prototypes = model.class_prototypes.clone() + + # Simulate training epoch + plugin.after_training_epoch() + + if mean_update_method == "condensation": + assert torch.all(model.class_counts == 1), "Class counts should be 1 for condensation method after epoch" + elif mean_update_method == "online": + assert torch.all( + torch.eq(model.class_counts, initial_class_counts) + ), "Class counts should not change for online method after epoch" + + # Simulate training batch + sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] + for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): + sentence.add_label(self.train_label_type, label) + model.forward_loss(sentences) + plugin.after_training_batch() + + assert not torch.all( + torch.eq(initial_prototypes, model.class_prototypes) + ), "Prototypes should be updated after a batch" + + if mean_update_method == "condensation": + assert torch.all( + model.class_counts >= 1 + ), "Class counts should be >= 1 for condensation method after a batch" + elif mean_update_method == "online": + assert torch.all( + model.class_counts > initial_class_counts + ), "Class counts should increase for online method after a batch" From 213396c762492d9ffc77e82a234eb0eb1ecebf42 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 8 Nov 2024 17:45:23 -0800 Subject: [PATCH 02/46] feat: change DeepNCM classifier to a decoder so it can be used with different model types. make small changes to DefaultClassifier forward_loss to pass label tensor when needed. update tests --- flair/models/__init__.py | 4 +- flair/models/deepncm_classification_model.py | 328 +++--------------- flair/nn/model.py | 11 +- .../functional/deepncm_trainer_plugin.py | 13 +- tests/models/test_deepncm_classifier.py | 61 ++-- 5 files changed, 103 insertions(+), 314 deletions(-) diff --git a/flair/models/__init__.py b/flair/models/__init__.py index bf3651078a..d9fca4a706 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,4 @@ -from .deepncm_classification_model import DeepNCMClassifier +from .deepncm_classification_model import DeepNCMDecoder from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -38,5 +38,5 @@ "TextClassifier", "TextRegressor", "MultitaskModel", - "DeepNCMClassifier", + "DeepNCMDecoder", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py index b942e28919..ec3385a78a 100644 --- a/flair/models/deepncm_classification_model.py +++ b/flair/models/deepncm_classification_model.py @@ -1,20 +1,15 @@ import logging -from typing import Any, Dict, List, Literal, Optional, Tuple, Union +from typing import Literal, Optional import torch -from tqdm import tqdm import flair -from flair.data import Dictionary, Sentence -from flair.datasets import DataLoader, FlairDatapointDataset -from flair.embeddings import DocumentEmbeddings -from flair.embeddings.base import load_embeddings -from flair.nn import Classifier +from flair.data import Dictionary log = logging.getLogger("flair") -class DeepNCMClassifier(Classifier[Sentence]): +class DeepNCMDecoder(torch.nn.Module): """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. This model combines deep learning with the Nearest Class Mean (NCM) approach. @@ -32,47 +27,50 @@ class DeepNCMClassifier(Classifier[Sentence]): def __init__( self, - embeddings: DocumentEmbeddings, label_dictionary: Dictionary, - label_type: str, + embeddings_size: int, encoding_dim: Optional[int] = None, alpha: float = 0.9, mean_update_method: Literal["online", "condensation", "decay"] = "online", use_encoder: bool = True, - multi_label: bool = False, - multi_label_threshold: float = 0.5, - ): - """Initialize a DeepNCMClassifier. + multi_label: bool = False, # should get from the Model it belongs to + ) -> None: + """Initialize a DeepNCMDecoder. Args: - embeddings: Document embeddings to use for encoding text. - label_dictionary: Dictionary containing the label vocabulary. - label_type: The type of label to predict. encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). - alpha: The decay factor for updating class prototypes (default is 0.9). + alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). use_encoder: Whether to apply an encoder to the input embeddings (default is True). multi_label: Whether to predict multiple labels per sentence (default is False). - multi_label_threshold: The threshold for multi-label prediction (default is 0.5). """ + super().__init__() - self.embeddings = embeddings self.label_dictionary = label_dictionary - self._label_type = label_type + self._num_prototypes = len(label_dictionary) + self.alpha = alpha self.mean_update_method = mean_update_method self.use_encoder = use_encoder self.multi_label = multi_label - self.multi_label_threshold = multi_label_threshold - self.num_classes = len(label_dictionary) - self.embedding_dim = embeddings.embedding_length + + self.embedding_dim = embeddings_size if use_encoder: self.encoding_dim = encoding_dim or self.embedding_dim else: self.encoding_dim = self.embedding_dim + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(self._num_prototypes, self.encoding_dim)), requires_grad=False + ) + + self.class_counts = torch.nn.Parameter(torch.zeros(self._num_prototypes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(self._num_prototypes).to(flair.device) + self.to(flair.device) + self._validate_parameters() if self.use_encoder: @@ -84,22 +82,11 @@ def __init__( else: self.encoder = torch.nn.Sequential(torch.nn.Identity()) - self.loss_function = ( - torch.nn.BCEWithLogitsLoss(reduction="sum") - if self.multi_label - else torch.nn.CrossEntropyLoss(reduction="sum") - ) - - self.class_prototypes = torch.nn.Parameter( - torch.nn.functional.normalize(torch.randn(self.num_classes, self.encoding_dim)), requires_grad=False - ) - self.class_counts = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=False) - self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) - self.prototype_update_counts = torch.zeros(self.num_classes).to(flair.device) + # all parameters will be pushed internally to the specified device self.to(flair.device) def _validate_parameters(self) -> None: - """Validate the input parameters.""" + """Validate that the input parameters have valid and compatible values.""" assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" assert self.mean_update_method in [ "online", @@ -108,26 +95,13 @@ def _validate_parameters(self) -> None: ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" assert self.encoding_dim > 0, "encoding_dim must be greater than 0" - def forward(self, sentences: Union[List[Sentence], Sentence]) -> torch.Tensor: - """Encode the input sentences using embeddings and optional encoder. - - Args: - sentences: Input sentence or list of sentences. - - Returns: - torch.Tensor: Encoded representations of the input sentences. - """ - if not isinstance(sentences, list): - sentences = [sentences] - - self.embeddings.embed(sentences) - sentence_embeddings = torch.stack([sentence.get_embedding() for sentence in sentences]) - encoded_embeddings = self.encoder(sentence_embeddings) - - return encoded_embeddings + @property + def num_prototypes(self) -> int: + """The number of class prototypes.""" + return self.class_prototypes.size(0) def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: - """Calculate distances between encoded embeddings and class prototypes. + """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. Args: encoded_embeddings: Encoded representations of the input sentences. @@ -135,60 +109,7 @@ def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor Returns: torch.Tensor: Distances between encoded embeddings and class prototypes. """ - return torch.cdist(encoded_embeddings, self.class_prototypes) - - def forward_loss(self, data_points: List[Sentence]) -> Tuple[torch.Tensor, int]: - """Compute the loss for a batch of sentences. - - Args: - data_points: A list of sentences. - - Returns: - Tuple[torch.Tensor, int]: The total loss and the number of sentences. - """ - encoded_embeddings = self.forward(data_points) - labels = self._prepare_label_tensor(data_points) - distances = self._calculate_distances(encoded_embeddings) - loss = self.loss_function(-distances, labels) - self._calculate_prototype_updates(encoded_embeddings, labels) - - return loss, len(data_points) - - def _prepare_label_tensor(self, sentences: List[Sentence]) -> torch.Tensor: - """Prepare the label tensor for the given sentences. - - Args: - sentences: A list of sentences. - - Returns: - torch.Tensor: The label tensor for the given sentences. - """ - if self.multi_label: - return torch.tensor( - [ - [ - ( - 1 - if label - in [sentence_label.value for sentence_label in sentence.get_labels(self._label_type)] - else 0 - ) - for label in self.label_dictionary.get_items() - ] - for sentence in sentences - ], - dtype=torch.float, - device=flair.device, - ) - else: - return torch.tensor( - [ - self.label_dictionary.get_idx_for_item(sentence.get_label(self._label_type).value) - for sentence in sentences - ], - dtype=torch.long, - device=flair.device, - ) + return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: """Calculate updates for class prototypes based on the current batch. @@ -198,7 +119,7 @@ def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: labels: True labels for the input sentences. """ one_hot = ( - labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_classes).float() + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() ) updates = torch.matmul(one_hot.t(), encoded_embeddings) @@ -230,163 +151,25 @@ def update_prototypes(self) -> None: # Reset prototype updates self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) - self.prototype_update_counts = torch.zeros(self.num_classes, device=flair.device) - - def predict( - self, - sentences: Union[List[Sentence], Sentence], - mini_batch_size: int = 32, - return_probabilities_for_all_classes: bool = False, - verbose: bool = False, - label_name: Optional[str] = None, - return_loss: bool = False, - embedding_storage_mode: str = "none", - ) -> Union[List[Sentence], Tuple[float, int]]: - """Predict classes for a list of sentences. - - Args: - sentences: A list of sentences or a single sentence. - mini_batch_size: Size of mini batches during prediction. - return_probabilities_for_all_classes: Whether to return probabilities for all classes. - verbose: If True, show progress bar during prediction. - label_name: The name of the label to use for prediction. - return_loss: If True, compute and return loss. - embedding_storage_mode: The mode for storing embeddings ('none', 'cpu', or 'gpu'). - - Returns: - Union[List[Sentence], Tuple[float, int]]: - if return_loss is True, returns a tuple of total loss and total number of sentences; - otherwise, returns the list of sentences with predicted labels. - """ - with torch.no_grad(): - if not isinstance(sentences, list): - sentences = [sentences] - if not sentences: - return sentences - - label_name = label_name or self.label_type - Sentence.set_context_for_sentences(sentences) - - filtered_sentences = [sent for sent in sentences if len(sent) > 0] - reordered_sentences = sorted(filtered_sentences, key=len, reverse=True) - - if len(reordered_sentences) == 0: - return sentences - - dataloader = DataLoader( - dataset=FlairDatapointDataset(reordered_sentences), - batch_size=mini_batch_size, - ) - - if verbose: - progress_bar = tqdm(dataloader) - progress_bar.set_description("Predicting") - dataloader = progress_bar - - total_loss = 0.0 - total_sentences = 0 - - for batch in dataloader: - if not batch: - continue - - encoded_embeddings = self.forward(batch) - distances = self._calculate_distances(encoded_embeddings) - - if self.multi_label: - probabilities = torch.sigmoid(-distances) - else: - probabilities = torch.nn.functional.softmax(-distances, dim=1) + self.prototype_update_counts = torch.zeros(self.num_prototypes, device=flair.device) - if return_loss: - labels = self._prepare_label_tensor(batch) - loss = self.loss_function(-distances, labels) - total_loss += loss.item() - total_sentences += len(batch) + def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of the decoder, which calculates the scores as prototype distances. - for sentence_index, sentence in enumerate(batch): - sentence.remove_labels(label_name) - - if self.multi_label: - for label_index, probability in enumerate(probabilities[sentence_index]): - if probability > self.multi_label_threshold or return_probabilities_for_all_classes: - label_value = self.label_dictionary.get_item_for_index(label_index) - sentence.add_label(label_name, label_value, probability.item()) - else: - predicted_idx = torch.argmax(probabilities[sentence_index]) - label_value = self.label_dictionary.get_item_for_index(predicted_idx.item()) - sentence.add_label(label_name, label_value, probabilities[sentence_index, predicted_idx].item()) - - if return_probabilities_for_all_classes: - for label_index, probability in enumerate(probabilities[sentence_index]): - label_value = self.label_dictionary.get_item_for_index(label_index) - sentence.add_label(f"{label_name}_all", label_value, probability.item()) - - for sentence in batch: - sentence.clear_embeddings(embedding_storage_mode) - - if return_loss: - return total_loss, total_sentences - return sentences - - def _get_state_dict(self) -> Dict[str, Any]: - """Get the state dictionary of the model. - - Returns: - Dict[str, Any]: The state dictionary containing model parameters and configuration. + :param embedded: Embedded representations of the input sentences. + :param label_tensor: True labels for the input sentences as a tensor. + :return: Scores as a tensor of distances to class prototypes. """ - model_state = { - "embeddings": self.embeddings.save_embeddings(), - "label_dictionary": self.label_dictionary, - "label_type": self.label_type, - "encoding_dim": self.encoding_dim, - "alpha": self.alpha, - "mean_update_method": self.mean_update_method, - "use_encoder": self.use_encoder, - "multi_label": self.multi_label, - "multi_label_threshold": self.multi_label_threshold, - "class_prototypes": self.class_prototypes.cpu(), - "class_counts": self.class_counts.cpu(), - "encoder": self.encoder.state_dict(), - } - return model_state - - @classmethod - def _init_model_with_state_dict(cls, state, **kwargs) -> "DeepNCMClassifier": - """Initialize the model from a state dictionary. + encoded_embeddings = self.encoder(embedded) - Args: - state: The state dictionary containing model parameters and configuration. - **kwargs: Additional keyword arguments for model initialization. + distances = self._calculate_distances(encoded_embeddings) - Returns: - DeepNCMClassifier: An instance of the model initialized with the given state. - """ - embeddings = state["embeddings"] - if isinstance(embeddings, dict): - embeddings = load_embeddings(embeddings) - - model = cls( - embeddings=embeddings, - label_dictionary=state["label_dictionary"], - label_type=state["label_type"], - encoding_dim=state["encoding_dim"], - alpha=state["alpha"], - mean_update_method=state["mean_update_method"], - use_encoder=state["use_encoder"], - multi_label=state.get("multi_label", False), - multi_label_threshold=state.get("multi_label_threshold", 0.5), - **kwargs, - ) + if label_tensor is not None: + self._calculate_prototype_updates(encoded_embeddings, label_tensor) - if "encoder" in state: - model.encoder.load_state_dict(state["encoder"]) - if "class_prototypes" in state: - model.class_prototypes.data = state["class_prototypes"].to(flair.device) - if "class_counts" in state: - model.class_counts.data = state["class_counts"].to(flair.device) + scores = -distances - return model + return scores def get_prototype(self, class_name: str) -> torch.Tensor: """Get the prototype vector for a given class name. @@ -407,15 +190,15 @@ def get_prototype(self, class_name: str) -> torch.Tensor: return self.class_prototypes[class_idx].clone() - def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> List[Tuple[str, float]]: - """Get the top_k closest prototype vectors to the given input vector using the configured distance metric. + def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> list[tuple[str, float]]: + """Get the k closest prototype vectors to the given input vector using the configured distance metric. Args: input_vector (torch.Tensor): The input vector to compare against prototypes. top_k (int): The number of closest prototypes to return (default is 5). Returns: - List[Tuple[str, float]]: Each tuple contains (class_name, distance). + list[tuple[str, float]]: Each tuple contains (class_name, distance). """ if input_vector.dim() != 1: raise ValueError("Input vector must be a 1D tensor") @@ -434,22 +217,3 @@ def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> nearest_prototypes.append((class_name, value.item())) return nearest_prototypes - - @property - def label_type(self) -> str: - """Get the label type for this classifier.""" - return self._label_type - - def __str__(self) -> str: - """Get a string representation of the model. - - Returns: - str: A string describing the model architecture. - """ - return ( - f"DeepNCMClassifier(\n" - f" (embeddings): {self.embeddings}\n" - f" (encoder): {self.encoder}\n" - f" (prototypes): {self.class_prototypes.shape}\n" - f")" - ) diff --git a/flair/nn/model.py b/flair/nn/model.py index 03834afc76..69c51f7a5e 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import Counter from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, List, Optional, Tuple, Union import torch.nn from torch import Tensor @@ -778,8 +778,11 @@ def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # pass data points through network to get encoded data point tensor data_point_tensor = self._encode_data_points(sentences, data_points) - # decode - scores = self.decoder(data_point_tensor) + # decode, passing label tensor if needed, such as for prototype updates + if "label_tensor" in inspect.signature(self.decoder.forward).parameters: + scores = self.decoder(data_point_tensor, label_tensor) + else: + scores = self.decoder(data_point_tensor) # an optional masking step (no masking in most cases) scores = self._mask_scores(scores, data_points) @@ -814,7 +817,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ): + ) -> Optional[Union[List[DT], Tuple[float, int]]]: """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py index 2c4c0ccb49..e5394debd2 100644 --- a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -1,6 +1,7 @@ import torch -from flair.models import DeepNCMClassifier, MultitaskModel +from flair.models import MultitaskModel +from flair.models.deepncm_classification_model import DeepNCMDecoder from flair.trainers.plugins.base import TrainerPlugin @@ -11,7 +12,7 @@ class DeepNCMPlugin(TrainerPlugin): """ def _process_models(self, operation: str): - """Process updates for all DeepNCMClassifier models in the trainer. + """Process updates for all DeepNCMDecoder decoders in the trainer. Args: operation (str): The operation to perform ('condensation' or 'update') @@ -21,11 +22,11 @@ def _process_models(self, operation: str): models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] for sub_model in models: - if isinstance(sub_model, DeepNCMClassifier): - if operation == "condensation" and sub_model.mean_update_method == "condensation": - sub_model.class_counts.data = torch.ones_like(sub_model.class_counts) + if hasattr(sub_model, "decoder") and isinstance(sub_model.decoder, DeepNCMDecoder): + if operation == "condensation" and sub_model.decoder.mean_update_method == "condensation": + sub_model.decoder.class_counts.data = torch.ones_like(sub_model.decoder.class_counts) elif operation == "update": - sub_model.update_prototypes() + sub_model.decoder.update_prototypes() @TrainerPlugin.hook def after_training_epoch(self, **kwargs): diff --git a/tests/models/test_deepncm_classifier.py b/tests/models/test_deepncm_classifier.py index 3b76b6c0b9..b587a33142 100644 --- a/tests/models/test_deepncm_classifier.py +++ b/tests/models/test_deepncm_classifier.py @@ -4,14 +4,14 @@ from flair.data import Sentence from flair.datasets import ClassificationCorpus from flair.embeddings import TransformerDocumentEmbeddings -from flair.models import DeepNCMClassifier +from flair.models import DeepNCMDecoder, TextClassifier from flair.trainers import ModelTrainer from flair.trainers.plugins import DeepNCMPlugin from tests.model_test_utils import BaseModelTest -class TestDeepNCMClassifier(BaseModelTest): - model_cls = DeepNCMClassifier +class TestDeepNCMDecoder(BaseModelTest): + model_cls = TextClassifier train_label_type = "class" multiclass_prediction_labels = ["POSITIVE", "NEGATIVE"] training_args = { @@ -33,6 +33,7 @@ def multiclass_train_test_sentence(self): return Sentence("This movie was great!") def build_model(self, embeddings, label_dict, **kwargs): + model_args = { "embeddings": embeddings, "label_dictionary": label_dict, @@ -40,9 +41,27 @@ def build_model(self, embeddings, label_dict, **kwargs): "use_encoder": False, "encoding_dim": 64, "alpha": 0.95, + "mean_update_method": "online", } model_args.update(kwargs) - return self.model_cls(**model_args) + + deepncm_decoder = DeepNCMDecoder( + label_dictionary=model_args["label_dictionary"], + embeddings_size=model_args["embeddings"].embedding_length, + alpha=model_args["alpha"], + encoding_dim=model_args["encoding_dim"], + mean_update_method=model_args["mean_update_method"], + ) + + model = self.model_cls( + embeddings=model_args["embeddings"], + label_dictionary=model_args["label_dictionary"], + label_type=model_args["label_type"], + multi_label=model_args.get("multi_label", False), + decoder=deepncm_decoder, + ) + + return model @pytest.mark.integration() def test_train_load_use_classifier( @@ -76,24 +95,24 @@ def test_get_prototype(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict) - prototype = model.get_prototype(next(iter(label_dict.get_items()))) + prototype = model.decoder.get_prototype(next(iter(label_dict.get_items()))) assert isinstance(prototype, torch.Tensor) - assert prototype.shape == (model.encoding_dim,) + assert prototype.shape == (model.decoder.encoding_dim,) with pytest.raises(ValueError): - model.get_prototype("NON_EXISTENT_CLASS") + model.decoder.get_prototype("NON_EXISTENT_CLASS") def test_get_closest_prototypes(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict) - input_vector = torch.randn(model.encoding_dim) - closest_prototypes = model.get_closest_prototypes(input_vector, top_k=2) + input_vector = torch.randn(model.decoder.encoding_dim) + closest_prototypes = model.decoder.get_closest_prototypes(input_vector, top_k=2) assert len(closest_prototypes) == 2 assert all(isinstance(item, tuple) and len(item) == 2 for item in closest_prototypes) with pytest.raises(ValueError): - model.get_closest_prototypes(torch.randn(model.encoding_dim + 1)) + model.decoder.get_closest_prototypes(torch.randn(model.decoder.encoding_dim + 1)) def test_forward_loss(self, corpus, embeddings): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) @@ -113,16 +132,16 @@ def test_mean_update_methods(self, corpus, embeddings, mean_update_method): label_dict = corpus.make_label_dictionary(label_type=self.train_label_type) model = self.build_model(embeddings, label_dict, mean_update_method=mean_update_method) - initial_prototypes = model.class_prototypes.clone() + initial_prototypes = model.decoder.class_prototypes.clone() sentences = [Sentence("This movie was great!"), Sentence("I didn't enjoy this film at all.")] for sentence, label in zip(sentences, list(label_dict.get_items())[:2]): sentence.add_label(self.train_label_type, label) model.forward_loss(sentences) - model.update_prototypes() + model.decoder.update_prototypes() - assert not torch.all(torch.eq(initial_prototypes, model.class_prototypes)) + assert not torch.all(torch.eq(initial_prototypes, model.decoder.class_prototypes)) @pytest.mark.parametrize("mean_update_method", ["online", "condensation", "decay"]) def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): @@ -133,17 +152,19 @@ def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): plugin = DeepNCMPlugin() plugin.attach_to(trainer) - initial_class_counts = model.class_counts.clone() - initial_prototypes = model.class_prototypes.clone() + initial_class_counts = model.decoder.class_counts.clone() + initial_prototypes = model.decoder.class_prototypes.clone() # Simulate training epoch plugin.after_training_epoch() if mean_update_method == "condensation": - assert torch.all(model.class_counts == 1), "Class counts should be 1 for condensation method after epoch" + assert torch.all( + model.decoder.class_counts == 1 + ), "Class counts should be 1 for condensation method after epoch" elif mean_update_method == "online": assert torch.all( - torch.eq(model.class_counts, initial_class_counts) + torch.eq(model.decoder.class_counts, initial_class_counts) ), "Class counts should not change for online method after epoch" # Simulate training batch @@ -154,14 +175,14 @@ def test_deepncm_plugin(self, corpus, embeddings, mean_update_method): plugin.after_training_batch() assert not torch.all( - torch.eq(initial_prototypes, model.class_prototypes) + torch.eq(initial_prototypes, model.decoder.class_prototypes) ), "Prototypes should be updated after a batch" if mean_update_method == "condensation": assert torch.all( - model.class_counts >= 1 + model.decoder.class_counts >= 1 ), "Class counts should be >= 1 for condensation method after a batch" elif mean_update_method == "online": assert torch.all( - model.class_counts > initial_class_counts + model.decoder.class_counts > initial_class_counts ), "Class counts should increase for online method after a batch" From 649e68dfbfb509b2f38d14aafb215ade283b0364 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Wed, 18 Dec 2024 15:19:54 -0500 Subject: [PATCH 03/46] refactor: move DeepNCMDecoder to decoder.py --- flair/models/__init__.py | 2 - flair/models/deepncm_classification_model.py | 208 ----------------- flair/nn/__init__.py | 3 +- flair/nn/decoder.py | 217 +++++++++++++++++- .../functional/deepncm_trainer_plugin.py | 2 +- tests/models/test_deepncm_classifier.py | 3 +- 6 files changed, 221 insertions(+), 214 deletions(-) diff --git a/flair/models/__init__.py b/flair/models/__init__.py index d9fca4a706..e75daf074b 100644 --- a/flair/models/__init__.py +++ b/flair/models/__init__.py @@ -1,4 +1,3 @@ -from .deepncm_classification_model import DeepNCMDecoder from .entity_linker_model import SpanClassifier from .entity_mention_linking import EntityMentionLinker from .language_model import LanguageModel @@ -38,5 +37,4 @@ "TextClassifier", "TextRegressor", "MultitaskModel", - "DeepNCMDecoder", ] diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py index ec3385a78a..be1b5788a0 100644 --- a/flair/models/deepncm_classification_model.py +++ b/flair/models/deepncm_classification_model.py @@ -9,211 +9,3 @@ log = logging.getLogger("flair") -class DeepNCMDecoder(torch.nn.Module): - """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. - - This model combines deep learning with the Nearest Class Mean (NCM) approach. - It uses document embeddings to represent text, optionally applies an encoder, - and classifies based on the nearest class prototype in the embedded space. - - The model supports various methods for updating class prototypes during training, - making it adaptable to different learning scenarios. - - This implementation is based on the research paper: - Guerriero, S., Caputo, B., & Mensink, T. (2018). DeepNCM: Deep Nearest Class Mean Classifiers. - In International Conference on Learning Representations (ICLR) 2018 Workshop. - URL: https://openreview.net/forum?id=rkPLZ4JPM - """ - - def __init__( - self, - label_dictionary: Dictionary, - embeddings_size: int, - encoding_dim: Optional[int] = None, - alpha: float = 0.9, - mean_update_method: Literal["online", "condensation", "decay"] = "online", - use_encoder: bool = True, - multi_label: bool = False, # should get from the Model it belongs to - ) -> None: - """Initialize a DeepNCMDecoder. - - Args: - encoding_dim: The dimensionality of the encoded embeddings (default is the same as the input embeddings). - alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. - mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). - use_encoder: Whether to apply an encoder to the input embeddings (default is True). - multi_label: Whether to predict multiple labels per sentence (default is False). - """ - - super().__init__() - - self.label_dictionary = label_dictionary - self._num_prototypes = len(label_dictionary) - - self.alpha = alpha - self.mean_update_method = mean_update_method - self.use_encoder = use_encoder - self.multi_label = multi_label - - self.embedding_dim = embeddings_size - - if use_encoder: - self.encoding_dim = encoding_dim or self.embedding_dim - else: - self.encoding_dim = self.embedding_dim - - self.class_prototypes = torch.nn.Parameter( - torch.nn.functional.normalize(torch.randn(self._num_prototypes, self.encoding_dim)), requires_grad=False - ) - - self.class_counts = torch.nn.Parameter(torch.zeros(self._num_prototypes), requires_grad=False) - self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) - self.prototype_update_counts = torch.zeros(self._num_prototypes).to(flair.device) - self.to(flair.device) - - self._validate_parameters() - - if self.use_encoder: - self.encoder = torch.nn.Sequential( - torch.nn.Linear(self.embedding_dim, self.encoding_dim * 2), - torch.nn.ReLU(), - torch.nn.Linear(self.encoding_dim * 2, self.encoding_dim), - ) - else: - self.encoder = torch.nn.Sequential(torch.nn.Identity()) - - # all parameters will be pushed internally to the specified device - self.to(flair.device) - - def _validate_parameters(self) -> None: - """Validate that the input parameters have valid and compatible values.""" - assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" - assert self.mean_update_method in [ - "online", - "condensation", - "decay", - ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" - assert self.encoding_dim > 0, "encoding_dim must be greater than 0" - - @property - def num_prototypes(self) -> int: - """The number of class prototypes.""" - return self.class_prototypes.size(0) - - def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: - """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. - - Args: - encoded_embeddings: Encoded representations of the input sentences. - - Returns: - torch.Tensor: Distances between encoded embeddings and class prototypes. - """ - return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) - - def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: - """Calculate updates for class prototypes based on the current batch. - - Args: - encoded_embeddings: Encoded representations of the input sentences. - labels: True labels for the input sentences. - """ - one_hot = ( - labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() - ) - - updates = torch.matmul(one_hot.t(), encoded_embeddings) - counts = one_hot.sum(dim=0) - mask = counts > 0 - self.prototype_updates[mask] += updates[mask] - self.prototype_update_counts[mask] += counts[mask] - - def update_prototypes(self) -> None: - """Apply accumulated updates to class prototypes.""" - with torch.no_grad(): - update_mask = self.prototype_update_counts > 0 - if update_mask.any(): - if self.mean_update_method in ["online", "condensation"]: - new_counts = self.class_counts[update_mask] + self.prototype_update_counts[update_mask] - self.class_prototypes[update_mask] = ( - self.class_counts[update_mask].unsqueeze(1) * self.class_prototypes[update_mask] - + self.prototype_updates[update_mask] - ) / new_counts.unsqueeze(1) - self.class_counts[update_mask] = new_counts - elif self.mean_update_method == "decay": - new_prototypes = self.prototype_updates[update_mask] / self.prototype_update_counts[ - update_mask - ].unsqueeze(1) - self.class_prototypes[update_mask] = ( - self.alpha * self.class_prototypes[update_mask] + (1 - self.alpha) * new_prototypes - ) - self.class_counts[update_mask] += self.prototype_update_counts[update_mask] - - # Reset prototype updates - self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) - self.prototype_update_counts = torch.zeros(self.num_prototypes, device=flair.device) - - def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: - """Forward pass of the decoder, which calculates the scores as prototype distances. - - :param embedded: Embedded representations of the input sentences. - :param label_tensor: True labels for the input sentences as a tensor. - :return: Scores as a tensor of distances to class prototypes. - """ - encoded_embeddings = self.encoder(embedded) - - distances = self._calculate_distances(encoded_embeddings) - - if label_tensor is not None: - self._calculate_prototype_updates(encoded_embeddings, label_tensor) - - scores = -distances - - return scores - - def get_prototype(self, class_name: str) -> torch.Tensor: - """Get the prototype vector for a given class name. - - Args: - class_name: The name of the class whose prototype vector is requested. - - Returns: - torch.Tensor: The prototype vector for the given class. - - Raises: - ValueError: If the class name is not found in the label dictionary. - """ - try: - class_idx = self.label_dictionary.get_idx_for_item(class_name) - except IndexError as exc: - raise ValueError(f"Class name '{class_name}' not found in the label dictionary") from exc - - return self.class_prototypes[class_idx].clone() - - def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> list[tuple[str, float]]: - """Get the k closest prototype vectors to the given input vector using the configured distance metric. - - Args: - input_vector (torch.Tensor): The input vector to compare against prototypes. - top_k (int): The number of closest prototypes to return (default is 5). - - Returns: - list[tuple[str, float]]: Each tuple contains (class_name, distance). - """ - if input_vector.dim() != 1: - raise ValueError("Input vector must be a 1D tensor") - if input_vector.size(0) != self.class_prototypes.size(1): - raise ValueError( - f"Input vector dimension ({input_vector.size(0)}) does not match prototype dimension ({self.class_prototypes.size(1)})" - ) - - input_vector = input_vector.unsqueeze(0) - distances = self._calculate_distances(input_vector) - top_k_values, top_k_indices = torch.topk(distances.squeeze(), k=top_k, largest=False) - - nearest_prototypes = [] - for idx, value in zip(top_k_indices, top_k_values): - class_name = self.label_dictionary.get_item_for_index(idx.item()) - nearest_prototypes.append((class_name, value.item())) - - return nearest_prototypes diff --git a/flair/nn/__init__.py b/flair/nn/__init__.py index 1ceae91859..9ced1753c1 100644 --- a/flair/nn/__init__.py +++ b/flair/nn/__init__.py @@ -1,4 +1,4 @@ -from .decoder import LabelVerbalizerDecoder, PrototypicalDecoder +from .decoder import DeepNCMDecoder, LabelVerbalizerDecoder, PrototypicalDecoder from .dropout import LockedDropout, WordDropout from .model import Classifier, DefaultClassifier, Model @@ -9,5 +9,6 @@ "DefaultClassifier", "Model", "PrototypicalDecoder", + "DeepNCMDecoder", "LabelVerbalizerDecoder", ] diff --git a/flair/nn/decoder.py b/flair/nn/decoder.py index 48cdbf39b0..b5fc49ecf0 100644 --- a/flair/nn/decoder.py +++ b/flair/nn/decoder.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Literal, Optional import torch @@ -123,6 +123,221 @@ def forward(self, embedded): return scores +class DeepNCMDecoder(torch.nn.Module): + """Deep Nearest Class Mean (DeepNCM) Classifier for text classification tasks. + + This model combines deep learning with the Nearest Class Mean (NCM) approach. + It uses document embeddings to represent text, optionally applies an encoder, + and classifies based on the nearest class prototype in the embedded space. + + The model supports various methods for updating class prototypes during training, + making it adaptable to different learning scenarios. + + This implementation is based on the research paper: + Guerriero, S., Caputo, B., & Mensink, T. (2018). DeepNCM: Deep Nearest Class Mean Classifiers. + In International Conference on Learning Representations (ICLR) 2018 Workshop. + URL: https://openreview.net/forum?id=rkPLZ4JPM + """ + + def __init__( + self, + label_dictionary: Dictionary, + embeddings_size: int, + use_encoder: bool = True, + encoding_dim: Optional[int] = None, + alpha: float = 0.9, + mean_update_method: Literal["online", "condensation", "decay"] = "online", + multi_label: bool = False, # should get from the Model it belongs to + ) -> None: + """Initialize a DeepNCMDecoder. + + Args: + label_dictionary: Label dictionary from the corpus + embeddings_size: The dimensionality of the input embeddings, usually the same as the model embeddings + use_encoder: Whether to apply an encoder to the input embeddings (default is True). + encoding_dim: The dimensionality of the encoded embeddings if an encoder is used (default is the same as the input embeddings). + alpha: The decay factor for updating class prototypes (default is 0.9). This only applies when mean_update_method is 'decay'. + mean_update_method: The method for updating class prototypes ('online', 'condensation', or 'decay'). + online - + condensation - + decay - after every batch, + multi_label: Whether to predict multiple labels per sentence (default is False, and performs multi-class clsasification). + """ + + super().__init__() + + self.label_dictionary = label_dictionary + self._num_prototypes = len(label_dictionary) + + self.alpha = alpha + self.mean_update_method = mean_update_method + self.use_encoder = use_encoder + self.multi_label = multi_label + + self.embedding_dim = embeddings_size + + if use_encoder: + self.encoding_dim = encoding_dim or self.embedding_dim + else: + self.encoding_dim = self.embedding_dim + + self.class_prototypes = torch.nn.Parameter( + torch.nn.functional.normalize(torch.randn(self._num_prototypes, self.encoding_dim)), requires_grad=False + ) + + self.class_counts = torch.nn.Parameter(torch.zeros(self._num_prototypes), requires_grad=False) + self.prototype_updates = torch.zeros_like(self.class_prototypes).to(flair.device) + self.prototype_update_counts = torch.zeros(self._num_prototypes).to(flair.device) + self.to(flair.device) + + self._validate_parameters() + + if self.use_encoder: + self.encoder = torch.nn.Sequential( + torch.nn.Linear(self.embedding_dim, self.encoding_dim * 2), + torch.nn.ReLU(), + torch.nn.Linear(self.encoding_dim * 2, self.encoding_dim), + ) + else: + self.encoder = torch.nn.Sequential(torch.nn.Identity()) + + # all parameters will be pushed internally to the specified device + self.to(flair.device) + + def _validate_parameters(self) -> None: + """Validate that the input parameters have valid and compatible values.""" + assert 0 <= self.alpha <= 1, "alpha must be in the range [0, 1]" + assert self.mean_update_method in [ + "online", + "condensation", + "decay", + ], f"Invalid mean_update_method: {self.mean_update_method}. Must be 'online', 'condensation', or 'decay'" + assert self.encoding_dim > 0, "encoding_dim must be greater than 0" + + @property + def num_prototypes(self) -> int: + """The number of class prototypes.""" + return self.class_prototypes.size(0) + + def _calculate_distances(self, encoded_embeddings: torch.Tensor) -> torch.Tensor: + """Calculate the squared Euclidean distance between encoded embeddings and class prototypes. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + + Returns: + torch.Tensor: Distances between encoded embeddings and class prototypes. + """ + return torch.cdist(encoded_embeddings, self.class_prototypes).pow(2) + + def _calculate_prototype_updates(self, encoded_embeddings: torch.Tensor, labels: torch.Tensor) -> None: + """Calculate updates for class prototypes based on the current batch. + + Args: + encoded_embeddings: Encoded representations of the input sentences. + labels: True labels for the input sentences. + """ + one_hot = ( + labels if self.multi_label else torch.nn.functional.one_hot(labels, num_classes=self.num_prototypes).float() + ) + + updates = torch.matmul(one_hot.t(), encoded_embeddings) + counts = one_hot.sum(dim=0) + mask = counts > 0 + self.prototype_updates[mask] += updates[mask] + self.prototype_update_counts[mask] += counts[mask] + + def update_prototypes(self) -> None: + """Apply accumulated updates to class prototypes.""" + with torch.no_grad(): + update_mask = self.prototype_update_counts > 0 + if update_mask.any(): + if self.mean_update_method in ["online", "condensation"]: + new_counts = self.class_counts[update_mask] + self.prototype_update_counts[update_mask] + self.class_prototypes[update_mask] = ( + self.class_counts[update_mask].unsqueeze(1) * self.class_prototypes[update_mask] + + self.prototype_updates[update_mask] + ) / new_counts.unsqueeze(1) + self.class_counts[update_mask] = new_counts + elif self.mean_update_method == "decay": + new_prototypes = self.prototype_updates[update_mask] / self.prototype_update_counts[ + update_mask + ].unsqueeze(1) + self.class_prototypes[update_mask] = ( + self.alpha * self.class_prototypes[update_mask] + (1 - self.alpha) * new_prototypes + ) + self.class_counts[update_mask] += self.prototype_update_counts[update_mask] + + # Reset prototype updates + self.prototype_updates = torch.zeros_like(self.class_prototypes, device=flair.device) + self.prototype_update_counts = torch.zeros(self.num_prototypes, device=flair.device) + + def forward(self, embedded: torch.Tensor, label_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass of the decoder, which calculates the scores as prototype distances. + + :param embedded: Embedded representations of the input sentences. + :param label_tensor: True labels for the input sentences as a tensor. + :return: Scores as a tensor of distances to class prototypes. + """ + encoded_embeddings = self.encoder(embedded) + + distances = self._calculate_distances(encoded_embeddings) + + if label_tensor is not None: + self._calculate_prototype_updates(encoded_embeddings, label_tensor) + + scores = -distances + + return scores + + def get_prototype(self, class_name: str) -> torch.Tensor: + """Get the prototype vector for a given class name. + + Args: + class_name: The name of the class whose prototype vector is requested. + + Returns: + torch.Tensor: The prototype vector for the given class. + + Raises: + ValueError: If the class name is not found in the label dictionary. + """ + try: + class_idx = self.label_dictionary.get_idx_for_item(class_name) + except IndexError as exc: + raise ValueError(f"Class name '{class_name}' not found in the label dictionary") from exc + + return self.class_prototypes[class_idx].clone() + + def get_closest_prototypes(self, input_vector: torch.Tensor, top_k: int = 5) -> list[tuple[str, float]]: + """Get the k closest prototype vectors to the given input vector using the configured distance metric. + + Args: + input_vector (torch.Tensor): The input vector to compare against prototypes. + top_k (int): The number of closest prototypes to return (default is 5). + + Returns: + list[tuple[str, float]]: Each tuple contains (class_name, distance). + """ + if input_vector.dim() != 1: + raise ValueError("Input vector must be a 1D tensor") + if input_vector.size(0) != self.class_prototypes.size(1): + raise ValueError( + f"Input vector dimension ({input_vector.size(0)}) does not match prototype dimension ({self.class_prototypes.size(1)})" + ) + + input_vector = input_vector.unsqueeze(0) + distances = self._calculate_distances(input_vector) + top_k_values, top_k_indices = torch.topk(distances.squeeze(), k=top_k, largest=False) + + nearest_prototypes = [] + for idx, value in zip(top_k_indices, top_k_values): + class_name = self.label_dictionary.get_item_for_index(idx.item()) + nearest_prototypes.append((class_name, value.item())) + + return nearest_prototypes + + class LabelVerbalizerDecoder(torch.nn.Module): """A class for decoding labels using the idea of siamese networks / bi-encoders. This can be used for all classification tasks in flair. diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py index e5394debd2..981d413d61 100644 --- a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -1,7 +1,7 @@ import torch from flair.models import MultitaskModel -from flair.models.deepncm_classification_model import DeepNCMDecoder +from flair.nn import DeepNCMDecoder from flair.trainers.plugins.base import TrainerPlugin diff --git a/tests/models/test_deepncm_classifier.py b/tests/models/test_deepncm_classifier.py index b587a33142..5324f08fc3 100644 --- a/tests/models/test_deepncm_classifier.py +++ b/tests/models/test_deepncm_classifier.py @@ -4,7 +4,8 @@ from flair.data import Sentence from flair.datasets import ClassificationCorpus from flair.embeddings import TransformerDocumentEmbeddings -from flair.models import DeepNCMDecoder, TextClassifier +from flair.models import TextClassifier +from flair.nn import DeepNCMDecoder from flair.trainers import ModelTrainer from flair.trainers.plugins import DeepNCMPlugin from tests.model_test_utils import BaseModelTest From 0ccb22f2c5acde978d1d9ed61ae98da2c33f5a4f Mon Sep 17 00:00:00 2001 From: alanakbik Date: Wed, 1 Jan 2025 11:58:44 +0100 Subject: [PATCH 04/46] Unclutter printouts of RelationClassifier during evaluation --- flair/models/relation_classifier_model.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 9c6c69577f..389bf05157 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -667,6 +667,28 @@ def predict( return loss if return_loss else None + def _print_predictions(self, batch, gold_label_type: str) -> list[str]: + lines = [] + for datapoint in batch: + # check if there is a label mismatch + g = [label.labeled_identifier for label in datapoint.get_labels(gold_label_type)] + p = [label.labeled_identifier for label in datapoint.get_labels("predicted")] + g.sort() + p.sort() + + # if the gold label is O and is correctly predicted as no label, do not print out as this clutters + # the output file with trivial predictions + if not (len(datapoint.get_labels(gold_label_type)) == 1 and datapoint.get_label(gold_label_type).value == "O" and len(datapoint.get_labels("predicted")) == 0): + correct_string = " -> MISMATCH!\n" if g != p else "" + eval_line = ( + f"{datapoint.text}\n" + f" - Gold: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels(gold_label_type))}\n" + f" - Pred: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels('predicted'))}\n" + f"{correct_string}\n" + ) + lines.append(eval_line) + return lines + def _get_state_dict(self) -> dict[str, Any]: model_state: dict[str, Any] = { **super()._get_state_dict(), From 58e903cf60d207f93bce0a4108ae9c34a8762e3c Mon Sep 17 00:00:00 2001 From: alanakbik Date: Wed, 1 Jan 2025 12:08:03 +0100 Subject: [PATCH 05/46] Formatting --- flair/models/relation_classifier_model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 389bf05157..8aca236230 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -678,7 +678,11 @@ def _print_predictions(self, batch, gold_label_type: str) -> list[str]: # if the gold label is O and is correctly predicted as no label, do not print out as this clutters # the output file with trivial predictions - if not (len(datapoint.get_labels(gold_label_type)) == 1 and datapoint.get_label(gold_label_type).value == "O" and len(datapoint.get_labels("predicted")) == 0): + if not ( + len(datapoint.get_labels(gold_label_type)) == 1 + and datapoint.get_label(gold_label_type).value == "O" + and len(datapoint.get_labels("predicted")) == 0 + ): correct_string = " -> MISMATCH!\n" if g != p else "" eval_line = ( f"{datapoint.text}\n" From 0f81dd29b74af52bef06a0f9df686b9ece995154 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 05:55:14 +0100 Subject: [PATCH 06/46] Add option to customize SegtokTokenizer --- flair/tokenization.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/flair/tokenization.py b/flair/tokenization.py index b377c419e6..13ddf20f28 100644 --- a/flair/tokenization.py +++ b/flair/tokenization.py @@ -1,4 +1,5 @@ import logging +import re import sys from abc import ABC, abstractmethod from typing import Callable @@ -79,10 +80,37 @@ class SegtokTokenizer(Tokenizer): For further details see: https://github.com/fnl/segtok """ - def __init__(self) -> None: + def __init__(self, additional_split_characters: list[str] = None) -> None: + """Initializes the SegtokTokenizer with an optional parameter for additional characters that should always + be split. + + The default behavior uses simple rules to split text into tokens. If you want to ensure that certain characters + always become their own token, you can change default behavior by setting the ``additional_split_characters`` + parameter. + + Args: + additional_split_characters: An optional list of characters that should always be split. For instance, if + you want to make sure that paragraph symbols always become their own token, instantiate with + additional_split_characters = ['§'] + """ + self.additional_split_characters = additional_split_characters super().__init__() + def _add_whitespace_around_symbols(self, text, symbols): + # Build the regular expression pattern dynamically based on the provided symbols + # This will match any character from the symbols list that doesn't have spaces around it + symbol_pattern = f"[{re.escape(''.join(symbols))}]" + + # Add space before and after symbols, where necessary + # Ensure that we are adding a space only if there isn't one already + text = re.sub(r"(\S)(" + symbol_pattern + r")", r"\1 \2", text) # Space before symbol + text = re.sub(r"(" + symbol_pattern + r")(\S)", r"\1 \2", text) # Space after symbol + + return text + def tokenize(self, text: str) -> list[str]: + if self.additional_split_characters: + text = self._add_whitespace_around_symbols(text, self.additional_split_characters) return SegtokTokenizer.run_tokenize(text) @staticmethod From fc786b3d47bda71fb170ef43aa6004c0d5470241 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 05:59:31 +0100 Subject: [PATCH 07/46] Optimize RelationClassifier by filtering long sentences --- flair/models/relation_classifier_model.py | 79 +++++++++++++++-------- 1 file changed, 51 insertions(+), 28 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 8aca236230..ee6e610062 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -256,6 +256,8 @@ def __init__( encoding_strategy: EncodingStrategy = TypedEntityMarker(), zero_tag_value: str = "O", allow_unk_tag: bool = True, + max_allowed_tokens_between_entities: int = 50, + max_encoded_sentence_length: int = 100, **classifierargs, ) -> None: """Initializes a `RelationClassifier`. @@ -271,6 +273,8 @@ def __init__( encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol zero_tag_value: The label to use for out-of-class relations allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. + max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. + max_encoded_sentence_length: The maximum length of encoded sentences. Smaller values speed up processing but potentially remove important context. classifierargs: The remaining parameters passed to the underlying :class:`flair.models.DefaultClassifier` """ # Set label type and prepare label dictionary @@ -278,6 +282,18 @@ def __init__( self._zero_tag_value = zero_tag_value self._allow_unk_tag = allow_unk_tag + if max_encoded_sentence_length - 2 < max_allowed_tokens_between_entities: + logger.warning( + "You set 'max_encoded_sentence_length' to be potentially smaller than 'max_allowed_tokens_between_entities'." + "To ensure that each encoded sentence at least contains the entities in a relation, " + "'max_encoded_sentence_length' should be at least 2 tokens larger than 'max_allowed_tokens_between_entities'." + "I am automatically changing 'max_encoded_sentence_length' to 'max_allowed_tokens_between_entities' + 2" + ) + max_encoded_sentence_length = max_allowed_tokens_between_entities + 2 + + self._max_allowed_tokens_between_entities = max_allowed_tokens_between_entities + self._max_encoded_sentence_length = max_encoded_sentence_length + modified_label_dictionary: Dictionary = Dictionary(add_unk=self._allow_unk_tag) modified_label_dictionary.add_item(self._zero_tag_value) for label in label_dictionary.get_items(): @@ -398,7 +414,7 @@ def _encode_sentence( head: _Entity, tail: _Entity, gold_label: Optional[str] = None, - ) -> EncodedSentence: + ) -> Optional[EncodedSentence]: """Returns a new Sentence object with masked/marked head and tail spans according to the encoding strategy. If provided, the encoded sentence also has the corresponding gold label annotation from :attr:`~label_type`. @@ -422,11 +438,15 @@ def _encode_sentence( # since there may be multiple occurrences of the same entity mentioned in the sentence. # Therefore, we use the span's position in the sentence. encoded_sentence_tokens: list[str] = [] + head_idx = None + tail_idx = None for token in original_sentence: if token is head.span[0]: + head_idx = len(encoded_sentence_tokens) encoded_sentence_tokens.append(self.encoding_strategy.encode_head(head.span, head.label)) elif token is tail.span[0]: + tail_idx = len(encoded_sentence_tokens) encoded_sentence_tokens.append(self.encoding_strategy.encode_tail(tail.span, tail.label)) elif all( @@ -435,6 +455,15 @@ def _encode_sentence( ): encoded_sentence_tokens.append(token.text) + # filter cases in which the distance between the two entities is too large + if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities: + return None + + # remove excess tokens left and right of entity pair to make encoded sentence shorter + encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length( + encoded_sentence_tokens, head_idx, tail_idx + ) + # Create masked sentence encoded_sentence: EncodedSentence = EncodedSentence( " ".join(encoded_sentence_tokens), use_tokenizer=SpaceTokenizer() @@ -448,6 +477,23 @@ def _encode_sentence( encoded_sentence.copy_context_from_sentence(original_sentence) return encoded_sentence + def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, head_idx, tail_idx): + if len(encoded_sentence_tokens) > self._max_encoded_sentence_length: + begin_slice = head_idx if head_idx < tail_idx else tail_idx + end_slice = tail_idx if head_idx < tail_idx else head_idx + distance = end_slice - begin_slice + padding_amount = self._max_encoded_sentence_length - distance + padding_per_side = padding_amount // 2 + begin_slice = begin_slice - padding_per_side if begin_slice - padding_per_side > 0 else 0 + end_slice = ( + end_slice + padding_per_side + if end_slice + padding_per_side < len(encoded_sentence_tokens) + else len(encoded_sentence_tokens) + ) + + encoded_sentence_tokens = encoded_sentence_tokens[begin_slice:end_slice] + return encoded_sentence_tokens + def _encode_sentence_for_inference( self, sentence: Sentence, @@ -520,6 +566,7 @@ def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list encoded_sentence for sentence in sentences for encoded_sentence in self._encode_sentence_for_training(sentence) + if encoded_sentence is not None ] def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset[EncodedSentence]: @@ -643,7 +690,9 @@ def predict( # Deal with the case where all sentences are standard (non-encoded) sentences Sentence.set_context_for_sentences(cast(list[Sentence], sentences)) sentences_with_relation_reference: list[tuple[EncodedSentence, Relation]] = list( - itertools.chain.from_iterable(self._encode_sentence_for_inference(sentence) for sentence in sentences) + itertools.chain.from_iterable( + self._encode_sentence_for_inference(sentence) for sentence in sentences if sentence is not None + ) ) encoded_sentences = [x[0] for x in sentences_with_relation_reference] @@ -667,32 +716,6 @@ def predict( return loss if return_loss else None - def _print_predictions(self, batch, gold_label_type: str) -> list[str]: - lines = [] - for datapoint in batch: - # check if there is a label mismatch - g = [label.labeled_identifier for label in datapoint.get_labels(gold_label_type)] - p = [label.labeled_identifier for label in datapoint.get_labels("predicted")] - g.sort() - p.sort() - - # if the gold label is O and is correctly predicted as no label, do not print out as this clutters - # the output file with trivial predictions - if not ( - len(datapoint.get_labels(gold_label_type)) == 1 - and datapoint.get_label(gold_label_type).value == "O" - and len(datapoint.get_labels("predicted")) == 0 - ): - correct_string = " -> MISMATCH!\n" if g != p else "" - eval_line = ( - f"{datapoint.text}\n" - f" - Gold: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels(gold_label_type))}\n" - f" - Pred: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels('predicted'))}\n" - f"{correct_string}\n" - ) - lines.append(eval_line) - return lines - def _get_state_dict(self) -> dict[str, Any]: model_state: dict[str, Any] = { **super()._get_state_dict(), From 594d8583f3550ba75f7799342cb93f61385e3e5a Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 06:08:30 +0100 Subject: [PATCH 08/46] Optimize RelationClassifier by filtering long sentences --- flair/models/relation_classifier_model.py | 26 +++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index ee6e610062..da25fdefd9 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -716,6 +716,32 @@ def predict( return loss if return_loss else None + def _print_predictions(self, batch, gold_label_type: str) -> list[str]: + lines = [] + for datapoint in batch: + # check if there is a label mismatch + g = [label.labeled_identifier for label in datapoint.get_labels(gold_label_type)] + p = [label.labeled_identifier for label in datapoint.get_labels("predicted")] + g.sort() + p.sort() + + # if the gold label is O and is correctly predicted as no label, do not print out as this clutters + # the output file with trivial predictions + if not ( + len(datapoint.get_labels(gold_label_type)) == 1 + and datapoint.get_label(gold_label_type).value == "O" + and len(datapoint.get_labels("predicted")) == 0 + ): + correct_string = " -> MISMATCH!\n" if g != p else "" + eval_line = ( + f"{datapoint.text}\n" + f" - Gold: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels(gold_label_type))}\n" + f" - Pred: {', '.join(label.value if label.data_point == datapoint else label.labeled_identifier for label in datapoint.get_labels('predicted'))}\n" + f"{correct_string}\n" + ) + lines.append(eval_line) + return lines + def _get_state_dict(self) -> dict[str, Any]: model_state: dict[str, Any] = { **super()._get_state_dict(), From 8fc8a58228b4f548c0eae48589f4fc2f5c1e9ca6 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 06:42:46 +0100 Subject: [PATCH 09/46] Fix serialization --- flair/models/relation_classifier_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index da25fdefd9..3c459da21d 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -755,6 +755,8 @@ def _get_state_dict(self) -> dict[str, Any]: "encoding_strategy": self.encoding_strategy, "zero_tag_value": self.zero_tag_value, "allow_unk_tag": self.allow_unk_tag, + "max_allowed_tokens_between_entities": self._max_allowed_tokens_between_entities, + "max_encoded_sentence_length": self._max_encoded_sentence_length, } return model_state @@ -772,6 +774,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): encoding_strategy=state["encoding_strategy"], zero_tag_value=state["zero_tag_value"], allow_unk_tag=state["allow_unk_tag"], + max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities", 25), + max_encoded_sentence_length=state.get("max_encoded_sentence_length", 50), **kwargs, ) From 1fd18513247ff3ac57ff123ead507206b9be28fb Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 08:47:38 +0100 Subject: [PATCH 10/46] Change context window calculation --- flair/models/relation_classifier_model.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 3c459da21d..3758489080 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -257,7 +257,7 @@ def __init__( zero_tag_value: str = "O", allow_unk_tag: bool = True, max_allowed_tokens_between_entities: int = 50, - max_encoded_sentence_length: int = 100, + max_surrounding_context_length: int = 10, **classifierargs, ) -> None: """Initializes a `RelationClassifier`. @@ -274,7 +274,7 @@ def __init__( zero_tag_value: The label to use for out-of-class relations allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. - max_encoded_sentence_length: The maximum length of encoded sentences. Smaller values speed up processing but potentially remove important context. + max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. classifierargs: The remaining parameters passed to the underlying :class:`flair.models.DefaultClassifier` """ # Set label type and prepare label dictionary @@ -282,17 +282,17 @@ def __init__( self._zero_tag_value = zero_tag_value self._allow_unk_tag = allow_unk_tag - if max_encoded_sentence_length - 2 < max_allowed_tokens_between_entities: + if max_surrounding_context_length - 2 < max_allowed_tokens_between_entities: logger.warning( "You set 'max_encoded_sentence_length' to be potentially smaller than 'max_allowed_tokens_between_entities'." "To ensure that each encoded sentence at least contains the entities in a relation, " "'max_encoded_sentence_length' should be at least 2 tokens larger than 'max_allowed_tokens_between_entities'." "I am automatically changing 'max_encoded_sentence_length' to 'max_allowed_tokens_between_entities' + 2" ) - max_encoded_sentence_length = max_allowed_tokens_between_entities + 2 + max_surrounding_context_length = max_allowed_tokens_between_entities + 2 self._max_allowed_tokens_between_entities = max_allowed_tokens_between_entities - self._max_encoded_sentence_length = max_encoded_sentence_length + self._max_surrounding_context_length = max_surrounding_context_length modified_label_dictionary: Dictionary = Dictionary(add_unk=self._allow_unk_tag) modified_label_dictionary.add_item(self._zero_tag_value) @@ -478,11 +478,11 @@ def _encode_sentence( return encoded_sentence def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, head_idx, tail_idx): - if len(encoded_sentence_tokens) > self._max_encoded_sentence_length: + if len(encoded_sentence_tokens) > self._max_surrounding_context_length: begin_slice = head_idx if head_idx < tail_idx else tail_idx end_slice = tail_idx if head_idx < tail_idx else head_idx distance = end_slice - begin_slice - padding_amount = self._max_encoded_sentence_length - distance + padding_amount = self._max_surrounding_context_length padding_per_side = padding_amount // 2 begin_slice = begin_slice - padding_per_side if begin_slice - padding_per_side > 0 else 0 end_slice = ( @@ -756,7 +756,7 @@ def _get_state_dict(self) -> dict[str, Any]: "zero_tag_value": self.zero_tag_value, "allow_unk_tag": self.allow_unk_tag, "max_allowed_tokens_between_entities": self._max_allowed_tokens_between_entities, - "max_encoded_sentence_length": self._max_encoded_sentence_length, + "max_encoded_sentence_length": self._max_surrounding_context_length, } return model_state From 7f89bb093a5ec1ae1fc26152a6a8e003c2c35748 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 08:51:12 +0100 Subject: [PATCH 11/46] Change context window calculation --- flair/models/relation_classifier_model.py | 32 +++++++---------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 3758489080..247c7f11a7 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -282,15 +282,6 @@ def __init__( self._zero_tag_value = zero_tag_value self._allow_unk_tag = allow_unk_tag - if max_surrounding_context_length - 2 < max_allowed_tokens_between_entities: - logger.warning( - "You set 'max_encoded_sentence_length' to be potentially smaller than 'max_allowed_tokens_between_entities'." - "To ensure that each encoded sentence at least contains the entities in a relation, " - "'max_encoded_sentence_length' should be at least 2 tokens larger than 'max_allowed_tokens_between_entities'." - "I am automatically changing 'max_encoded_sentence_length' to 'max_allowed_tokens_between_entities' + 2" - ) - max_surrounding_context_length = max_allowed_tokens_between_entities + 2 - self._max_allowed_tokens_between_entities = max_allowed_tokens_between_entities self._max_surrounding_context_length = max_surrounding_context_length @@ -478,20 +469,17 @@ def _encode_sentence( return encoded_sentence def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, head_idx, tail_idx): - if len(encoded_sentence_tokens) > self._max_surrounding_context_length: - begin_slice = head_idx if head_idx < tail_idx else tail_idx - end_slice = tail_idx if head_idx < tail_idx else head_idx - distance = end_slice - begin_slice - padding_amount = self._max_surrounding_context_length - padding_per_side = padding_amount // 2 - begin_slice = begin_slice - padding_per_side if begin_slice - padding_per_side > 0 else 0 - end_slice = ( - end_slice + padding_per_side - if end_slice + padding_per_side < len(encoded_sentence_tokens) - else len(encoded_sentence_tokens) - ) + begin_slice = head_idx if head_idx < tail_idx else tail_idx + end_slice = tail_idx if head_idx < tail_idx else head_idx + padding_amount = self._max_surrounding_context_length + begin_slice = begin_slice - padding_amount if begin_slice - padding_amount > 0 else 0 + end_slice = ( + end_slice + padding_amount + if end_slice + padding_amount < len(encoded_sentence_tokens) + else len(encoded_sentence_tokens) + ) - encoded_sentence_tokens = encoded_sentence_tokens[begin_slice:end_slice] + encoded_sentence_tokens = encoded_sentence_tokens[begin_slice:end_slice] return encoded_sentence_tokens def _encode_sentence_for_inference( From 70148da636493e345f58b414a70c5e777d0bb3ad Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 12:45:22 +0100 Subject: [PATCH 12/46] Add sanity check to ensure entities are not contained in one another --- flair/models/relation_classifier_model.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 247c7f11a7..7faca14783 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -421,6 +421,12 @@ def _encode_sentence( original_sentence: Sentence = head.span.sentence assert original_sentence is tail.span.sentence, "The head and tail need to come from the same sentence." + # Sanity check: Do not create a labeled span if one entity contains the other + if head.span[0].idx <= tail.span[0].idx and head.span[-1].idx >= tail.span[-1].idx: + return None + if head.span[0].idx >= tail.span[0].idx and head.span[-1].idx <= tail.span[-1].idx: + return None + # Pre-compute non-leading head and tail tokens for entity masking non_leading_head_tokens: list[Token] = head.span.tokens[1:] non_leading_tail_tokens: list[Token] = tail.span.tokens[1:] @@ -683,6 +689,8 @@ def predict( ) ) + sentences_with_relation_reference = [item for item in sentences_with_relation_reference if item[0] is not None] + encoded_sentences = [x[0] for x in sentences_with_relation_reference] loss = super().predict( encoded_sentences, @@ -744,7 +752,7 @@ def _get_state_dict(self) -> dict[str, Any]: "zero_tag_value": self.zero_tag_value, "allow_unk_tag": self.allow_unk_tag, "max_allowed_tokens_between_entities": self._max_allowed_tokens_between_entities, - "max_encoded_sentence_length": self._max_surrounding_context_length, + "max_surrounding_context_length": self._max_surrounding_context_length, } return model_state @@ -763,7 +771,7 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): zero_tag_value=state["zero_tag_value"], allow_unk_tag=state["allow_unk_tag"], max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities", 25), - max_encoded_sentence_length=state.get("max_encoded_sentence_length", 50), + max_surrounding_context_length=state.get("max_surrounding_context_length", 50), **kwargs, ) From f50c3b3dfae80b6c26777528ae03e09592871ac3 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 15:41:43 +0100 Subject: [PATCH 13/46] Fix slicing such that left and right context are of equal length --- flair/models/relation_classifier_model.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 7faca14783..dadc17c053 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -256,7 +256,7 @@ def __init__( encoding_strategy: EncodingStrategy = TypedEntityMarker(), zero_tag_value: str = "O", allow_unk_tag: bool = True, - max_allowed_tokens_between_entities: int = 50, + max_allowed_tokens_between_entities: int = 20, max_surrounding_context_length: int = 10, **classifierargs, ) -> None: @@ -456,6 +456,8 @@ def _encode_sentence( if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities: return None + print(head_idx, tail_idx) + # remove excess tokens left and right of entity pair to make encoded sentence shorter encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length( encoded_sentence_tokens, head_idx, tail_idx @@ -480,8 +482,8 @@ def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, padding_amount = self._max_surrounding_context_length begin_slice = begin_slice - padding_amount if begin_slice - padding_amount > 0 else 0 end_slice = ( - end_slice + padding_amount - if end_slice + padding_amount < len(encoded_sentence_tokens) + end_slice + padding_amount + 1 + if end_slice + padding_amount + 1 < len(encoded_sentence_tokens) else len(encoded_sentence_tokens) ) @@ -689,7 +691,9 @@ def predict( ) ) - sentences_with_relation_reference = [item for item in sentences_with_relation_reference if item[0] is not None] + sentences_with_relation_reference = [ + item for item in sentences_with_relation_reference if item[0] is not None + ] encoded_sentences = [x[0] for x in sentences_with_relation_reference] loss = super().predict( From 142703b0fd6ad5d95d25d30f8054ae063c2d8caf Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 21:44:28 +0100 Subject: [PATCH 14/46] Make mypy happy --- flair/models/relation_classifier_model.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index dadc17c053..fe52791479 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -435,8 +435,8 @@ def _encode_sentence( # since there may be multiple occurrences of the same entity mentioned in the sentence. # Therefore, we use the span's position in the sentence. encoded_sentence_tokens: list[str] = [] - head_idx = None - tail_idx = None + head_idx = -10000 + tail_idx = 10000 for token in original_sentence: if token is head.span[0]: head_idx = len(encoded_sentence_tokens) @@ -456,8 +456,6 @@ def _encode_sentence( if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities: return None - print(head_idx, tail_idx) - # remove excess tokens left and right of entity pair to make encoded sentence shorter encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length( encoded_sentence_tokens, head_idx, tail_idx @@ -511,13 +509,15 @@ def _encode_sentence_for_inference( Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence """ for head, tail, gold_label in self._entity_pair_permutations(sentence): - masked_sentence: EncodedSentence = self._encode_sentence( + masked_sentence = self._encode_sentence( head=head, tail=tail, gold_label=gold_label if gold_label is not None else self.zero_tag_value, ) original_relation: Relation = Relation(first=head.span, second=tail.span) - yield masked_sentence, original_relation + + if masked_sentence is not None: + yield masked_sentence, original_relation def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedSentence]: """Create Encoded Sentences and Relation pairs for Training. @@ -534,13 +534,14 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS else: continue # Skip generated data points that do not express an originally annotated relation - masked_sentence: EncodedSentence = self._encode_sentence( + masked_sentence = self._encode_sentence( head=head, tail=tail, gold_label=gold_label, ) - yield masked_sentence + if masked_sentence is not None: + yield masked_sentence def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list[EncodedSentence]: """Transforms sentences into encoded sentences specific to the `RelationClassifier`. @@ -562,7 +563,6 @@ def transform_sentence(self, sentences: Union[Sentence, list[Sentence]]) -> list encoded_sentence for sentence in sentences for encoded_sentence in self._encode_sentence_for_training(sentence) - if encoded_sentence is not None ] def transform_dataset(self, dataset: Dataset[Sentence]) -> FlairDatapointDataset[EncodedSentence]: @@ -691,10 +691,6 @@ def predict( ) ) - sentences_with_relation_reference = [ - item for item in sentences_with_relation_reference if item[0] is not None - ] - encoded_sentences = [x[0] for x in sentences_with_relation_reference] loss = super().predict( encoded_sentences, From 3ad499bb4ecb89f7d1fac059ef3303b920dacfe8 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 21:52:28 +0100 Subject: [PATCH 15/46] Remove unnecessary if statement --- flair/models/relation_classifier_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index fe52791479..eb7f4ab73b 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -686,9 +686,7 @@ def predict( # Deal with the case where all sentences are standard (non-encoded) sentences Sentence.set_context_for_sentences(cast(list[Sentence], sentences)) sentences_with_relation_reference: list[tuple[EncodedSentence, Relation]] = list( - itertools.chain.from_iterable( - self._encode_sentence_for_inference(sentence) for sentence in sentences if sentence is not None - ) + itertools.chain.from_iterable(self._encode_sentence_for_inference(sentence) for sentence in sentences) ) encoded_sentences = [x[0] for x in sentences_with_relation_reference] From 1d7a5fb3df15ed86dab1df125d808b91e431f13d Mon Sep 17 00:00:00 2001 From: alanakbik Date: Thu, 2 Jan 2025 22:28:53 +0100 Subject: [PATCH 16/46] make mypy happy --- flair/tokenization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/tokenization.py b/flair/tokenization.py index 13ddf20f28..10d5a79960 100644 --- a/flair/tokenization.py +++ b/flair/tokenization.py @@ -2,7 +2,7 @@ import re import sys from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, Optional from segtok.segmenter import split_single from segtok.tokenizer import split_contractions, word_tokenizer @@ -80,7 +80,7 @@ class SegtokTokenizer(Tokenizer): For further details see: https://github.com/fnl/segtok """ - def __init__(self, additional_split_characters: list[str] = None) -> None: + def __init__(self, additional_split_characters: Optional[list[str]] = None) -> None: """Initializes the SegtokTokenizer with an optional parameter for additional characters that should always be split. From 46379c9142a52d36c910f6f00bc5066a40f24a31 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Fri, 3 Jan 2025 14:25:46 +0100 Subject: [PATCH 17/46] Removed deprecated deepncm classifier file --- flair/models/deepncm_classification_model.py | 11 ----------- 1 file changed, 11 deletions(-) delete mode 100644 flair/models/deepncm_classification_model.py diff --git a/flair/models/deepncm_classification_model.py b/flair/models/deepncm_classification_model.py deleted file mode 100644 index be1b5788a0..0000000000 --- a/flair/models/deepncm_classification_model.py +++ /dev/null @@ -1,11 +0,0 @@ -import logging -from typing import Literal, Optional - -import torch - -import flair -from flair.data import Dictionary - -log = logging.getLogger("flair") - - From cab51053956347c269f5e4a9802c848778a31cb9 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Fri, 3 Jan 2025 14:26:13 +0100 Subject: [PATCH 18/46] Slightly refactored deepncm trainer plugin --- .../functional/deepncm_trainer_plugin.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py index 981d413d61..006396b760 100644 --- a/flair/trainers/plugins/functional/deepncm_trainer_plugin.py +++ b/flair/trainers/plugins/functional/deepncm_trainer_plugin.py @@ -1,3 +1,5 @@ +from collections.abc import Iterable + import torch from flair.models import MultitaskModel @@ -11,32 +13,29 @@ class DeepNCMPlugin(TrainerPlugin): Handles both multitask and single-task scenarios. """ - def _process_models(self, operation: str): - """Process updates for all DeepNCMDecoder decoders in the trainer. - - Args: - operation (str): The operation to perform ('condensation' or 'update') - """ + @property + def decoders(self) -> Iterable[DeepNCMDecoder]: + """Iterator over all DeepNCMDecoder decoders in the trainer.""" model = self.trainer.model models = model.tasks.values() if isinstance(model, MultitaskModel) else [model] for sub_model in models: if hasattr(sub_model, "decoder") and isinstance(sub_model.decoder, DeepNCMDecoder): - if operation == "condensation" and sub_model.decoder.mean_update_method == "condensation": - sub_model.decoder.class_counts.data = torch.ones_like(sub_model.decoder.class_counts) - elif operation == "update": - sub_model.decoder.update_prototypes() + yield sub_model.decoder @TrainerPlugin.hook def after_training_epoch(self, **kwargs): - """Update prototypes after each training epoch.""" - self._process_models("condensation") + """Reset class counts after each training epoch.""" + for decoder in self.decoders: + if decoder.mean_update_method == "condensation": + decoder.class_counts.data = torch.ones_like(decoder.class_counts) @TrainerPlugin.hook def after_training_batch(self, **kwargs): """Update prototypes after each training batch.""" - self._process_models("update") + for decoder in self.decoders: + decoder.update_prototypes() def __str__(self) -> str: return "DeepNCMPlugin" From 5ae1508e44dbe19d0dc913240c35db205654cc0b Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Fri, 3 Jan 2025 14:26:22 +0100 Subject: [PATCH 19/46] Fixed formatting --- flair/nn/decoder.py | 1 - flair/nn/model.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/flair/nn/decoder.py b/flair/nn/decoder.py index b5fc49ecf0..5499f03fe5 100644 --- a/flair/nn/decoder.py +++ b/flair/nn/decoder.py @@ -163,7 +163,6 @@ def __init__( decay - after every batch, multi_label: Whether to predict multiple labels per sentence (default is False, and performs multi-class clsasification). """ - super().__init__() self.label_dictionary = label_dictionary diff --git a/flair/nn/model.py b/flair/nn/model.py index 69c51f7a5e..cc38c56c5a 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import Counter from pathlib import Path -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional, Union import torch.nn from torch import Tensor @@ -780,7 +780,7 @@ def forward_loss(self, sentences: list[DT]) -> tuple[torch.Tensor, int]: # decode, passing label tensor if needed, such as for prototype updates if "label_tensor" in inspect.signature(self.decoder.forward).parameters: - scores = self.decoder(data_point_tensor, label_tensor) + scores = self.decoder(data_point_tensor, label_tensor=label_tensor) else: scores = self.decoder(data_point_tensor) @@ -817,7 +817,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ) -> Optional[Union[List[DT], Tuple[float, int]]]: + ) -> Optional[Union[list[DT], tuple[float, int]]]: """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: From 8398be2abc194a735140771dfc1373742557ca90 Mon Sep 17 00:00:00 2001 From: Max Ploner Date: Fri, 3 Jan 2025 15:25:58 +0100 Subject: [PATCH 20/46] Removed predict return types The specified return types were overly resetrictive (e.g. did not include sequence labelling models) --- flair/nn/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index 623907727b..7c10c82ee8 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -817,7 +817,7 @@ def predict( label_name: Optional[str] = None, return_loss: bool = False, embedding_storage_mode: EmbeddingStorageMode = "none", - ) -> Optional[Union[list[DT], tuple[float, int]]]: + ): """Predicts the class labels for the given sentences. The labels are directly added to the sentences. Args: From 1c85e086966881b9cf82f3ea2093d978bd94a67b Mon Sep 17 00:00:00 2001 From: alanakbik Date: Sat, 11 Jan 2025 15:17:48 +0100 Subject: [PATCH 21/46] Add options to load full documents as Sentence objects --- flair/datasets/sequence_labeling.py | 25 +++++++- .../trivial_bioes_with_boundaries/dev.txt | 37 ++++++++++++ .../trivial_bioes_with_boundaries/test.txt | 39 ++++++++++++ .../trivial_bioes_with_boundaries/train.txt | 59 +++++++++++++++++++ tests/test_datasets.py | 48 +++++++++++++++ 5 files changed, 207 insertions(+), 1 deletion(-) create mode 100644 tests/resources/tasks/trivial/trivial_bioes_with_boundaries/dev.txt create mode 100644 tests/resources/tasks/trivial/trivial_bioes_with_boundaries/test.txt create mode 100644 tests/resources/tasks/trivial/trivial_bioes_with_boundaries/train.txt diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index b2ab2f45dd..479e9e71e5 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -326,6 +326,8 @@ def __init__( label_name_map: Optional[dict[str, str]] = None, banned_sentences: Optional[list[str]] = None, default_whitespace_after: int = 1, + every_sentence_is_independent: bool = False, + documents_as_sentences: bool = False, **corpusargs, ) -> None: r"""Instantiates a Corpus from CoNLL column-formatted task data such as CoNLL03 or CoNLL2000. @@ -361,6 +363,8 @@ def __init__( skip_first_line=skip_first_line, label_name_map=label_name_map, default_whitespace_after=default_whitespace_after, + every_sentence_is_independent=every_sentence_is_independent, + documents_as_sentences=documents_as_sentences, ) for train_file in train_files ] @@ -385,6 +389,8 @@ def __init__( skip_first_line=skip_first_line, label_name_map=label_name_map, default_whitespace_after=default_whitespace_after, + every_sentence_is_independent=every_sentence_is_independent, + documents_as_sentences=documents_as_sentences, ) for test_file in test_files ] @@ -409,6 +415,8 @@ def __init__( skip_first_line=skip_first_line, label_name_map=label_name_map, default_whitespace_after=default_whitespace_after, + every_sentence_is_independent=every_sentence_is_independent, + documents_as_sentences=documents_as_sentences, ) for dev_file in dev_files ] @@ -481,10 +489,12 @@ def __init__( banned_sentences: Optional[list[str]] = None, in_memory: bool = True, document_separator_token: Optional[str] = None, + every_sentence_is_independent: bool = False, encoding: str = "utf-8", skip_first_line: bool = False, label_name_map: Optional[dict[str, str]] = None, default_whitespace_after: int = 1, + documents_as_sentences: bool = False, ) -> None: r"""Instantiates a column dataset. @@ -505,9 +515,11 @@ def __init__( self.column_delimiter = re.compile(column_delimiter) self.comment_symbol = comment_symbol self.document_separator_token = document_separator_token + self.every_sentence_is_independent = every_sentence_is_independent self.label_name_map = label_name_map self.banned_sentences = banned_sentences self.default_whitespace_after = default_whitespace_after + self.documents_as_sentences = documents_as_sentences # store either Sentence objects in memory, or only file offsets self.in_memory = in_memory @@ -702,6 +714,9 @@ def _convert_lines_to_sentence( if sentence.to_original_text() == self.document_separator_token: sentence.is_document_boundary = True + if self.every_sentence_is_independent or self.documents_as_sentences: + sentence.is_document_boundary = True + # add span labels if span_level_tag_columns: for span_column in span_level_tag_columns: @@ -818,6 +833,13 @@ def _remap_label(self, tag): return tag def __line_completes_sentence(self, line: str) -> bool: + + if self.documents_as_sentences: + if line.startswith(self.document_separator_token): + return True + else: + return False + sentence_completed = line.isspace() or line == "" return sentence_completed @@ -5035,7 +5057,8 @@ def __init__( test_file=None, column_format=columns, in_memory=in_memory, - sample_missing_splits=False, # No test data is available, so do not shrink dev data for shared task preparation! + sample_missing_splits=False, + # No test data is available, so do not shrink dev data for shared task preparation! **corpusargs, ) corpora.append(corpus) diff --git a/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/dev.txt b/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/dev.txt new file mode 100644 index 0000000000..b741ce5ab7 --- /dev/null +++ b/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/dev.txt @@ -0,0 +1,37 @@ +this O +is O +New B-LOC +York I-LOC + +here O +is O +New B-LOC +York I-LOC + +I O +like O +New B-LOC +York I-LOC + +we O +like O +New B-LOC +York I-LOC + +-DOCSTART- + +this O +is O +Berlin B-LOC + +here O +is O +Berlin B-LOC + +I O +like O +Berlin B-LOC + +we O +like O +Berlin B-LOC \ No newline at end of file diff --git a/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/test.txt b/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/test.txt new file mode 100644 index 0000000000..64a127bd88 --- /dev/null +++ b/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/test.txt @@ -0,0 +1,39 @@ +this O +is O +New B-LOC +York I-LOC + +here O +is O +New B-LOC +York I-LOC + +I O +like O +New B-LOC +York I-LOC + +we O +like O +New B-LOC +York I-LOC + +-DOCSTART- + +this O +is O +Berlin B-LOC + +here O +is O +Berlin B-LOC + +I O +like O +Berlin B-LOC + +we O +like O +Berlin B-LOC + +-DOCSTART- \ No newline at end of file diff --git a/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/train.txt b/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/train.txt new file mode 100644 index 0000000000..4f934bcfd5 --- /dev/null +++ b/tests/resources/tasks/trivial/trivial_bioes_with_boundaries/train.txt @@ -0,0 +1,59 @@ +this O +is O +New B-LOC +York I-LOC + +here O +is O +New B-LOC +York I-LOC + +I O +like O +New B-LOC +York I-LOC + +we O +like O +New B-LOC +York I-LOC + +-DOCSTART- + +this O +is O +Berlin B-LOC + +here O +is O +Berlin B-LOC + +I O +like O +Berlin B-LOC + +we O +like O +Berlin B-LOC + +-DOCSTART- + +this O +is O +New B-LOC +York I-LOC + +here O +is O +New B-LOC +York I-LOC + +I O +like O +New B-LOC +York I-LOC + +we O +like O +New B-LOC +York I-LOC \ No newline at end of file diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 25a99f87e0..8e6a7019b0 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -75,6 +75,54 @@ def test_load_sequence_labeling_data(tasks_base_path): assert len(corpus.test) == 1 +def test_load_sequence_labeling_data_with_boundaries(tasks_base_path): + # get training, test and dev data + corpus = flair.datasets.ColumnCorpus( + tasks_base_path / "trivial" / "trivial_bioes_with_boundaries", column_format={0: "text", 1: "ner"} + ) + + assert len(corpus.train) == 14 + assert len(corpus.dev) == 9 + assert len(corpus.test) == 10 + + # now exclude -DOCSTART- sentences + corpus = flair.datasets.ColumnCorpus( + tasks_base_path / "trivial" / "trivial_bioes_with_boundaries", + column_format={0: "text", 1: "ner"}, + banned_sentences=["-DOCSTART-"], + ) + + assert len(corpus.train) == 12 + assert len(corpus.dev) == 8 + assert len(corpus.test) == 8 + + assert len(corpus.train[0].right_context(5)) == 5 + + # now load whole documents as sentences + corpus = flair.datasets.ColumnCorpus( + tasks_base_path / "trivial" / "trivial_bioes_with_boundaries", + column_format={0: "text", 1: "ner"}, + document_separator_token="-DOCSTART-", + documents_as_sentences=True, + ) + + assert len(corpus.train) == 3 + assert len(corpus.dev) == 2 + assert len(corpus.test) == 2 + + assert len(corpus.train[0].right_context(5)) == 0 + + # ban each boundary but set each sentence to be independent + corpus = flair.datasets.ColumnCorpus( + tasks_base_path / "trivial" / "trivial_bioes_with_boundaries", + column_format={0: "text", 1: "ner"}, + banned_sentences=["-DOCSTART-"], + every_sentence_is_independent=True, + ) + + assert len(corpus.train[0].right_context(5)) == 0 + + def test_load_sequence_labeling_whitespace_after(tasks_base_path): # get training, test and dev data corpus = flair.datasets.ColumnCorpus( From 06a5c0cad99e7576883d359dddcbc8d54bb01d9e Mon Sep 17 00:00:00 2001 From: alanakbik Date: Sat, 11 Jan 2025 15:42:13 +0100 Subject: [PATCH 22/46] Mypy fix --- flair/datasets/sequence_labeling.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 479e9e71e5..80fc6d38ba 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -521,6 +521,12 @@ def __init__( self.default_whitespace_after = default_whitespace_after self.documents_as_sentences = documents_as_sentences + if documents_as_sentences and not document_separator_token: + log.error( + "document_as_sentences was set to True, but no document_separator_token was provided. Please set" + "a value for document_separator_token in order to enable the document_as_sentence functionality." + ) + # store either Sentence objects in memory, or only file offsets self.in_memory = in_memory @@ -834,7 +840,7 @@ def _remap_label(self, tag): def __line_completes_sentence(self, line: str) -> bool: - if self.documents_as_sentences: + if self.documents_as_sentences and self.document_separator_token: if line.startswith(self.document_separator_token): return True else: From 0cca6a8e6296574b357249dc4538f164c7e84d7b Mon Sep 17 00:00:00 2001 From: dobbersc Date: Wed, 15 Jan 2025 01:41:00 +0100 Subject: [PATCH 23/46] Ensure presence of head and tail entity in the original sentence --- flair/models/relation_classifier_model.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index eb7f4ab73b..e60756a74d 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -258,7 +258,7 @@ def __init__( allow_unk_tag: bool = True, max_allowed_tokens_between_entities: int = 20, max_surrounding_context_length: int = 10, - **classifierargs, + **classifierargs: Any, ) -> None: """Initializes a `RelationClassifier`. @@ -435,8 +435,8 @@ def _encode_sentence( # since there may be multiple occurrences of the same entity mentioned in the sentence. # Therefore, we use the span's position in the sentence. encoded_sentence_tokens: list[str] = [] - head_idx = -10000 - tail_idx = 10000 + head_idx: Optional[int] = None + tail_idx: Optional[int] = None for token in original_sentence: if token is head.span[0]: head_idx = len(encoded_sentence_tokens) @@ -452,11 +452,19 @@ def _encode_sentence( ): encoded_sentence_tokens.append(token.text) - # filter cases in which the distance between the two entities is too large + msg: str + if head_idx is None: + msg = f"The head entity ({head!r}) is not located inside the original sentence ({original_sentence!r})." + raise AssertionError(msg) + if tail_idx is None: + msg = f"The tail entity ({tail!r}) is not located inside the original sentence ({original_sentence!r})." + raise AssertionError(msg) + + # Filter cases in which the distance between the two entities is too large if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities: return None - # remove excess tokens left and right of entity pair to make encoded sentence shorter + # Remove excess tokens left and right of entity pair to make encoded sentence shorter encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length( encoded_sentence_tokens, head_idx, tail_idx ) From a0696bbb2df934a209c1cabb3fe38f27a7f46cf8 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Fri, 17 Jan 2025 09:29:29 +0100 Subject: [PATCH 24/46] Update RegexpTagger to be able to specify matching groups --- flair/models/regexp_tagger.py | 32 ++++++++++++++++++++---------- tests/models/test_regexp_tagger.py | 0 2 files changed, 21 insertions(+), 11 deletions(-) create mode 100644 tests/models/test_regexp_tagger.py diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py index e41981c899..1707fee41c 100644 --- a/flair/models/regexp_tagger.py +++ b/flair/models/regexp_tagger.py @@ -45,7 +45,9 @@ def get_token_span(self, span: tuple[int, int]) -> Span: class RegexpTagger: - def __init__(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]) -> None: + def __init__( + self, mapping: Union[list[Union[tuple[str, str], tuple[str, str, int]]], tuple[str, str], tuple[str, str, int]] + ) -> None: r"""This tagger is capable of tagging sentence objects with given regexp -> label mappings. I.e: The tuple (r'(["\'])(?:(?=(\\?))\2.)*?\1', 'QUOTE') maps every match of the regexp to @@ -58,14 +60,18 @@ def __init__(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]) -> No Args: mapping: A list of tuples or a single tuple representing a mapping as regexp -> label """ - self._regexp_mapping: dict[str, typing.Pattern] = {} + self._regexp_mapping: list[str, typing.Pattern, int] = [] self.register_labels(mapping=mapping) + def label_type(self): + for regexp, label, group in self._regexp_mapping: + return label + @property def registered_labels(self): return self._regexp_mapping - def register_labels(self, mapping: Union[list[tuple[str, str]], tuple[str, str]]): + def register_labels(self, mapping: Union[list[tuple[str, str, int]], tuple[str, str, int]]): """Register a regexp -> label mapping. Args: @@ -73,9 +79,14 @@ def register_labels(self, mapping: Union[list[tuple[str, str]], tuple[str, str]] """ mapping = self._listify(mapping) - for regexp, label in mapping: + for entry in mapping: + regexp = entry[0] + label = entry[1] + group = entry[2] if len(entry) > 2 else 0 try: - self._regexp_mapping[label] = re.compile(regexp) + pattern = re.compile(regexp) + self._regexp_mapping.append((pattern, label, group)) + except re.error as err: raise re.error( f"Couldn't compile regexp '{regexp}' for label '{label}'. Aborted with error: '{err.msg}'" @@ -89,10 +100,7 @@ def remove_labels(self, labels: Union[list[str], str]): """ labels = self._listify(labels) - for label in labels: - if not self._regexp_mapping.get(label): - continue - self._regexp_mapping.pop(label) + self._regexp_mapping = [mapping for mapping in self._regexp_mapping if mapping[1] not in labels] @staticmethod def _listify(element: object) -> list: @@ -120,9 +128,11 @@ def _label(self, sentence: Sentence): """ collection = TokenCollection(sentence) - for label, pattern in self._regexp_mapping.items(): + for pattern, label, group in self._regexp_mapping: for match in pattern.finditer(sentence.to_original_text()): - span: tuple[int, int] = match.span() + # print(match) + span: tuple[int, int] = match.span(group) + # print(span) try: token_span = collection.get_token_span(span) except ValueError: diff --git a/tests/models/test_regexp_tagger.py b/tests/models/test_regexp_tagger.py new file mode 100644 index 0000000000..e69de29bb2 From 7eb13a3a1ebc8c82608e43c5932b07da8912eef4 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Fri, 17 Jan 2025 09:30:29 +0100 Subject: [PATCH 25/46] Ruff fixes --- docs/conf.py | 1 - flair/data.py | 7 ++----- flair/datasets/document_classification.py | 4 ++-- flair/distributed_utils.py | 3 ++- flair/models/sequence_tagger_model.py | 1 - flair/nn/model.py | 3 +-- flair/splitter.py | 3 +-- flair/trainers/plugins/functional/checkpoints.py | 2 -- tests/models/test_regexp_tagger.py | 16 ++++++++++++++++ 9 files changed, 24 insertions(+), 16 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 7ac895ac0e..a9194e50cb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -5,7 +5,6 @@ # -- Project information ----------------------------------------------------- from sphinx_github_style import get_linkcode_resolve -from torch.nn import Module version = "0.15.0" release = "0.15.0" diff --git a/flair/data.py b/flair/data.py index 7ee32f40b9..40175418a4 100644 --- a/flair/data.py +++ b/flair/data.py @@ -1389,8 +1389,7 @@ def __init__( sample_missing_splits: Union[bool, str] = True, random_seed: Optional[int] = None, ) -> None: - """ - Constructor method to initialize a :class:`Corpus`. You can define the train, dev and test split + """Constructor method to initialize a :class:`Corpus`. You can define the train, dev and test split by passing the corresponding Dataset object to the constructor. At least one split should be defined. If the option `sample_missing_splits` is set to True, missing splits will be randomly sampled from the train split. @@ -1484,7 +1483,6 @@ def downsample( Returns: A pointer to itself for optional use in method chaining. """ - if downsample_train and self._train is not None: self._train = self._downsample_to_proportion(self._train, percentage, random_seed) @@ -1511,8 +1509,7 @@ def filter_empty_sentences(self): log.info(self) def filter_long_sentences(self, max_charlength: int): - """ - A method that filters all sentences for which the plain text is longer than a specified number of characters. + """A method that filters all sentences for which the plain text is longer than a specified number of characters. This is an in-place operation that directly modifies the Corpus object itself by removing these sentences. diff --git a/flair/datasets/document_classification.py b/flair/datasets/document_classification.py index 363c84e561..82dbda97a1 100644 --- a/flair/datasets/document_classification.py +++ b/flair/datasets/document_classification.py @@ -798,7 +798,7 @@ def __init__( if not rebalance_corpus and dataset == "test": data_file = test_data_file - with open(data_file, "at") as f_p: + with open(data_file, "a") as f_p: current_path = data_path / "aclImdb" / dataset / label for file_name in current_path.iterdir(): if file_name.is_file() and file_name.name.endswith(".txt"): @@ -891,7 +891,7 @@ def __init__( data_path / "original", members=[m for m in f_in.getmembers() if f"{dataset}/{label}" in m.name], ) - with open(f"{data_path}/{dataset}.txt", "at", encoding="utf-8") as f_p: + with open(f"{data_path}/{dataset}.txt", "a", encoding="utf-8") as f_p: current_path = data_path / "original" / dataset / label for file_name in current_path.iterdir(): if file_name.is_file(): diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py index e774084009..99ad28f9c0 100644 --- a/flair/distributed_utils.py +++ b/flair/distributed_utils.py @@ -73,7 +73,8 @@ def aggregate(value, aggregation_fn=np.mean): def validate_corpus_same_each_process(corpus: Corpus) -> None: """Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two - reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable""" + reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable + """ for dataset in [corpus.train, corpus.dev, corpus.test]: if dataset is not None: _validate_dataset_same_each_process(dataset) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index 7bc9ec051b..ea028c3e70 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -286,7 +286,6 @@ def forward_loss(self, sentences: list[Sentence]) -> tuple[torch.Tensor, int]: A tuple consisting of the loss tensor and the number of tokens in the batch. """ - # if there are no sentences, there is no loss if len(sentences) == 0: return torch.tensor(0.0, dtype=torch.float, device=flair.device, requires_grad=True), 0 diff --git a/flair/nn/model.py b/flair/nn/model.py index 7c10c82ee8..54c52eba3e 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -208,8 +208,7 @@ def load(cls, model_path: Union[str, Path, dict[str, Any]]) -> "Model": return model def print_model_card(self): - """ - This method produces a log message that includes all recorded parameters the model was trained with. + """This method produces a log message that includes all recorded parameters the model was trained with. The model card includes information such as the Flair, PyTorch and Transformers versions used during training, and the training parameters. diff --git a/flair/splitter.py b/flair/splitter.py index 6246969f28..ac4a22bde5 100644 --- a/flair/splitter.py +++ b/flair/splitter.py @@ -28,8 +28,7 @@ class SentenceSplitter(ABC): """ def split(self, text: str, link_sentences: bool = True) -> list[Sentence]: - """ - Takes as input a text as a plain string and outputs a list of :class:`flair.data.Sentence` objects. + """Takes as input a text as a plain string and outputs a list of :class:`flair.data.Sentence` objects. If link_sentences is set (by default, it is). The :class:`flair.data.Sentence` objects will include pointers to the preceding and following sentences in the original text. This way, the original sequence information will diff --git a/flair/trainers/plugins/functional/checkpoints.py b/flair/trainers/plugins/functional/checkpoints.py index 4261a56a23..cf1a21468a 100644 --- a/flair/trainers/plugins/functional/checkpoints.py +++ b/flair/trainers/plugins/functional/checkpoints.py @@ -1,8 +1,6 @@ import logging from typing import Any -import torch - from flair.trainers.plugins.base import TrainerPlugin log = logging.getLogger("flair") diff --git a/tests/models/test_regexp_tagger.py b/tests/models/test_regexp_tagger.py index e69de29bb2..1e3f8d2a39 100644 --- a/tests/models/test_regexp_tagger.py +++ b/tests/models/test_regexp_tagger.py @@ -0,0 +1,16 @@ +from flair.data import Sentence +from flair.models import RegexpTagger + + +def test_regexp_tagger(): + + sentence = Sentence('Der sagte: "das ist durchaus interessant"') + + tagger = RegexpTagger( + mapping=[(r'["„»]((?:(?=(\\?))\2.)*?)[”"“«]', "quote_part", 1), (r'["„»]((?:(?=(\\?))\2.)*?)[”"“«]', "quote")] + ) + + tagger.predict(sentence) + + assert sentence.get_label("quote_part").data_point.text == "das ist durchaus interessant" + assert sentence.get_label("quote").data_point.text == '"das ist durchaus interessant"' From 0128130f89fbf261eb4859b1dcc3308241ab7b3e Mon Sep 17 00:00:00 2001 From: alanakbik Date: Fri, 17 Jan 2025 09:37:11 +0100 Subject: [PATCH 26/46] Make mypy happy --- flair/models/regexp_tagger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py index 1707fee41c..37906d55aa 100644 --- a/flair/models/regexp_tagger.py +++ b/flair/models/regexp_tagger.py @@ -60,7 +60,7 @@ def __init__( Args: mapping: A list of tuples or a single tuple representing a mapping as regexp -> label """ - self._regexp_mapping: list[str, typing.Pattern, int] = [] + self._regexp_mapping: list = [] self.register_labels(mapping=mapping) def label_type(self): @@ -71,7 +71,7 @@ def label_type(self): def registered_labels(self): return self._regexp_mapping - def register_labels(self, mapping: Union[list[tuple[str, str, int]], tuple[str, str, int]]): + def register_labels(self, mapping): """Register a regexp -> label mapping. Args: From 4ed2e49d16f4a7f6937086244f4cc29af28ce08d Mon Sep 17 00:00:00 2001 From: dobbersc Date: Tue, 21 Jan 2025 23:11:15 +0100 Subject: [PATCH 27/46] Refactor `_slice_encoded_sentence_to_max_allowed_length` to use min/max operations instead of if-statements --- flair/models/relation_classifier_model.py | 47 +++++++++++++++-------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index e60756a74d..1d3136a9c0 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -400,6 +400,34 @@ def _entity_pair_permutations( yield head, tail, gold_label + @staticmethod + def _truncate_context_around_entities( + encoded_sentence_tokens: list[str], + head_idx: int, + tail_idx: int, + context_length: int, + ) -> list[str]: + """Truncates the encoded sentence to include the head and tail entity and their surrounding context. + The context, in between the entity pairs will always be included. + + Args: + encoded_sentence_tokens: The list of tokens corresponding to the encoded sentence. + head_idx: The index of the head entity in the token list. + tail_idx: The index of the tail entity in the token list. + context_length: The maximum number of tokens to include as surrounding context around the head and tail entities. + + Returns: + The tokens of the truncated sentence. + """ + begin_slice: int = min(head_idx, tail_idx) + end_slice: int = max(head_idx, tail_idx) + + # Preserve context around the entities. Always include their in-between context. + begin_slice = max(begin_slice - context_length, 0) + end_slice = min(end_slice + context_length + 1, len(encoded_sentence_tokens)) + + return encoded_sentence_tokens[begin_slice:end_slice] + def _encode_sentence( self, head: _Entity, @@ -465,8 +493,8 @@ def _encode_sentence( return None # Remove excess tokens left and right of entity pair to make encoded sentence shorter - encoded_sentence_tokens = self._slice_encoded_sentence_to_max_allowed_length( - encoded_sentence_tokens, head_idx, tail_idx + encoded_sentence_tokens = self._truncate_context_around_entities( + encoded_sentence_tokens, head_idx, tail_idx, self._max_surrounding_context_length ) # Create masked sentence @@ -479,23 +507,10 @@ def _encode_sentence( # Using the sentence label instead of annotating a separate `Relation` object is easier to manage since, # during prediction, the forward pass does not need any knowledge about the entities in the sentence. encoded_sentence.add_label(typename=self.label_type, value=gold_label, score=1.0) + encoded_sentence.copy_context_from_sentence(original_sentence) return encoded_sentence - def _slice_encoded_sentence_to_max_allowed_length(self, encoded_sentence_tokens, head_idx, tail_idx): - begin_slice = head_idx if head_idx < tail_idx else tail_idx - end_slice = tail_idx if head_idx < tail_idx else head_idx - padding_amount = self._max_surrounding_context_length - begin_slice = begin_slice - padding_amount if begin_slice - padding_amount > 0 else 0 - end_slice = ( - end_slice + padding_amount + 1 - if end_slice + padding_amount + 1 < len(encoded_sentence_tokens) - else len(encoded_sentence_tokens) - ) - - encoded_sentence_tokens = encoded_sentence_tokens[begin_slice:end_slice] - return encoded_sentence_tokens - def _encode_sentence_for_inference( self, sentence: Sentence, From 306412fcdde07e2faea6f24581662f7c3fd72417 Mon Sep 17 00:00:00 2001 From: dobbersc Date: Tue, 21 Jan 2025 23:31:17 +0100 Subject: [PATCH 28/46] Allow to disable the `max_allowed_tokens_between_entities` and `max_surrounding_context_length` filter for backwards compatibility --- flair/models/relation_classifier_model.py | 28 +++++++++++++---------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 1d3136a9c0..14009075df 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -256,8 +256,8 @@ def __init__( encoding_strategy: EncodingStrategy = TypedEntityMarker(), zero_tag_value: str = "O", allow_unk_tag: bool = True, - max_allowed_tokens_between_entities: int = 20, - max_surrounding_context_length: int = 10, + max_allowed_tokens_between_entities: Optional[int] = 20, + max_surrounding_context_length: Optional[int] = 10, **classifierargs: Any, ) -> None: """Initializes a `RelationClassifier`. @@ -273,8 +273,8 @@ def __init__( encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol zero_tag_value: The label to use for out-of-class relations allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. - max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. - max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. + max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. If `None`, the filter will be disabled. + max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. The context, in between the entity pairs will always be included. If `None`, the filter will be disabled. classifierargs: The remaining parameters passed to the underlying :class:`flair.models.DefaultClassifier` """ # Set label type and prepare label dictionary @@ -489,13 +489,17 @@ def _encode_sentence( raise AssertionError(msg) # Filter cases in which the distance between the two entities is too large - if abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities: + if ( + self._max_allowed_tokens_between_entities is not None + and abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities + ): return None # Remove excess tokens left and right of entity pair to make encoded sentence shorter - encoded_sentence_tokens = self._truncate_context_around_entities( - encoded_sentence_tokens, head_idx, tail_idx, self._max_surrounding_context_length - ) + if self._max_surrounding_context_length is not None: + encoded_sentence_tokens = self._truncate_context_around_entities( + encoded_sentence_tokens, head_idx, tail_idx, self._max_surrounding_context_length + ) # Create masked sentence encoded_sentence: EncodedSentence = EncodedSentence( @@ -532,7 +536,7 @@ def _encode_sentence_for_inference( Returns: Encoded sentences annotated with their gold relation and the corresponding relation in the original sentence """ for head, tail, gold_label in self._entity_pair_permutations(sentence): - masked_sentence = self._encode_sentence( + masked_sentence: Optional[EncodedSentence] = self._encode_sentence( head=head, tail=tail, gold_label=gold_label if gold_label is not None else self.zero_tag_value, @@ -557,7 +561,7 @@ def _encode_sentence_for_training(self, sentence: Sentence) -> Iterator[EncodedS else: continue # Skip generated data points that do not express an originally annotated relation - masked_sentence = self._encode_sentence( + masked_sentence: Optional[EncodedSentence] = self._encode_sentence( head=head, tail=tail, gold_label=gold_label, @@ -791,8 +795,8 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): encoding_strategy=state["encoding_strategy"], zero_tag_value=state["zero_tag_value"], allow_unk_tag=state["allow_unk_tag"], - max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities", 25), - max_surrounding_context_length=state.get("max_surrounding_context_length", 50), + max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities"), + max_surrounding_context_length=state.get("max_surrounding_context_length"), **kwargs, ) From 4fc487807733182acfbd8fad1335297679aa7ed7 Mon Sep 17 00:00:00 2001 From: dobbersc Date: Wed, 22 Jan 2025 00:59:02 +0100 Subject: [PATCH 29/46] Rearrange parameters and make sentence filters public --- flair/models/relation_classifier_model.py | 30 +++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/flair/models/relation_classifier_model.py b/flair/models/relation_classifier_model.py index 14009075df..61c0454244 100644 --- a/flair/models/relation_classifier_model.py +++ b/flair/models/relation_classifier_model.py @@ -252,12 +252,12 @@ def __init__( entity_label_types: Union[str, Sequence[str], dict[str, Optional[set[str]]]], entity_pair_labels: Optional[set[tuple[str, str]]] = None, entity_threshold: Optional[float] = None, + max_allowed_tokens_between_entities: Optional[int] = 20, + max_surrounding_context_length: Optional[int] = 10, cross_augmentation: bool = True, encoding_strategy: EncodingStrategy = TypedEntityMarker(), zero_tag_value: str = "O", allow_unk_tag: bool = True, - max_allowed_tokens_between_entities: Optional[int] = 20, - max_surrounding_context_length: Optional[int] = 10, **classifierargs: Any, ) -> None: """Initializes a `RelationClassifier`. @@ -269,12 +269,12 @@ def __init__( entity_label_types: A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only 'PER' and 'ORG' labels from a NER-tagger: `{'ner': {'PER', 'ORG'}}`. To use all labels from 'ner', pass 'ner'. entity_pair_labels: A set of valid relation entity pair combinations, used as relation candidates. Specify valid entity pairs in a set of tuples of labels (, ). E.g. for the `born_in` relation, only relations from 'PER' to 'LOC' make sense. Here, relations from 'PER' to 'PER' are not meaningful, so it is advised to specify the `entity_pair_labels` as `{('PER', 'ORG')}`. This setting may help to reduce the number of relation candidates. Leaving this parameter as `None` (default) disables the relation-candidate-filter, i.e. the model classifies the relation for each entity pair in the cross product of *all* entity pairs (inefficient). entity_threshold: Only pre-labelled entities above this threshold are taken into account by the model. + max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. If `None`, the filter will be disabled. + max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. The context, in between the entity pairs will always be included. If `None`, the filter will be disabled. cross_augmentation: If `True`, use cross augmentation to transform `Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus`, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence. encoding_strategy: An instance of a class conforming the :class:`EncodingStrategy` protocol zero_tag_value: The label to use for out-of-class relations allow_unk_tag: If `False`, removes `` from the passed label dictionary, otherwise do nothing. - max_allowed_tokens_between_entities: The maximum allowed number of allowed tokens between entities. All other entity pairs are filtered from consideration. If `None`, the filter will be disabled. - max_surrounding_context_length: The maximum length of context around entity pairs that will be considered. The context, in between the entity pairs will always be included. If `None`, the filter will be disabled. classifierargs: The remaining parameters passed to the underlying :class:`flair.models.DefaultClassifier` """ # Set label type and prepare label dictionary @@ -282,9 +282,6 @@ def __init__( self._zero_tag_value = zero_tag_value self._allow_unk_tag = allow_unk_tag - self._max_allowed_tokens_between_entities = max_allowed_tokens_between_entities - self._max_surrounding_context_length = max_surrounding_context_length - modified_label_dictionary: Dictionary = Dictionary(add_unk=self._allow_unk_tag) modified_label_dictionary.add_item(self._zero_tag_value) for label in label_dictionary.get_items(): @@ -309,6 +306,8 @@ def __init__( self.entity_pair_labels = entity_pair_labels self.entity_threshold = entity_threshold + self.max_allowed_tokens_between_entities = max_allowed_tokens_between_entities + self.max_surrounding_context_length = max_surrounding_context_length self.cross_augmentation = cross_augmentation self.encoding_strategy = encoding_strategy @@ -408,6 +407,7 @@ def _truncate_context_around_entities( context_length: int, ) -> list[str]: """Truncates the encoded sentence to include the head and tail entity and their surrounding context. + The context, in between the entity pairs will always be included. Args: @@ -490,15 +490,15 @@ def _encode_sentence( # Filter cases in which the distance between the two entities is too large if ( - self._max_allowed_tokens_between_entities is not None - and abs(head_idx - tail_idx) > self._max_allowed_tokens_between_entities + self.max_allowed_tokens_between_entities is not None + and abs(head_idx - tail_idx) > self.max_allowed_tokens_between_entities ): return None # Remove excess tokens left and right of entity pair to make encoded sentence shorter - if self._max_surrounding_context_length is not None: + if self.max_surrounding_context_length is not None: encoded_sentence_tokens = self._truncate_context_around_entities( - encoded_sentence_tokens, head_idx, tail_idx, self._max_surrounding_context_length + encoded_sentence_tokens, head_idx, tail_idx, self.max_surrounding_context_length ) # Create masked sentence @@ -772,12 +772,12 @@ def _get_state_dict(self) -> dict[str, Any]: "entity_label_types": self.entity_label_types, "entity_pair_labels": self.entity_pair_labels, "entity_threshold": self.entity_threshold, + "max_allowed_tokens_between_entities": self.max_allowed_tokens_between_entities, + "max_surrounding_context_length": self.max_surrounding_context_length, "cross_augmentation": self.cross_augmentation, "encoding_strategy": self.encoding_strategy, "zero_tag_value": self.zero_tag_value, "allow_unk_tag": self.allow_unk_tag, - "max_allowed_tokens_between_entities": self._max_allowed_tokens_between_entities, - "max_surrounding_context_length": self._max_surrounding_context_length, } return model_state @@ -791,12 +791,12 @@ def _init_model_with_state_dict(cls, state: dict[str, Any], **kwargs): entity_label_types=state["entity_label_types"], entity_pair_labels=state["entity_pair_labels"], entity_threshold=state["entity_threshold"], + max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities"), + max_surrounding_context_length=state.get("max_surrounding_context_length"), cross_augmentation=state["cross_augmentation"], encoding_strategy=state["encoding_strategy"], zero_tag_value=state["zero_tag_value"], allow_unk_tag=state["allow_unk_tag"], - max_allowed_tokens_between_entities=state.get("max_allowed_tokens_between_entities"), - max_surrounding_context_length=state.get("max_surrounding_context_length"), **kwargs, ) From de8b7f4559b69e593f990902f976b4abf2672881 Mon Sep 17 00:00:00 2001 From: dobbersc Date: Wed, 22 Jan 2025 01:00:18 +0100 Subject: [PATCH 30/46] Add test cases for `max_allowed_tokens_between_entities` and `max_surrounding_context_length` parameters --- tests/models/test_relation_classifier.py | 85 ++++++++++++++++++---- tests/resources/tasks/conllu/train.conllup | 25 +++++++ 2 files changed, 95 insertions(+), 15 deletions(-) diff --git a/tests/models/test_relation_classifier.py b/tests/models/test_relation_classifier.py index da4de52bfc..2d8aeb78f7 100644 --- a/tests/models/test_relation_classifier.py +++ b/tests/models/test_relation_classifier.py @@ -1,8 +1,4 @@ -from operator import itemgetter -from typing import Optional - import pytest -from torch.utils.data import Dataset from flair.data import Relation, Sentence from flair.datasets import ColumnCorpus, DataLoader @@ -21,7 +17,7 @@ from tests.model_test_utils import BaseModelTest encoding_strategies: dict[EncodingStrategy, list[tuple[str, str]]] = { - EntityMask(): [("[HEAD]", "[TAIL]") for _ in range(7)], + EntityMask(): [("[HEAD]", "[TAIL]") for _ in range(8)], TypedEntityMask(): [ ("[HEAD-ORG]", "[TAIL-PER]"), ("[HEAD-ORG]", "[TAIL-PER]"), @@ -30,6 +26,7 @@ ("[HEAD-LOC]", "[TAIL-PER]"), ("[HEAD-LOC]", "[TAIL-PER]"), ("[HEAD-ORG]", "[TAIL-PER]"), + ("[HEAD-LOC]", "[TAIL-PER]"), ], EntityMarker(): [ ("[HEAD] Google [/HEAD]", "[TAIL] Larry Page [/TAIL]"), @@ -39,6 +36,7 @@ ("[HEAD] Berlin [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), ("[HEAD] Germany [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), ("[HEAD] MIT [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), + ("[HEAD] Berlin [/HEAD]", "[TAIL] Joseph Weizenbaum [/TAIL]"), ], TypedEntityMarker(): [ ("[HEAD-ORG] Google [/HEAD-ORG]", "[TAIL-PER] Larry Page [/TAIL-PER]"), @@ -48,6 +46,7 @@ ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), ("[HEAD-LOC] Germany [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), ("[HEAD-ORG] MIT [/HEAD-ORG]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), + ("[HEAD-LOC] Berlin [/HEAD-LOC]", "[TAIL-PER] Joseph Weizenbaum [/TAIL-PER]"), ], EntityMarkerPunct(): [ ("@ Google @", "# Larry Page #"), @@ -57,6 +56,7 @@ ("@ Berlin @", "# Joseph Weizenbaum #"), ("@ Germany @", "# Joseph Weizenbaum #"), ("@ MIT @", "# Joseph Weizenbaum #"), + ("@ Berlin @", "# Joseph Weizenbaum #"), ], TypedEntityMarkerPunct(): [ ("@ * ORG * Google @", "# ^ PER ^ Larry Page #"), @@ -66,6 +66,7 @@ ("@ * LOC * Berlin @", "# ^ PER ^ Joseph Weizenbaum #"), ("@ * LOC * Germany @", "# ^ PER ^ Joseph Weizenbaum #"), ("@ * ORG * MIT @", "# ^ PER ^ Joseph Weizenbaum #"), + ("@ * LOC * Berlin @", "# ^ PER ^ Joseph Weizenbaum #"), ], } @@ -104,7 +105,7 @@ def transform_corpus(self, model, corpus): @pytest.fixture() def example_sentence(self): - sentence = Sentence(["Microsoft", "was", "found", "by", "Bill", "Gates"]) + sentence = Sentence(["Microsoft", "was", "founded", "by", "Bill", "Gates"]) sentence[:1].add_label(typename="ner", value="ORG", score=1.0) sentence[4:].add_label(typename="ner", value="PER", score=1.0) return sentence @@ -163,17 +164,14 @@ def assert_training_example(self, predicted_training_example): @staticmethod def check_transformation_correctness( - split: Optional[Dataset], + encoded_sentences: list[EncodedSentence], ground_truth: set[tuple[str, tuple[str, ...]]], ) -> None: # Ground truth is a set of tuples of (, ) - assert split is not None - - data_loader = DataLoader(split, batch_size=1) - assert all(isinstance(sentence, EncodedSentence) for sentence in map(itemgetter(0), data_loader)) + assert all(isinstance(sentence, EncodedSentence) for sentence in encoded_sentences) assert { (sentence.to_tokenized_string(), tuple(label.value for label in sentence.get_labels("relation"))) - for sentence in map(itemgetter(0), data_loader) + for sentence in encoded_sentences } == ground_truth @pytest.mark.parametrize( @@ -194,7 +192,12 @@ def test_transform_corpus( ) -> None: label_dictionary = corpus.make_label_dictionary("relation") model: RelationClassifier = self.build_model( - embeddings, label_dictionary, cross_augmentation=cross_augmentation, encoding_strategy=encoding_strategy + embeddings, + label_dictionary, + cross_augmentation=cross_augmentation, + encoding_strategy=encoding_strategy, + max_allowed_tokens_between_entities=None, + max_surrounding_context_length=None, ) transformed_corpus = model.transform_corpus(corpus) @@ -211,7 +214,7 @@ def test_transform_corpus( f"{encoded_entity_pairs[3][1]} was born in {encoded_entity_pairs[3][0]} on 22 June 1910 .", ("place_of_birth",), ), - # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany." + # Entity pair permutations of: "Joseph Weizenbaum , a professor at MIT , was born in Berlin , Germany ." ( f"{encoded_entity_pairs[4][1]} , a professor at MIT , " f"was born in {encoded_entity_pairs[4][0]} , Germany .", @@ -222,6 +225,12 @@ def test_transform_corpus( f"was born in Berlin , {encoded_entity_pairs[5][0]} .", ("place_of_birth",), ), + # Entity pair permutations of: "The German - American computer scientist Joseph Weizenbaum ( 8 January 1923 - 5 March 2008 ) was born in Berlin ." + ( + f"The German - American computer scientist {encoded_entity_pairs[7][1]} " + f"( 8 January 1923 - 5 March 2008 ) was born in {encoded_entity_pairs[7][0]} .", + ("place_of_birth",), + ), } if cross_augmentation: @@ -235,4 +244,50 @@ def test_transform_corpus( ) for split in (transformed_corpus.train, transformed_corpus.dev, transformed_corpus.test): - self.check_transformation_correctness(split, ground_truth) + self.check_transformation_correctness( + encoded_sentences=[batch[0] for batch in DataLoader(split, batch_size=1)], ground_truth=ground_truth + ) + + def test_transform_max_allowed_tokens_between_entities( + self, + corpus: ColumnCorpus, + embeddings: TransformerDocumentEmbeddings, + ) -> None: + assert corpus.train is not None + + label_dictionary = corpus.make_label_dictionary("relation") + model: RelationClassifier = self.build_model( + embeddings, label_dictionary, max_allowed_tokens_between_entities=12, max_surrounding_context_length=None + ) + + # "The German - American computer scientist Joseph Weizenbaum ( 8 January 1923 - 5 March 2008 ) was born in Berlin ." + sentence: Sentence = corpus.train[-1] + self.check_transformation_correctness(encoded_sentences=model.transform_sentence(sentence), ground_truth=set()) + + def test_transform_max_surrounding_context_length( + self, + corpus: ColumnCorpus, + embeddings: TransformerDocumentEmbeddings, + ) -> None: + assert corpus.train is not None + + label_dictionary = corpus.make_label_dictionary("relation") + model: RelationClassifier = self.build_model( + embeddings, + label_dictionary, + encoding_strategy=EntityMask(), + max_allowed_tokens_between_entities=None, + max_surrounding_context_length=2, + ) + + # "The German - American computer scientist Joseph Weizenbaum ( 8 January 1923 - 5 March 2008 ) was born in Berlin ." + sentence: Sentence = corpus.train[-1] + self.check_transformation_correctness( + encoded_sentences=model.transform_sentence(sentence), + ground_truth={ + ( + "computer scientist [TAIL] ( 8 January 1923 - 5 March 2008 ) was born in [HEAD] .", + ("place_of_birth",), + ), + }, + ) diff --git a/tests/resources/tasks/conllu/train.conllup b/tests/resources/tasks/conllu/train.conllup index 83ab9f305f..811efe5ce0 100644 --- a/tests/resources/tasks/conllu/train.conllup +++ b/tests/resources/tasks/conllu/train.conllup @@ -51,3 +51,28 @@ 13 , PUNCT O _ 14 Germany PROPN B-LOC SpaceAfter=No 15 . PUNCT O _ + +# text = The German-American computer scientist Joseph Weizenbaum (8 January 1923 - 5 March 2008) was born in Berlin. +# relations = 21;21;7;8;place_of_birth +1 The DET O _ +2 German PROPN O SpaceAfter=No +3 - PUNCT O SpaceAfter=No +4 American PROPN O _ +5 computer PROPN O _ +6 scientist NOUN O _ +7 Joseph PROPN B-PER _ +8 Weizenbaum PROPN I-PER _ +9 ( PUNCT O SpaceAfter=No +10 8 NUM O _ +11 January PROPN O _ +12 1923 NUM O _ +13 - SYM O _ +14 5 NUM O _ +15 March PROPN O _ +16 2008 NUM O SpaceAfter=No +17 ) PUNCT O _ +18 was PRON O _ +19 born ADV O _ +20 in ADP O _ +21 Berlin PROPN B-LOC SpaceAfter=No +22 . PUNCT O _ From a06fa304c6d07df5e48507ae98a28e57d8238041 Mon Sep 17 00:00:00 2001 From: dobbersc Date: Wed, 22 Jan 2025 01:32:40 +0100 Subject: [PATCH 31/46] Fix tests due to additional training data point in `train.conllup` --- tests/models/test_relation_extractor.py | 2 +- tests/test_datasets.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/models/test_relation_extractor.py b/tests/models/test_relation_extractor.py index 9f009fa14f..7e97b85ab8 100644 --- a/tests/models/test_relation_extractor.py +++ b/tests/models/test_relation_extractor.py @@ -21,7 +21,7 @@ class TestRelationExtractor(BaseModelTest): } training_args = { "max_epochs": 4, - "mini_batch_size": 2, + "mini_batch_size": 4, "learning_rate": 0.1, } diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8e6a7019b0..d988b03be9 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -376,9 +376,9 @@ def test_load_conllu_plus_corpus(tasks_base_path): in_memory=False, ) - assert len(corpus.train) == 4 - assert len(corpus.dev) == 4 - assert len(corpus.test) == 4 + assert len(corpus.train) == 5 + assert len(corpus.dev) == 5 + assert len(corpus.test) == 5 _assert_conllu_dataset(corpus.train) @@ -393,9 +393,9 @@ def test_load_conllu_corpus_plus_in_memory(tasks_base_path): in_memory=True, ) - assert len(corpus.train) == 4 - assert len(corpus.dev) == 4 - assert len(corpus.test) == 4 + assert len(corpus.train) == 5 + assert len(corpus.dev) == 5 + assert len(corpus.test) == 5 _assert_conllu_dataset(corpus.train) From 6fc2848cd75520597d47261b06c16977d3813a6c Mon Sep 17 00:00:00 2001 From: Nathaniel Travis Date: Fri, 24 Jan 2025 12:43:41 -0800 Subject: [PATCH 32/46] add per-task metrics --- flair/models/multitask_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/flair/models/multitask_model.py b/flair/models/multitask_model.py index 414eb46197..d789e057ce 100644 --- a/flair/models/multitask_model.py +++ b/flair/models/multitask_model.py @@ -164,6 +164,7 @@ def evaluate( # type: ignore[override] main_score = 0.0 all_detailed_results = "" all_classification_report: dict[str, dict[str, Any]] = {} + scores: dict[Any, float] = {} for task_id, split in batch_split.items(): result = self.tasks[task_id].evaluate( @@ -194,7 +195,12 @@ def evaluate( # type: ignore[override] ) all_classification_report[task_id] = result.classification_report - scores = {"loss": loss.item() / len(batch_split)} + # Add metrics so they will be available to _publish_eval_result. + for avg_type in ("micro avg", "macro avg"): + for metric_type in ("f1-score", "precision", "recall"): + scores[(task_id, avg_type, metric_type)] = result.classification_report[avg_type][metric_type] + + scores["loss"] = loss.item() / len(batch_split) return Result( main_score=main_score / len(batch_split), From 32c875b2e8a981ebf39ce7c0a453d2eebbd1bee9 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 24 Jan 2025 23:54:11 -0800 Subject: [PATCH 33/46] fix: cast indices tensor to int to fix bug --- flair/nn/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flair/nn/model.py b/flair/nn/model.py index 54c52eba3e..04a9c24a3b 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -899,7 +899,9 @@ def predict( if has_unknown_label: has_any_unknown_label = True - scores = torch.index_select(scores, 0, torch.tensor(filtered_indices, device=flair.device)) + scores = torch.index_select( + scores, 0, torch.tensor(filtered_indices, device=flair.device, dtype=torch.int32) + ) gold_labels = self._prepare_label_tensor([data_points[index] for index in filtered_indices]) overall_loss += self._calculate_loss(scores, gold_labels)[0] From 7c302c6afea086acd2aa0a29475ee81887f4a049 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Sun, 26 Jan 2025 12:20:25 -0800 Subject: [PATCH 34/46] fix: use proper eval default main eval metrics for text regression model also refactor variables to avoid type conflicts --- flair/models/pairwise_regression_model.py | 8 +++---- flair/models/text_regression_model.py | 27 ++++++++++++++--------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py index 9a1c2704be..bc77b54dce 100644 --- a/flair/models/pairwise_regression_model.py +++ b/flair/models/pairwise_regression_model.py @@ -345,7 +345,7 @@ def evaluate( f"spearman: {metric.spearmanr():.4f}" ) - scores = { + eval_metrics = { "loss": eval_loss.item(), "mse": metric.mean_squared_error(), "mae": metric.mean_absolute_error(), @@ -354,12 +354,12 @@ def evaluate( } if main_evaluation_metric[0] in ("correlation", "other"): - main_score = scores[main_evaluation_metric[1]] + main_score = eval_metrics[main_evaluation_metric[1]] else: - main_score = scores["spearman"] + main_score = eval_metrics["spearman"] return Result( main_score=main_score, detailed_results=detailed_result, - scores=scores, + scores=eval_metrics, ) diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py index d1ad98d4e0..a0a99e6402 100644 --- a/flair/models/text_regression_model.py +++ b/flair/models/text_regression_model.py @@ -137,7 +137,7 @@ def evaluate( out_path: Optional[Union[str, Path]] = None, embedding_storage_mode: EmbeddingStorageMode = "none", mini_batch_size: int = 32, - main_evaluation_metric: tuple[str, str] = ("micro avg", "f1-score"), + main_evaluation_metric: tuple[str, str] = ("correlation", "pearson"), exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, @@ -195,16 +195,23 @@ def evaluate( f"spearman: {metric.spearmanr():.4f}" ) - result: Result = Result( - main_score=metric.pearsonr(), + eval_metrics = { + "loss": eval_loss.item(), + "mse": metric.mean_squared_error(), + "mae": metric.mean_absolute_error(), + "pearson": metric.pearsonr(), + "spearman": metric.spearmanr(), + } + + if main_evaluation_metric[0] in ("correlation", "other"): + main_score = eval_metrics[main_evaluation_metric[1]] + else: + main_score = eval_metrics["spearman"] + + result = Result( + main_score=main_score, detailed_results=detailed_result, - scores={ - "loss": eval_loss.item(), - "mse": metric.mean_squared_error(), - "mae": metric.mean_absolute_error(), - "pearson": metric.pearsonr(), - "spearman": metric.spearmanr(), - }, + scores=eval_metrics, ) return result From 6f8d168b27032e3146c701728abf1c808dbe58d3 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Mon, 27 Jan 2025 09:41:06 +0100 Subject: [PATCH 35/46] Fix serialization issue in ModelTrainer --- flair/models/regexp_tagger.py | 1 - flair/trainers/trainer.py | 20 ++++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/flair/models/regexp_tagger.py b/flair/models/regexp_tagger.py index 37906d55aa..00cc1073d3 100644 --- a/flair/models/regexp_tagger.py +++ b/flair/models/regexp_tagger.py @@ -1,5 +1,4 @@ import re -import typing from dataclasses import dataclass, field from typing import Union diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index 4f5d5ff7b8..fcfa100b68 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -966,14 +966,18 @@ def _initialize_model_card(self, **training_parameters): except ImportError: pass - # remember all parameters used in train() call - model_card["training_parameters"] = { - k: str(v) if isinstance(v, Path) else v for k, v in training_parameters.items() - } - - model_card["training_parameters"] = { - k: f"{v.__module__}.{v.__name__}" if inspect.isclass(v) else v for k, v in training_parameters.items() - } + # remember the training parameters + model_card["training_parameters"] = {} + for k, v in training_parameters.items(): + + # special rule for Path variables to make sure models can be deserialized on other OS + if isinstance(v, Path): + v = str(v) + # classes are only serialized as names + if inspect.isclass(v): + v = f"{v.__module__}.{v.__name__}" + + model_card["training_parameters"][k] = v plugins = [plugin.get_state() for plugin in model_card["training_parameters"]["plugins"]] model_card["training_parameters"]["plugins"] = plugins From 0b95bcd0dba10ac714cbf32ae7d3528f9d9ba5ad Mon Sep 17 00:00:00 2001 From: Samule Garda Date: Tue, 28 Jan 2025 10:56:38 +0100 Subject: [PATCH 36/46] fix: update scipy .A to toarray() --- flair/embeddings/document.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 8f66a198ed..69fe9405b3 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -210,7 +210,7 @@ def embed(self, sentences: Union[list[Sentence], Sentence]): sentences = [sentences] raw_sentences = [s.to_original_text() for s in sentences] - tfidf_vectors = torch.from_numpy(self.vectorizer.transform(raw_sentences).A) + tfidf_vectors = torch.from_numpy(self.vectorizer.transform(raw_sentences).toarray()) for sentence_id, sentence in enumerate(sentences): sentence.set_embedding(self.name, tfidf_vectors[sentence_id]) From 79aa33706e7f753f2edf962feb1d75de22af0d1d Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 31 Jan 2025 14:09:37 +0100 Subject: [PATCH 37/46] add compability to torch 2.6 --- flair/file_utils.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/file_utils.py b/flair/file_utils.py index 518d69e809..53fac0f75b 100644 --- a/flair/file_utils.py +++ b/flair/file_utils.py @@ -382,4 +382,4 @@ def load_torch_state(model_file: str) -> dict[str, typing.Any]: # to load models on some Mac/Windows setups # see https://github.com/zalandoresearch/flair/issues/351 f = load_big_file(model_file) - return torch.load(f, map_location="cpu") + return torch.load(f, map_location="cpu", weights_only=False) diff --git a/requirements.txt b/requirements.txt index 2704114ace..39cf750c66 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ scikit-learn>=1.0.2 segtok>=1.5.11 sqlitedict>=2.0.0 tabulate>=0.8.10 -torch>=1.5.0,!=1.8 +torch>=1.13.1 tqdm>=4.63.0 transformer-smaller-training-vocab>=0.2.3 transformers[sentencepiece]>=4.25.0,<5.0.0 From 087b74efa3c1839010bb8dcc3bbc740d57da131a Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 31 Jan 2025 15:05:56 +0100 Subject: [PATCH 38/46] fix some typing --- flair/embeddings/document.py | 3 +-- flair/embeddings/token.py | 2 +- flair/models/sequence_tagger_model.py | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/flair/embeddings/document.py b/flair/embeddings/document.py index 8f66a198ed..91aac0b9e0 100644 --- a/flair/embeddings/document.py +++ b/flair/embeddings/document.py @@ -691,10 +691,9 @@ def _add_embeddings_internal(self, sentences: list[Sentence]): lengths: list[int] = [len(sentence.tokens) for sentence in sentences] padding_length: int = max(max(lengths), self.min_sequence_length) - pre_allocated_zero_tensor = torch.zeros( self.embeddings.embedding_length * padding_length, - dtype=self.convs[0].weight.dtype, + dtype=cast(torch.nn.Conv1d, self.convs[0]).weight.dtype, device=flair.device, ) diff --git a/flair/embeddings/token.py b/flair/embeddings/token.py index 3d95c8ee0b..a9e6ba2199 100644 --- a/flair/embeddings/token.py +++ b/flair/embeddings/token.py @@ -1466,7 +1466,7 @@ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]: word = token.text if self.field is None else token.get_label(self.field).value if word.strip() == "": - ids = [self.spm.vocab_size(), self.embedder.spm.vocab_size()] + ids = [self.spm.vocab_size(), self.spm.vocab_size()] else: if self.do_preproc: word = self._preprocess(word) diff --git a/flair/models/sequence_tagger_model.py b/flair/models/sequence_tagger_model.py index ea028c3e70..25b5b374f3 100644 --- a/flair/models/sequence_tagger_model.py +++ b/flair/models/sequence_tagger_model.py @@ -862,8 +862,8 @@ def push_to_hub( self.save(local_model_path) # Determine if model card already exists - info = model_info(repo_id, use_auth_token=token) - write_readme = all(f.rfilename != "README.md" for f in info.siblings) + info = model_info(repo_id, token=token) + write_readme = info.siblings is None or all(f.rfilename != "README.md" for f in info.siblings) # Generate and save model card if write_readme: From 6a5fc2cb7671ecbf837364a541a38dc94b88e2c6 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 27 Jan 2025 15:15:50 +0100 Subject: [PATCH 39/46] dataset: add support for BarNER dataset --- flair/datasets/sequence_labeling.py | 91 +++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 80fc6d38ba..22f6f90650 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -5530,3 +5530,94 @@ def __init__( corpora, name="masakha-pos-" + "-".join(languages), ) + + +class NER_BAVARIAN_WIKI(ColumnCorpus): + def __init__( + self, + fine_grained: bool = False, + revision: str = "main", + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + **corpusargs, + ) -> None: + """Initialize the Bavarian NER Bavarian NER Dataset (BarNER). + + The dataset was proposed in the 2024 LREC-COLING paper + "Sebastian, Basti, Wastl?! Recognizing Named Entities in Bavarian Dialectal Data" paper by Peng et al. + :param fine_grained: Defines if the fine-grained or coarse-grained (default) should be used. + :param revision: Defines the revision/commit of BarNER dataset, by default dataset from 'main' branch is used. + :param base_path: Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this + to point to a different folder but typically this should not be necessary. + :param in_memory: If True, keeps dataset in memory giving speedups in training. + """ + base_path = flair.cache_root / "datasets" if not base_path else Path(base_path) + dataset_name = self.__class__.__name__.lower() + data_folder = base_path / dataset_name + data_path = flair.cache_root / "datasets" / dataset_name + + document_boundary_marker = "-DOCSTART-" + + for split in ["train", "dev", "test"]: + # Get original version + original_split_filename = data_path / "original" / f"bar-wiki-{split}.tsv" + if not original_split_filename.is_file(): + original_split_url = ( + f"https://raw.githubusercontent.com/mainlp/BarNER/{revision}/data/BarNER-final/bar-wiki-{split}.tsv" + ) + cached_path(original_split_url, data_path / "original") + + # Add sentence boundary marker + modified_split_filename = data_path / f"bar-wiki-{split}.tsv" + if not modified_split_filename.is_file(): + f_out = open(modified_split_filename, "w") + + with open(original_split_filename) as f_p: + for line in f_p: + line = line.strip() + if line.startswith("# newdoc id = "): + f_out.write(f"{document_boundary_marker}\tO\n\n") + continue + if line.startswith("# "): + continue + f_out.write(f"{line}\n") + f_out.close() + + columns = {0: "text", 1: "ner"} + + label_name_map = None + + if not fine_grained: + # Only allowed classes in course setting are: PER, LOC, ORG and MISC. + # All other NEs are normalized to O, except EVENT and WOA are normalized to MISC (cf. Table 3 of paper). + label_name_map = { + "EVENT": "MISC", + "EVENTderiv": "O", + "EVENTpart": "O", + "LANG": "O", + "LANGderiv": "O", + "LANGpart": "O", + "LOCderiv": "O", + "LOCpart": "O", + "MISCderiv": "O", + "MISCpart": "O", + "ORGderiv": "O", + "ORGpart": "O", + "PERderiv": "O", + "PERpart": "O", + "RELIGION": "O", + "RELIGIONderiv": "O", + "WOA": "MISC", + "WOAderiv": "O", + "WOApart": "O", + } + + super().__init__( + data_folder, + columns, + in_memory=in_memory, + comment_symbol="# ", + document_separator_token="-DOCSTART-", + label_name_map=label_name_map, + **corpusargs, + ) From f8186a301cdd4c980bbb353add6688422eb24d75 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 27 Jan 2025 15:16:27 +0100 Subject: [PATCH 40/46] dataset: make NER_BAVARIAN_WIKI (BarNER) globally available) --- flair/datasets/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index d54ff35e01..30fa00fc67 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -185,6 +185,7 @@ NER_ARABIC_ANER, NER_ARABIC_AQMAR, NER_BASQUE, + NER_BAVARIAN_WIKI, NER_CHINESE_WEIBO, NER_DANISH_DANE, NER_ENGLISH_MOVIE_COMPLEX, @@ -477,6 +478,7 @@ "NER_ARABIC_ANER", "NER_ARABIC_AQMAR", "NER_BASQUE", + "NER_BAVARIAN_WIKI", "NER_CHINESE_WEIBO", "NER_DANISH_DANE", "NER_ENGLISH_MOVIE_COMPLEX", From dcd029befcd007211c69ff346691a620db1a50b2 Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 27 Jan 2025 15:17:08 +0100 Subject: [PATCH 41/46] tests: add basic sentence & token count test for new BarNER dataset --- tests/test_datasets.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8e6a7019b0..9acd753337 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -954,6 +954,29 @@ def test_german_mobie(tasks_base_path): ), f"Number of parsed tokens ({actual_tokens}) does not match with reported number of tokens ({ref_tokens})!" +@pytest.mark.skip() +def test_bavarian_wiki(tasks_base_path): + corpus = flair.datasets.NER_BAVARIAN_WIKI() + + ref_sentences = 3_577 + ref_tokens = 75_690 + + actual_sentences = sum( + [1 for sentence in corpus.train + corpus.dev + corpus.test if sentence[0].text != "-DOCSTART-"] + ) + actual_tokens = sum( + [len(sentence) for sentence in corpus.train + corpus.dev + corpus.test if sentence[0].text != "-DOCSTART-"] + ) + + assert ref_sentences == actual_sentences, ( + f"Number of parsed sentences ({actual_sentences}) does not match with " + f"reported number of sentences ({ref_sentences})!" + ) + assert ( + ref_tokens == actual_tokens + ), f"Number of parsed tokens ({actual_tokens}) does not match with reported number of tokens ({ref_tokens})!" + + def test_multi_file_jsonl_corpus_should_use_label_type(tasks_base_path): corpus = MultiFileJsonlCorpus( train_files=[tasks_base_path / "jsonl/train.jsonl"], From 4332d7928f8ab25b5bbe51a5cb0dea4ef9902ad2 Mon Sep 17 00:00:00 2001 From: alanakbik Date: Mon, 3 Feb 2025 12:31:25 +0100 Subject: [PATCH 42/46] Specify UTF-8 encoding --- flair/datasets/sequence_labeling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flair/datasets/sequence_labeling.py b/flair/datasets/sequence_labeling.py index 22f6f90650..863446c1cc 100644 --- a/flair/datasets/sequence_labeling.py +++ b/flair/datasets/sequence_labeling.py @@ -5570,9 +5570,9 @@ def __init__( # Add sentence boundary marker modified_split_filename = data_path / f"bar-wiki-{split}.tsv" if not modified_split_filename.is_file(): - f_out = open(modified_split_filename, "w") + f_out = open(modified_split_filename, "w", encoding="utf-8") - with open(original_split_filename) as f_p: + with open(original_split_filename, encoding="utf-8") as f_p: for line in f_p: line = line.strip() if line.startswith("# newdoc id = "): From 2a3f6cb2db5d321489ebc3a4b42b69208c3be370 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 2 Aug 2024 18:10:03 -0700 Subject: [PATCH 43/46] feat: add chunking function to allow sequence tagger training on sentences exceeding the token limit, including tests --- flair/class_utils.py | 9 +- flair/training_utils.py | 217 +++++++++++++++++---- requirements-dev.txt | 2 +- tests/resources/text_sequences/resume1.txt | 85 ++++++++ tests/test_sentence_labeling.py | 191 ++++++++++++++++++ 5 files changed, 461 insertions(+), 43 deletions(-) create mode 100644 tests/resources/text_sequences/resume1.txt create mode 100644 tests/test_sentence_labeling.py diff --git a/flair/class_utils.py b/flair/class_utils.py index 7e01f4ff42..ec6666c99f 100644 --- a/flair/class_utils.py +++ b/flair/class_utils.py @@ -2,12 +2,17 @@ import inspect from collections.abc import Iterable from types import ModuleType -from typing import Any, Optional, TypeVar, Union, overload +from typing import Any, Iterable, List, Optional, Protocol, Type, TypeVar, Union, overload + T = TypeVar("T") -def get_non_abstract_subclasses(cls: type[T]) -> Iterable[type[T]]: +class StringLike(Protocol): + def __str__(self) -> str: ... + + +def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]: for subclass in cls.__subclasses__(): yield from get_non_abstract_subclasses(subclass) if inspect.isabstract(subclass): diff --git a/flair/training_utils.py b/flair/training_utils.py index 9b38ec1ddb..69430a17b3 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -1,22 +1,27 @@ import logging +import pathlib import random from collections import defaultdict from enum import Enum from functools import reduce from math import inf from pathlib import Path -from typing import Literal, Optional, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Union +from numpy import ndarray from scipy.stats import pearsonr, spearmanr +from scipy.stats._stats_py import PearsonRResult, SignificanceResult from sklearn.metrics import mean_absolute_error, mean_squared_error from torch.optim import Optimizer from torch.utils.data import Dataset import flair -from flair.data import DT, Dictionary, Sentence, _iter_dataset +from flair.class_utils import StringLike +from flair.data import DT, Dictionary, Sentence, Token, _iter_dataset EmbeddingStorageMode = Literal["none", "cpu", "gpu"] -log = logging.getLogger("flair") +MinMax = Literal["min", "max"] +logger = logging.getLogger("flair") class Result: @@ -33,7 +38,7 @@ def __init__( self.main_score: float = main_score self.scores = scores self.detailed_results: str = detailed_results - self.classification_report = classification_report + self.classification_report = classification_report if classification_report is not None else {} @property def loss(self): @@ -44,40 +49,36 @@ def __str__(self) -> str: class MetricRegression: - def __init__(self, name) -> None: + def __init__(self, name: str) -> None: self.name = name self.true: list[float] = [] self.pred: list[float] = [] - def mean_squared_error(self): + def mean_squared_error(self) -> Union[float, ndarray]: return mean_squared_error(self.true, self.pred) def mean_absolute_error(self): return mean_absolute_error(self.true, self.pred) - def pearsonr(self): + def pearsonr(self) -> PearsonRResult: return pearsonr(self.true, self.pred)[0] - def spearmanr(self): + def spearmanr(self) -> SignificanceResult: return spearmanr(self.true, self.pred)[0] - # dummy return to fulfill trainer.train() needs - def micro_avg_f_score(self): - return self.mean_squared_error() - - def to_tsv(self): + def to_tsv(self) -> str: return f"{self.mean_squared_error()}\t{self.mean_absolute_error()}\t{self.pearsonr()}\t{self.spearmanr()}" @staticmethod - def tsv_header(prefix=None): + def tsv_header(prefix: StringLike = None) -> str: if prefix: return f"{prefix}_MEAN_SQUARED_ERROR\t{prefix}_MEAN_ABSOLUTE_ERROR\t{prefix}_PEARSON\t{prefix}_SPEARMAN" return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN" @staticmethod - def to_empty_tsv(): + def to_empty_tsv() -> str: return "\t_\t_\t_\t_" def __str__(self) -> str: @@ -101,13 +102,13 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list)) self.number_of_weights = number_of_weights - def extract_weights(self, state_dict, iteration): + def extract_weights(self, state_dict: Dict, iteration: int) -> None: for key in state_dict: vec = state_dict[key] - # print(vec) try: weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size()))) - except Exception: + except Exception as e: + logger.debug(e) continue if key not in self.weights_dict: @@ -195,15 +196,15 @@ class AnnealOnPlateau: def __init__( self, optimizer, - mode="min", - aux_mode="min", - factor=0.1, - patience=10, - initial_extra_patience=0, - verbose=False, - cooldown=0, - min_lr=0, - eps=1e-8, + mode: MinMax = "min", + aux_mode: MinMax = "min", + factor: float = 0.1, + patience: int = 10, + initial_extra_patience: int = 0, + verbose: bool = False, + cooldown: int = 0, + min_lr: float = 0.0, + eps: float = 1e-8, ) -> None: if factor >= 1.0: raise ValueError("Factor should be < 1.0.") @@ -214,6 +215,7 @@ def __init__( raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") self.optimizer = optimizer + self.min_lrs: List[float] if isinstance(min_lr, (list, tuple)): if len(min_lr) != len(optimizer.param_groups): raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}") @@ -231,7 +233,7 @@ def __init__( self.best = None self.best_aux = None self.num_bad_epochs = None - self.mode_worse = None # the worse value for the chosen mode + self.mode_worse: Optional[float] = None # the worse value for the chosen mode self.eps = eps self.last_epoch = 0 self._init_is_better(mode=mode) @@ -258,7 +260,7 @@ def step(self, metric, auxiliary_metric=None) -> bool: if self.mode == "max" and current > self.best: is_better = True - if current == self.best and auxiliary_metric: + if current == self.best and auxiliary_metric is not None: current_aux = float(auxiliary_metric) if self.aux_mode == "min" and current_aux < self.best_aux: is_better = True @@ -289,20 +291,20 @@ def step(self, metric, auxiliary_metric=None) -> bool: return reduce_learning_rate - def _reduce_lr(self, epoch): + def _reduce_lr(self, epoch: int) -> None: for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group["lr"]) new_lr = max(old_lr * self.factor, self.min_lrs[i]) if old_lr - new_lr > self.eps: param_group["lr"] = new_lr if self.verbose: - log.info(f" - reducing learning rate of group {epoch} to {new_lr}") + logger.info(f" - reducing learning rate of group {epoch} to {new_lr}") @property def in_cooldown(self): return self.cooldown_counter > 0 - def _init_is_better(self, mode): + def _init_is_better(self, mode: MinMax) -> None: if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") @@ -313,10 +315,10 @@ def _init_is_better(self, mode): self.mode = mode - def state_dict(self): + def state_dict(self) -> Dict: return {key: value for key, value in self.__dict__.items() if key != "optimizer"} - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Dict) -> None: self.__dict__.update(state_dict) self._init_is_better(mode=self.mode) @@ -350,11 +352,11 @@ def convert_labels_to_one_hot(label_list: list[list[str]], label_dict: Dictionar return [[1 if label in labels else 0 for label in label_dict.get_items()] for labels in label_list] -def log_line(log): +def log_line(log: logging.Logger) -> None: log.info("-" * 100, stacklevel=3) -def add_file_handler(log, output_file): +def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.FileHandler: init_output_file(output_file.parents[0], output_file.name) fh = logging.FileHandler(output_file, mode="w", encoding="utf-8") fh.setLevel(logging.INFO) @@ -367,12 +369,20 @@ def add_file_handler(log, output_file): def store_embeddings( data_points: Union[list[DT], Dataset], storage_mode: EmbeddingStorageMode, - dynamic_embeddings: Optional[list[str]] = None, -): + dynamic_embeddings: Optional[List[str]] = None, +) -> None: + """Stores embeddings of data points in memory or on disk. + + Args: + data_points: a DataSet or list of DataPoints for which embeddings should be stored + storage_mode: store in either CPU or GPU memory, or delete them if set to 'none' + dynamic_embeddings: these are always deleted. If not passed, they are identified automatically. + """ + if isinstance(data_points, Dataset): data_points = list(_iter_dataset(data_points)) - # if memory mode option 'none' delete everything + # if storage mode option 'none' delete everything if storage_mode == "none": dynamic_embeddings = None @@ -391,7 +401,7 @@ def store_embeddings( data_point.to("cpu", pin_memory=pin_memory) -def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]: +def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]: dynamic_embeddings = [] all_embeddings = [] for data_point in data_points: @@ -411,3 +421,130 @@ def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]: if not all_embeddings: return None return list(set(dynamic_embeddings)) + + +class TokenEntity(NamedTuple): + """Entity represented by token indices.""" + + start_token_idx: int + end_token_idx: int + label: str + value: str = "" # text value of the entity + score: float = 1.0 + + +class CharEntity(NamedTuple): + """Entity represented by character indices.""" + + start_char_idx: int + end_char_idx: int + label: str + value: str + score: float = 1.0 + + +def create_labeled_sentence_from_tokens( + tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner" +) -> Sentence: + """Creates a new Sentence object from a list of tokens or strings and applies entity labels. + + Tokens are recreated with the same text, but not attached to the previous sentence. + + Args: + tokens: a list of Token objects or strings - only the text is used, not any labels + token_entities: a list of TokenEntity objects representing entity annotations + type_name: the type of entity label to apply + Returns: + A labeled Sentence object + """ + tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence + sentence = Sentence(tokens, use_tokenizer=True) + for entity in token_entities: + sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score) + return sentence + + +def create_sentence_chunks( + text: str, + entities: List[CharEntity], + token_limit: int = 512, + use_context: bool = True, + overlap: int = 0, # TODO: implement overlap +) -> List[Sentence]: + """Chunks and labels a text from a list of entity annotations. + + The function explicitly tokenizes the text and labels separately, ensuring entity labels are + not partially split across tokens. + + Args: + text (str): The full text to be tokenized and labeled. + entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the + format (start_char_index, end_char_index, entity_class, entity_text). + token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking + use_context: whether to add context to the sentence + overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context + + Returns: + A list of labeled Sentence objects representing the chunks of the original text + """ + chunks = [] + + tokens: List[Token] = [] + current_index = 0 + token_entities: List[TokenEntity] = [] + end_token_idx = 0 + + for entity in entities: + + if entity.start_char_idx > current_index: # add non-entity text + non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens + while end_token_idx + len(non_entity_tokens) > token_limit: + num_tokens = token_limit - len(tokens) + tokens.extend(non_entity_tokens[:num_tokens]) + non_entity_tokens = non_entity_tokens[num_tokens:] + # skip any fully negative samples, they cause fine_tune to fail with + # `torch.cat(): expected a non-empty list of Tensors` + if len(token_entities) > 0: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + tokens, token_entities = [], [] + end_token_idx = 0 + tokens.extend(non_entity_tokens) + + # add new entity tokens + start_token_idx = len(tokens) + entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx]) + if len(entity_sentence) > token_limit: + logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}") + end_token_idx = start_token_idx + len(entity_sentence) + + if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + + tokens, token_entities = [], [] + start_token_idx, end_token_idx = 0, len(entity_sentence) + + token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score) + token_entities.append(token_entity) + tokens.extend(entity_sentence) + + current_index = entity.end_char_idx + + # add any remaining tokens to a new chunk + if current_index < len(text): + remaining_sentence = Sentence(text[current_index:]) + if end_token_idx + len(remaining_sentence) > token_limit: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + tokens, token_entities = [], [] + tokens.extend(remaining_sentence) + + if tokens: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + + for chunk in chunks: + if len(chunk) > token_limit: + logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}") + + if use_context: + Sentence.set_context_for_sentences(chunks) + + return chunks diff --git a/requirements-dev.txt b/requirements-dev.txt index 3b8fbde79c..8053d231b8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,4 @@ types-Deprecated>=1.2.9.2 types-requests>=2.28.11.17 types-tabulate>=0.9.0.2 pyab3p -transformers!=4.40.1,!=4.40.0 \ No newline at end of file +transformers!=4.40.1,!=4.40.0 diff --git a/tests/resources/text_sequences/resume1.txt b/tests/resources/text_sequences/resume1.txt new file mode 100644 index 0000000000..6be7107559 --- /dev/null +++ b/tests/resources/text_sequences/resume1.txt @@ -0,0 +1,85 @@ +Nate Diaz +_____________________________________________________________________ +         5151 Morada Lane Stockton, CA | (555) 445-6487| lkji8765@ziprecruiter.com + +Summary +____________________________________________________________________________ + +Reliable, service-focused nursing professional with excellent patient-care and charting skills gained through five years of experience as a CNA. Compassionate and technically skilled in attending to patients in diverse healthcare settings. BLS and CPR certified (current). + +Highlights +____________________________________________________________________________ + +· Patient Care & Safety +· Diagnostic Testing +· Customer loyalty +· Medical Terminology +· Privacy / HIPAA Regulations +Experience +____________________________________________________________________________ +ABC STAFFING COMPANY +Certified Nursing Assistant 2012- Present +Stockton, CA + +· Provide high-quality patient care as an in-demand per-diem CNA within surgical, acute-care, rehabilitation, home-healthcare and nursing-home settings. +· Preserve patient dignity and minimize discomfort while carrying out duties such as bedpan changes, diapering, emptying drainage bags and bathing. +· Commended for chart accuracy, effective team collaboration, patient relations and consistent delivery of empathetic care. + +DEF REHABILITATION COMPANY 2010-2012 +Certified Nursing Assistant Stockton, CA + +· Displayed strong clinical skills in assessing vital signs, performing lab draws and glucose checks, and providing pre- and post-operative care. +· Adhered to safety guidelines; completed hospital’s three-hour Patient Safety Training Program. +· Ensured the accurate, timely flow of information by maintaining thorough patient records and updating healthcare team on patients’ status. +· Complied with HIPAA standards in all patient documentation and interactions. + +Education + +____________________________________________________________________________ + +Nurse’s Aid Program- Completed 60 hours of clinical training +Arizona State University +Tempe, AZ + +Nate Diaz +_____________________________________________________________________ +         5151 Morada Lane Stockton, CA | (555) 445-6487| lkji8765@ziprecruiter.com + +Summary +____________________________________________________________________________ + +Reliable, service-focused nursing professional with excellent patient-care and charting skills gained through five years of experience as a CNA. Compassionate and technically skilled in attending to patients in diverse healthcare settings. BLS and CPR certified (current). + +Highlights +____________________________________________________________________________ + +· Patient Care & Safety +· Diagnostic Testing +· Customer loyalty +· Medical Terminology +· Privacy / HIPAA Regulations +Experience +____________________________________________________________________________ +ABC STAFFING COMPANY +Certified Nursing Assistant 2012- Present +Stockton, CA + +· Provide high-quality patient care as an in-demand per-diem CNA within surgical, acute-care, rehabilitation, home-healthcare and nursing-home settings. +· Preserve patient dignity and minimize discomfort while carrying out duties such as bedpan changes, diapering, emptying drainage bags and bathing. +· Commended for chart accuracy, effective team collaboration, patient relations and consistent delivery of empathetic care. + +DEF REHABILITATION COMPANY 2010-2012 +Certified Nursing Assistant Stockton, CA + +· Displayed strong clinical skills in assessing vital signs, performing lab draws and glucose checks, and providing pre- and post-operative care. +· Adhered to safety guidelines; completed hospital’s three-hour Patient Safety Training Program. +· Ensured the accurate, timely flow of information by maintaining thorough patient records and updating healthcare team on patients’ status. +· Complied with HIPAA standards in all patient documentation and interactions. + +Education + +____________________________________________________________________________ + +Nurse’s Aid Program- Completed 60 hours of clinical training +Arizona State University +Tempe, AZ \ No newline at end of file diff --git a/tests/test_sentence_labeling.py b/tests/test_sentence_labeling.py new file mode 100644 index 0000000000..810ae21038 --- /dev/null +++ b/tests/test_sentence_labeling.py @@ -0,0 +1,191 @@ +from typing import Dict, List + +import pytest + +from flair.data import Sentence +from flair.training_utils import CharEntity, create_sentence_chunks + + +@pytest.fixture(params=["resume1.txt"]) +def resume(request, resources_path) -> str: + filepath = resources_path / "text_sequences" / request.param + with open(filepath, encoding="utf8") as file: + text_content = file.read() + return text_content + + +@pytest.fixture +def parsed_resume_dict(resume) -> dict: + return { + "raw_text": resume, + "entities": [ + CharEntity(20, 40, "dummy_label1", "Dummy Text 1"), + CharEntity(250, 300, "dummy_label2", "Dummy Text 2"), + CharEntity(700, 810, "dummy_label3", "Dummy Text 3"), + CharEntity(3900, 4000, "dummy_label4", "Dummy Text 4"), + ], + } + + +@pytest.fixture +def small_token_limit_resume() -> Dict: + return { + "raw_text": "Professional Clown June 2020 - August 2021 Entertaining students of all ages. Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Gained " + "proficiency in juggling and scaring children.", + "entities": [ + CharEntity(0, 18, "EXPERIENCE.TITLE", ""), + CharEntity(19, 29, "DATE.START_DATE", ""), + CharEntity(31, 42, "DATE.END_DATE", ""), + CharEntity(450, 510, "EXPERIENCE.DESCRIPTION", ""), + ], + } + + +@pytest.fixture +def small_token_limit_response() -> List[Sentence]: + """Recreates expected response Sentences.""" + chunk0 = Sentence("Professional Clown June 2020 - August 2021 Entertaining students of") + chunk0[0:2].add_label("Professional Clown", "EXPERIENCE.TITLE") + chunk0[2:4].add_label("June 2020", "DATE.START_DATE") + chunk0[5:7].add_label("August 2021", "DATE.END_DATE") + + chunk1 = Sentence("Blah Blah Blah Blah Blah Blah Blah Bl") + + chunk2 = Sentence("ah Blah Gained proficiency in juggling and scaring children .") + chunk2[0:10].add_label("ah Blah Gained proficiency in juggling and scaring children .", "EXPERIENCE.DESCRIPTION") + + return [chunk0, chunk1, chunk2] + + +class TestChunking: + def test_empty_string(self): + sentences = create_sentence_chunks("", []) + assert len(sentences) == 0 + + def check_split_entities(self, entity_labels, chunks, max_token_limit): + """Ensure that no entities are split over chunks (except entities longer than the token limit).""" + chunk_intervals = [] + start_index = 0 + for chunk in chunks: + end_index = start_index + len(chunk.text) + chunk_intervals.append((start_index, end_index)) + start_index = end_index + + for entity in entity_labels: + entity_start, entity_end = entity.start_char_idx, entity.end_char_idx + entity_length = entity_end - entity_start + + # Skip the check if the entity itself is longer than the maximum token limit + if entity_length > max_token_limit: + continue + + assert any( + start <= entity_start and entity_end <= end for start, end in chunk_intervals + ), f"Entity {entity} is not within a single chunk interval" + + @pytest.mark.parametrize( + "test_text, expected_text", + [ + ("test text", "test text"), + ("a", "a"), + ("this ", "this"), + ], + ) + def test_short_text(self, test_text, expected_text): + """Short texts that should fit nicely into a single chunk.""" + chunks = create_sentence_chunks(test_text, []) + assert chunks[0].text == expected_text + + def test_create_flair_sentence(self, parsed_resume_dict): + chunks = create_sentence_chunks(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) + assert len(chunks) == 2 + + max_token_limit = 512 # default + assert all(len(c) <= max_token_limit for c in chunks) + + self.check_split_entities(parsed_resume_dict["entities"], chunks, max_token_limit) + + def test_small_token_limit(self, small_token_limit_resume, small_token_limit_response): + max_token_limit = 10 # test a small max token limit + chunks = create_sentence_chunks( + small_token_limit_resume["raw_text"], small_token_limit_resume["entities"], token_limit=max_token_limit + ) + + for response, expected in zip(chunks, small_token_limit_response): + assert response.to_tagged_string() == expected.to_tagged_string() + + assert all(len(c) <= max_token_limit for c in chunks) + + self.check_split_entities(small_token_limit_resume["entities"], chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities, expected_chunks", + [ + ( + "Led a team of five engineers. It's important to note the project's success. We've implemented state-of-the-art technologies. Co-ordinated efforts with cross-functional teams.", + [ + CharEntity(0, 25, "RESPONSIBILITY", "Led a team of five engineers"), + CharEntity(27, 72, "ACHIEVEMENT", "It's important to note the project's success"), + CharEntity(74, 117, "ACHIEVEMENT", "We've implemented state-of-the-art technologies"), + CharEntity(119, 168, "RESPONSIBILITY", "Co-ordinated efforts with cross-functional teams"), + ], + [ + "Led a team of five engine er s. It 's important to note the project 's succe ss", + ". We 've implemented state-of-the-art techno lo gies . Co-ordinated efforts with cross-functional teams .", + ], + ), + ], + ) + def test_contractions_and_hyphens(self, test_text, entities, expected_chunks): + max_token_limit = 20 + chunks = create_sentence_chunks(test_text, entities, max_token_limit) + for i, chunk in enumerate(expected_chunks): + assert chunks[i].text == chunk + self.check_split_entities(entities, chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities", + [ + ( + "This is a long text. " * 100, + [CharEntity(0, 1000, "dummy_label1", "Dummy Text 1")], + ) + ], + ) + def test_long_text(self, test_text, entities): + """Test for handling long texts that should be split into multiple chunks.""" + max_token_limit = 512 + chunks = create_sentence_chunks(test_text, entities, max_token_limit) + assert len(chunks) > 1 + assert all(len(c) <= max_token_limit for c in chunks) + self.check_split_entities(entities, chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities, expected_chunks", + [ + ( + "Hello! Is your company hiring? I am available for employment. Contact me at 5:00 p.m.", + [ + CharEntity(0, 6, "LABEL", "Hello!"), + CharEntity(7, 31, "LABEL", "Is your company hiring?"), + CharEntity(32, 65, "LABEL", "I am available for employment."), + CharEntity(66, 86, "LABEL", "Contact me at 5:00 p.m."), + ], + [ + "Hello ! Is your company hiring ? I", + "am available for employment . Con t", + "act me at 5:00 p.m .", + ], + ) + ], + ) + def test_text_with_punctuation(self, test_text, entities, expected_chunks): + max_token_limit = 10 + chunks = create_sentence_chunks(test_text, entities, max_token_limit) + for i, chunk in enumerate(expected_chunks): + assert chunks[i].text == chunk + self.check_split_entities(entities, chunks, max_token_limit) From 1860930ac90b9af40c12d3dede858002fe33635c Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 22 Nov 2024 11:53:14 -0800 Subject: [PATCH 44/46] remove chunking logic to have simple sentence labeler. fix tests. --- flair/data.py | 2 +- flair/training_utils.py | 82 +++++----------- tests/test_sentence_labeling.py | 168 +++++++++++++++++--------------- 3 files changed, 116 insertions(+), 136 deletions(-) diff --git a/flair/data.py b/flair/data.py index 40175418a4..30a2daf313 100644 --- a/flair/data.py +++ b/flair/data.py @@ -565,7 +565,7 @@ def __init__( head_id: Optional[int] = None, whitespace_after: int = 1, start_position: int = 0, - sentence=None, + sentence: Optional["Sentence"] = None, ) -> None: super().__init__(sentence=sentence) diff --git a/flair/training_utils.py b/flair/training_utils.py index 69430a17b3..ea5b576f0c 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -6,7 +6,7 @@ from functools import reduce from math import inf from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Union +from typing import Literal, NamedTuple, Optional, Union from numpy import ndarray from scipy.stats import pearsonr, spearmanr @@ -102,7 +102,7 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list)) self.number_of_weights = number_of_weights - def extract_weights(self, state_dict: Dict, iteration: int) -> None: + def extract_weights(self, state_dict: dict, iteration: int) -> None: for key in state_dict: vec = state_dict[key] try: @@ -215,7 +215,7 @@ def __init__( raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") self.optimizer = optimizer - self.min_lrs: List[float] + self.min_lrs: list[float] if isinstance(min_lr, (list, tuple)): if len(min_lr) != len(optimizer.param_groups): raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}") @@ -315,10 +315,10 @@ def _init_is_better(self, mode: MinMax) -> None: self.mode = mode - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return {key: value for key, value in self.__dict__.items() if key != "optimizer"} - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) self._init_is_better(mode=self.mode) @@ -369,7 +369,7 @@ def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging. def store_embeddings( data_points: Union[list[DT], Dataset], storage_mode: EmbeddingStorageMode, - dynamic_embeddings: Optional[List[str]] = None, + dynamic_embeddings: Optional[list[str]] = None, ) -> None: """Stores embeddings of data points in memory or on disk. @@ -401,7 +401,7 @@ def store_embeddings( data_point.to("cpu", pin_memory=pin_memory) -def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]: +def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]: dynamic_embeddings = [] all_embeddings = [] for data_point in data_points: @@ -444,7 +444,7 @@ class CharEntity(NamedTuple): def create_labeled_sentence_from_tokens( - tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner" + tokens: Union[list[Token]], token_entities: list[TokenEntity], type_name: str = "ner" ) -> Sentence: """Creates a new Sentence object from a list of tokens or strings and applies entity labels. @@ -457,20 +457,18 @@ def create_labeled_sentence_from_tokens( Returns: A labeled Sentence object """ - tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence - sentence = Sentence(tokens, use_tokenizer=True) + tokens_ = [token.text for token in tokens] # create new tokens that do not already belong to a sentence + sentence = Sentence(tokens_, use_tokenizer=True) for entity in token_entities: sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score) return sentence -def create_sentence_chunks( +def create_labeled_sentence( text: str, - entities: List[CharEntity], - token_limit: int = 512, - use_context: bool = True, - overlap: int = 0, # TODO: implement overlap -) -> List[Sentence]: + entities: list[CharEntity], + token_limit: float = inf, +) -> Sentence: """Chunks and labels a text from a list of entity annotations. The function explicitly tokenizes the text and labels separately, ensuring entity labels are @@ -481,48 +479,25 @@ def create_sentence_chunks( entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the format (start_char_index, end_char_index, entity_class, entity_text). token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking - use_context: whether to add context to the sentence - overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context Returns: A list of labeled Sentence objects representing the chunks of the original text """ - chunks = [] - - tokens: List[Token] = [] + tokens: list[Token] = [] current_index = 0 - token_entities: List[TokenEntity] = [] - end_token_idx = 0 + token_entities: list[TokenEntity] = [] for entity in entities: - - if entity.start_char_idx > current_index: # add non-entity text - non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens - while end_token_idx + len(non_entity_tokens) > token_limit: - num_tokens = token_limit - len(tokens) - tokens.extend(non_entity_tokens[:num_tokens]) - non_entity_tokens = non_entity_tokens[num_tokens:] - # skip any fully negative samples, they cause fine_tune to fail with - # `torch.cat(): expected a non-empty list of Tensors` - if len(token_entities) > 0: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - tokens, token_entities = [], [] - end_token_idx = 0 - tokens.extend(non_entity_tokens) + if current_index < entity.start_char_idx: + # add tokens before the entity + sentence = Sentence(text[current_index : entity.start_char_idx]) + tokens.extend(sentence) # add new entity tokens start_token_idx = len(tokens) entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx]) - if len(entity_sentence) > token_limit: - logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}") end_token_idx = start_token_idx + len(entity_sentence) - if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - - tokens, token_entities = [], [] - start_token_idx, end_token_idx = 0, len(entity_sentence) - token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score) token_entities.append(token_entity) tokens.extend(entity_sentence) @@ -532,19 +507,10 @@ def create_sentence_chunks( # add any remaining tokens to a new chunk if current_index < len(text): remaining_sentence = Sentence(text[current_index:]) - if end_token_idx + len(remaining_sentence) > token_limit: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - tokens, token_entities = [], [] tokens.extend(remaining_sentence) - if tokens: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - - for chunk in chunks: - if len(chunk) > token_limit: - logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}") - - if use_context: - Sentence.set_context_for_sentences(chunks) + if isinstance(token_limit, int) and token_limit < len(tokens): + tokens = tokens[:token_limit] + token_entities = [entity for entity in token_entities if entity.end_token_idx <= token_limit] - return chunks + return create_labeled_sentence_from_tokens(tokens, token_entities) diff --git a/tests/test_sentence_labeling.py b/tests/test_sentence_labeling.py index 810ae21038..56742da4df 100644 --- a/tests/test_sentence_labeling.py +++ b/tests/test_sentence_labeling.py @@ -1,9 +1,9 @@ -from typing import Dict, List +from typing import cast import pytest from flair.data import Sentence -from flair.training_utils import CharEntity, create_sentence_chunks +from flair.training_utils import CharEntity, TokenEntity, create_labeled_sentence @pytest.fixture(params=["resume1.txt"]) @@ -28,7 +28,7 @@ def parsed_resume_dict(resume) -> dict: @pytest.fixture -def small_token_limit_resume() -> Dict: +def small_token_limit_resume() -> dict: return { "raw_text": "Professional Clown June 2020 - August 2021 Entertaining students of all ages. Blah Blah Blah " "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " @@ -46,7 +46,7 @@ def small_token_limit_resume() -> Dict: @pytest.fixture -def small_token_limit_response() -> List[Sentence]: +def small_token_limit_response() -> list[Sentence]: """Recreates expected response Sentences.""" chunk0 = Sentence("Professional Clown June 2020 - August 2021 Entertaining students of") chunk0[0:2].add_label("Professional Clown", "EXPERIENCE.TITLE") @@ -63,28 +63,32 @@ def small_token_limit_response() -> List[Sentence]: class TestChunking: def test_empty_string(self): - sentences = create_sentence_chunks("", []) + sentences = create_labeled_sentence("", []) assert len(sentences) == 0 - def check_split_entities(self, entity_labels, chunks, max_token_limit): - """Ensure that no entities are split over chunks (except entities longer than the token limit).""" - chunk_intervals = [] - start_index = 0 - for chunk in chunks: - end_index = start_index + len(chunk.text) - chunk_intervals.append((start_index, end_index)) - start_index = end_index + def check_tokens(self, sentence: Sentence, expected_tokens: list[str]): + assert len(sentence.tokens) == len(expected_tokens) + assert [token.text for token in sentence.tokens] == expected_tokens + for token, expected_token in zip(sentence.tokens, expected_tokens): + assert token.text == expected_token - for entity in entity_labels: - entity_start, entity_end = entity.start_char_idx, entity.end_char_idx - entity_length = entity_end - entity_start + def check_token_entities(self, sentence: Sentence, expected_labels: list[TokenEntity]): + assert len(sentence.labels) == len(expected_labels) + for label, expected_label in zip(sentence.labels, expected_labels): - # Skip the check if the entity itself is longer than the maximum token limit - if entity_length > max_token_limit: - continue + assert label.value == expected_label.label + span = cast(Sentence, label.data_point) + assert span.tokens[0]._internal_index is not None + assert span.tokens[0]._internal_index - 1 == expected_label.start_token_idx + assert span.tokens[-1]._internal_index is not None + assert span.tokens[-1]._internal_index - 1 == expected_label.end_token_idx - assert any( - start <= entity_start and entity_end <= end for start, end in chunk_intervals + def check_split_entities(self, entity_labels, sentence: Sentence): + """Ensure that no entities are split over chunks (except entities longer than the token limit).""" + for entity in entity_labels: + entity_start, entity_end = entity.start_char_idx, entity.end_char_idx + assert entity_start >= 0 and entity_end <= len( + sentence ), f"Entity {entity} is not within a single chunk interval" @pytest.mark.parametrize( @@ -95,57 +99,71 @@ def check_split_entities(self, entity_labels, chunks, max_token_limit): ("this ", "this"), ], ) - def test_short_text(self, test_text, expected_text): + def test_short_text(self, test_text: str, expected_text: str): """Short texts that should fit nicely into a single chunk.""" - chunks = create_sentence_chunks(test_text, []) - assert chunks[0].text == expected_text - - def test_create_flair_sentence(self, parsed_resume_dict): - chunks = create_sentence_chunks(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) - assert len(chunks) == 2 + chunks = create_labeled_sentence(test_text, []) + assert chunks.text == expected_text - max_token_limit = 512 # default - assert all(len(c) <= max_token_limit for c in chunks) - - self.check_split_entities(parsed_resume_dict["entities"], chunks, max_token_limit) - - def test_small_token_limit(self, small_token_limit_resume, small_token_limit_response): - max_token_limit = 10 # test a small max token limit - chunks = create_sentence_chunks( - small_token_limit_resume["raw_text"], small_token_limit_resume["entities"], token_limit=max_token_limit - ) - - for response, expected in zip(chunks, small_token_limit_response): - assert response.to_tagged_string() == expected.to_tagged_string() - - assert all(len(c) <= max_token_limit for c in chunks) - - self.check_split_entities(small_token_limit_resume["entities"], chunks, max_token_limit) + def test_create_labeled_sentence(self, parsed_resume_dict: dict): + create_labeled_sentence(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) @pytest.mark.parametrize( - "test_text, entities, expected_chunks", + "test_text, entities, expected_tokens, expected_labels", [ ( "Led a team of five engineers. It's important to note the project's success. We've implemented state-of-the-art technologies. Co-ordinated efforts with cross-functional teams.", [ - CharEntity(0, 25, "RESPONSIBILITY", "Led a team of five engineers"), - CharEntity(27, 72, "ACHIEVEMENT", "It's important to note the project's success"), - CharEntity(74, 117, "ACHIEVEMENT", "We've implemented state-of-the-art technologies"), - CharEntity(119, 168, "RESPONSIBILITY", "Co-ordinated efforts with cross-functional teams"), + CharEntity(0, 28, "RESPONSIBILITY", "Led a team of five engineers"), + CharEntity(30, 74, "ACHIEVEMENT", "It's important to note the project's success"), + CharEntity(76, 123, "ACHIEVEMENT", "We've implemented state-of-the-art technologies"), + CharEntity(125, 173, "RESPONSIBILITY", "Co-ordinated efforts with cross-functional teams"), + ], + [ + "Led", + "a", + "team", + "of", + "five", + "engineers", + ".", + "It", + "'s", + "important", + "to", + "note", + "the", + "project", + "'s", + "success", + ".", + "We", + "'ve", + "implemented", + "state-of-the-art", + "technologies", + ".", + "Co-ordinated", + "efforts", + "with", + "cross-functional", + "teams", + ".", ], [ - "Led a team of five engine er s. It 's important to note the project 's succe ss", - ". We 've implemented state-of-the-art techno lo gies . Co-ordinated efforts with cross-functional teams .", + TokenEntity(0, 5, "RESPONSIBILITY"), + TokenEntity(7, 15, "ACHIEVEMENT"), + TokenEntity(17, 21, "ACHIEVEMENT"), + TokenEntity(23, 27, "RESPONSIBILITY"), ], ), ], ) - def test_contractions_and_hyphens(self, test_text, entities, expected_chunks): - max_token_limit = 20 - chunks = create_sentence_chunks(test_text, entities, max_token_limit) - for i, chunk in enumerate(expected_chunks): - assert chunks[i].text == chunk - self.check_split_entities(entities, chunks, max_token_limit) + def test_contractions_and_hyphens( + self, test_text: str, entities: list[CharEntity], expected_tokens: list[str], expected_labels: list[TokenEntity] + ): + sentence = create_labeled_sentence(test_text, entities) + self.check_tokens(sentence, expected_tokens) + self.check_token_entities(sentence, expected_labels) @pytest.mark.parametrize( "test_text, entities", @@ -156,36 +174,32 @@ def test_contractions_and_hyphens(self, test_text, entities, expected_chunks): ) ], ) - def test_long_text(self, test_text, entities): + def test_long_text(self, test_text: str, entities: list[CharEntity]): """Test for handling long texts that should be split into multiple chunks.""" - max_token_limit = 512 - chunks = create_sentence_chunks(test_text, entities, max_token_limit) - assert len(chunks) > 1 - assert all(len(c) <= max_token_limit for c in chunks) - self.check_split_entities(entities, chunks, max_token_limit) + create_labeled_sentence(test_text, entities) @pytest.mark.parametrize( - "test_text, entities, expected_chunks", + "test_text, entities, expected_labels", [ ( "Hello! Is your company hiring? I am available for employment. Contact me at 5:00 p.m.", [ CharEntity(0, 6, "LABEL", "Hello!"), - CharEntity(7, 31, "LABEL", "Is your company hiring?"), - CharEntity(32, 65, "LABEL", "I am available for employment."), - CharEntity(66, 86, "LABEL", "Contact me at 5:00 p.m."), + CharEntity(7, 30, "LABEL", "Is your company hiring?"), + CharEntity(31, 61, "LABEL", "I am available for employment."), + CharEntity(62, 85, "LABEL", "Contact me at 5:00 p.m."), ], [ - "Hello ! Is your company hiring ? I", - "am available for employment . Con t", - "act me at 5:00 p.m .", + TokenEntity(0, 1, "LABEL"), + TokenEntity(2, 6, "LABEL"), + TokenEntity(7, 12, "LABEL"), + TokenEntity(13, 18, "LABEL"), ], ) ], ) - def test_text_with_punctuation(self, test_text, entities, expected_chunks): - max_token_limit = 10 - chunks = create_sentence_chunks(test_text, entities, max_token_limit) - for i, chunk in enumerate(expected_chunks): - assert chunks[i].text == chunk - self.check_split_entities(entities, chunks, max_token_limit) + def test_text_with_punctuation( + self, test_text: str, entities: list[CharEntity], expected_labels: list[TokenEntity] + ): + sentence = create_labeled_sentence(test_text, entities) + self.check_token_entities(sentence, expected_labels) From 5784834c14a8cc1834595bede7bcc113633f36c3 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 6 Dec 2024 23:39:42 -0800 Subject: [PATCH 45/46] fix: remove type hints from private module --- flair/training_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flair/training_utils.py b/flair/training_utils.py index ea5b576f0c..ce15bdb6e5 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -10,7 +10,6 @@ from numpy import ndarray from scipy.stats import pearsonr, spearmanr -from scipy.stats._stats_py import PearsonRResult, SignificanceResult from sklearn.metrics import mean_absolute_error, mean_squared_error from torch.optim import Optimizer from torch.utils.data import Dataset @@ -61,10 +60,10 @@ def mean_squared_error(self) -> Union[float, ndarray]: def mean_absolute_error(self): return mean_absolute_error(self.true, self.pred) - def pearsonr(self) -> PearsonRResult: + def pearsonr(self): return pearsonr(self.true, self.pred)[0] - def spearmanr(self) -> SignificanceResult: + def spearmanr(self): return spearmanr(self.true, self.pred)[0] def to_tsv(self) -> str: From 082e84569841625ecf467c05bfc6e9c12464ad3c Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Mon, 3 Feb 2025 20:28:07 -0800 Subject: [PATCH 46/46] fix: doc strings and function name --- flair/training_utils.py | 11 ++++++----- tests/test_sentence_labeling.py | 14 +++++++------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/flair/training_utils.py b/flair/training_utils.py index ce15bdb6e5..915df609a8 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -463,24 +463,25 @@ def create_labeled_sentence_from_tokens( return sentence -def create_labeled_sentence( +def create_labeled_sentence_from_entity_offsets( text: str, entities: list[CharEntity], token_limit: float = inf, ) -> Sentence: - """Chunks and labels a text from a list of entity annotations. + """Creates a labeled sentence from a text and a list of entity annotations. The function explicitly tokenizes the text and labels separately, ensuring entity labels are - not partially split across tokens. + not partially split across tokens. The sentence is truncated if a token limit is set. Args: text (str): The full text to be tokenized and labeled. entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the format (start_char_index, end_char_index, entity_class, entity_text). - token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking + token_limit: numerical value that determines the maximum token length of the sentence. + use inf to not perform chunking Returns: - A list of labeled Sentence objects representing the chunks of the original text + A labeled Sentence objects representing the text and entity annotations. """ tokens: list[Token] = [] current_index = 0 diff --git a/tests/test_sentence_labeling.py b/tests/test_sentence_labeling.py index 56742da4df..0bfb6dce94 100644 --- a/tests/test_sentence_labeling.py +++ b/tests/test_sentence_labeling.py @@ -3,7 +3,7 @@ import pytest from flair.data import Sentence -from flair.training_utils import CharEntity, TokenEntity, create_labeled_sentence +from flair.training_utils import CharEntity, TokenEntity, create_labeled_sentence_from_entity_offsets @pytest.fixture(params=["resume1.txt"]) @@ -63,7 +63,7 @@ def small_token_limit_response() -> list[Sentence]: class TestChunking: def test_empty_string(self): - sentences = create_labeled_sentence("", []) + sentences = create_labeled_sentence_from_entity_offsets("", []) assert len(sentences) == 0 def check_tokens(self, sentence: Sentence, expected_tokens: list[str]): @@ -101,11 +101,11 @@ def check_split_entities(self, entity_labels, sentence: Sentence): ) def test_short_text(self, test_text: str, expected_text: str): """Short texts that should fit nicely into a single chunk.""" - chunks = create_labeled_sentence(test_text, []) + chunks = create_labeled_sentence_from_entity_offsets(test_text, []) assert chunks.text == expected_text def test_create_labeled_sentence(self, parsed_resume_dict: dict): - create_labeled_sentence(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) + create_labeled_sentence_from_entity_offsets(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) @pytest.mark.parametrize( "test_text, entities, expected_tokens, expected_labels", @@ -161,7 +161,7 @@ def test_create_labeled_sentence(self, parsed_resume_dict: dict): def test_contractions_and_hyphens( self, test_text: str, entities: list[CharEntity], expected_tokens: list[str], expected_labels: list[TokenEntity] ): - sentence = create_labeled_sentence(test_text, entities) + sentence = create_labeled_sentence_from_entity_offsets(test_text, entities) self.check_tokens(sentence, expected_tokens) self.check_token_entities(sentence, expected_labels) @@ -176,7 +176,7 @@ def test_contractions_and_hyphens( ) def test_long_text(self, test_text: str, entities: list[CharEntity]): """Test for handling long texts that should be split into multiple chunks.""" - create_labeled_sentence(test_text, entities) + create_labeled_sentence_from_entity_offsets(test_text, entities) @pytest.mark.parametrize( "test_text, entities, expected_labels", @@ -201,5 +201,5 @@ def test_long_text(self, test_text: str, entities: list[CharEntity]): def test_text_with_punctuation( self, test_text: str, entities: list[CharEntity], expected_labels: list[TokenEntity] ): - sentence = create_labeled_sentence(test_text, entities) + sentence = create_labeled_sentence_from_entity_offsets(test_text, entities) self.check_token_entities(sentence, expected_labels)