Skip to content

Commit

Permalink
Merge branch 'master' into gh-3488/save-column-corpus-to-files
Browse files Browse the repository at this point in the history
  • Loading branch information
alanakbik authored Feb 4, 2025
2 parents 522621a + 8acd698 commit 95998f3
Show file tree
Hide file tree
Showing 39 changed files with 1,554 additions and 135 deletions.
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# -- Project information -----------------------------------------------------
from sphinx_github_style import get_linkcode_resolve
from torch.nn import Module

version = "0.15.0"
release = "0.15.0"
Expand Down
9 changes: 7 additions & 2 deletions flair/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
import inspect
from collections.abc import Iterable
from types import ModuleType
from typing import Any, Optional, TypeVar, Union, overload
from typing import Any, Iterable, List, Optional, Protocol, Type, TypeVar, Union, overload


T = TypeVar("T")


def get_non_abstract_subclasses(cls: type[T]) -> Iterable[type[T]]:
class StringLike(Protocol):
def __str__(self) -> str: ...


def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]:
for subclass in cls.__subclasses__():
yield from get_non_abstract_subclasses(subclass)
if inspect.isabstract(subclass):
Expand Down
9 changes: 3 additions & 6 deletions flair/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def __init__(
head_id: Optional[int] = None,
whitespace_after: int = 1,
start_position: int = 0,
sentence=None,
sentence: Optional["Sentence"] = None,
) -> None:
super().__init__(sentence=sentence)

Expand Down Expand Up @@ -1419,8 +1419,7 @@ def __init__(
sample_missing_splits: Union[bool, str] = True,
random_seed: Optional[int] = None,
) -> None:
"""
Constructor method to initialize a :class:`Corpus`. You can define the train, dev and test split
"""Constructor method to initialize a :class:`Corpus`. You can define the train, dev and test split
by passing the corresponding Dataset object to the constructor. At least one split should be defined.
If the option `sample_missing_splits` is set to True, missing splits will be randomly sampled from the
train split.
Expand Down Expand Up @@ -1514,7 +1513,6 @@ def downsample(
Returns:
A pointer to itself for optional use in method chaining.
"""

if downsample_train and self._train is not None:
self._train = self._downsample_to_proportion(self._train, percentage, random_seed)

Expand All @@ -1541,8 +1539,7 @@ def filter_empty_sentences(self):
log.info(self)

def filter_long_sentences(self, max_charlength: int):
"""
A method that filters all sentences for which the plain text is longer than a specified number of characters.
"""A method that filters all sentences for which the plain text is longer than a specified number of characters.
This is an in-place operation that directly modifies the Corpus object itself by removing these sentences.
Expand Down
2 changes: 2 additions & 0 deletions flair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@
NER_ARABIC_ANER,
NER_ARABIC_AQMAR,
NER_BASQUE,
NER_BAVARIAN_WIKI,
NER_CHINESE_WEIBO,
NER_DANISH_DANE,
NER_ENGLISH_MOVIE_COMPLEX,
Expand Down Expand Up @@ -477,6 +478,7 @@
"NER_ARABIC_ANER",
"NER_ARABIC_AQMAR",
"NER_BASQUE",
"NER_BAVARIAN_WIKI",
"NER_CHINESE_WEIBO",
"NER_DANISH_DANE",
"NER_ENGLISH_MOVIE_COMPLEX",
Expand Down
4 changes: 2 additions & 2 deletions flair/datasets/document_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,7 @@ def __init__(
if not rebalance_corpus and dataset == "test":
data_file = test_data_file

with open(data_file, "at") as f_p:
with open(data_file, "a") as f_p:
current_path = data_path / "aclImdb" / dataset / label
for file_name in current_path.iterdir():
if file_name.is_file() and file_name.name.endswith(".txt"):
Expand Down Expand Up @@ -891,7 +891,7 @@ def __init__(
data_path / "original",
members=[m for m in f_in.getmembers() if f"{dataset}/{label}" in m.name],
)
with open(f"{data_path}/{dataset}.txt", "at", encoding="utf-8") as f_p:
with open(f"{data_path}/{dataset}.txt", "a", encoding="utf-8") as f_p:
current_path = data_path / "original" / dataset / label
for file_name in current_path.iterdir():
if file_name.is_file():
Expand Down
122 changes: 121 additions & 1 deletion flair/datasets/sequence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def __init__(
label_name_map: Optional[dict[str, str]] = None,
banned_sentences: Optional[list[str]] = None,
default_whitespace_after: int = 1,
every_sentence_is_independent: bool = False,
documents_as_sentences: bool = False,
**corpusargs,
) -> None:
r"""Instantiates a Corpus from CoNLL column-formatted task data such as CoNLL03 or CoNLL2000.
Expand Down Expand Up @@ -358,6 +360,8 @@ def __init__(
skip_first_line=skip_first_line,
label_name_map=label_name_map,
default_whitespace_after=default_whitespace_after,
every_sentence_is_independent=every_sentence_is_independent,
documents_as_sentences=documents_as_sentences,
)
for train_file in train_files
]
Expand All @@ -382,6 +386,8 @@ def __init__(
skip_first_line=skip_first_line,
label_name_map=label_name_map,
default_whitespace_after=default_whitespace_after,
every_sentence_is_independent=every_sentence_is_independent,
documents_as_sentences=documents_as_sentences,
)
for test_file in test_files
]
Expand All @@ -406,6 +412,8 @@ def __init__(
skip_first_line=skip_first_line,
label_name_map=label_name_map,
default_whitespace_after=default_whitespace_after,
every_sentence_is_independent=every_sentence_is_independent,
documents_as_sentences=documents_as_sentences,
)
for dev_file in dev_files
]
Expand Down Expand Up @@ -621,10 +629,12 @@ def __init__(
banned_sentences: Optional[list[str]] = None,
in_memory: bool = True,
document_separator_token: Optional[str] = None,
every_sentence_is_independent: bool = False,
encoding: str = "utf-8",
skip_first_line: bool = False,
label_name_map: Optional[dict[str, str]] = None,
default_whitespace_after: int = 1,
documents_as_sentences: bool = False,
) -> None:
r"""Instantiates a column dataset.
Expand All @@ -645,9 +655,17 @@ def __init__(
self.column_delimiter = re.compile(column_delimiter)
self.comment_symbol = comment_symbol
self.document_separator_token = document_separator_token
self.every_sentence_is_independent = every_sentence_is_independent
self.label_name_map = label_name_map
self.banned_sentences = banned_sentences
self.default_whitespace_after = default_whitespace_after
self.documents_as_sentences = documents_as_sentences

if documents_as_sentences and not document_separator_token:
log.error(
"document_as_sentences was set to True, but no document_separator_token was provided. Please set"
"a value for document_separator_token in order to enable the document_as_sentence functionality."
)

# store either Sentence objects in memory, or only file offsets
self.in_memory = in_memory
Expand Down Expand Up @@ -842,6 +860,9 @@ def _convert_lines_to_sentence(
if sentence.to_original_text() == self.document_separator_token:
sentence.is_document_boundary = True

if self.every_sentence_is_independent or self.documents_as_sentences:
sentence.is_document_boundary = True

# add span labels
if span_level_tag_columns:
for span_column in span_level_tag_columns:
Expand Down Expand Up @@ -978,6 +999,13 @@ def write_dataset_to_file(self, label_types: List[str], file_path: Path, column_
ColumnCorpus._write_dataset_to_file(self, label_types, file_path, column_delimiter)

def __line_completes_sentence(self, line: str) -> bool:

if self.documents_as_sentences and self.document_separator_token:
if line.startswith(self.document_separator_token):
return True
else:
return False

sentence_completed = line.isspace() or line == ""
return sentence_completed

Expand Down Expand Up @@ -5195,7 +5223,8 @@ def __init__(
test_file=None,
column_format=columns,
in_memory=in_memory,
sample_missing_splits=False, # No test data is available, so do not shrink dev data for shared task preparation!
sample_missing_splits=False,
# No test data is available, so do not shrink dev data for shared task preparation!
**corpusargs,
)
corpora.append(corpus)
Expand Down Expand Up @@ -5661,3 +5690,94 @@ def __init__(
corpora,
name="masakha-pos-" + "-".join(languages),
)


class NER_BAVARIAN_WIKI(ColumnCorpus):
def __init__(
self,
fine_grained: bool = False,
revision: str = "main",
base_path: Optional[Union[str, Path]] = None,
in_memory: bool = True,
**corpusargs,
) -> None:
"""Initialize the Bavarian NER Bavarian NER Dataset (BarNER).
The dataset was proposed in the 2024 LREC-COLING paper
"Sebastian, Basti, Wastl?! Recognizing Named Entities in Bavarian Dialectal Data" paper by Peng et al.
:param fine_grained: Defines if the fine-grained or coarse-grained (default) should be used.
:param revision: Defines the revision/commit of BarNER dataset, by default dataset from 'main' branch is used.
:param base_path: Default is None, meaning that corpus gets auto-downloaded and loaded. You can override this
to point to a different folder but typically this should not be necessary.
:param in_memory: If True, keeps dataset in memory giving speedups in training.
"""
base_path = flair.cache_root / "datasets" if not base_path else Path(base_path)
dataset_name = self.__class__.__name__.lower()
data_folder = base_path / dataset_name
data_path = flair.cache_root / "datasets" / dataset_name

document_boundary_marker = "-DOCSTART-"

for split in ["train", "dev", "test"]:
# Get original version
original_split_filename = data_path / "original" / f"bar-wiki-{split}.tsv"
if not original_split_filename.is_file():
original_split_url = (
f"https://raw.githubusercontent.com/mainlp/BarNER/{revision}/data/BarNER-final/bar-wiki-{split}.tsv"
)
cached_path(original_split_url, data_path / "original")

# Add sentence boundary marker
modified_split_filename = data_path / f"bar-wiki-{split}.tsv"
if not modified_split_filename.is_file():
f_out = open(modified_split_filename, "w", encoding="utf-8")

with open(original_split_filename, encoding="utf-8") as f_p:
for line in f_p:
line = line.strip()
if line.startswith("# newdoc id = "):
f_out.write(f"{document_boundary_marker}\tO\n\n")
continue
if line.startswith("# "):
continue
f_out.write(f"{line}\n")
f_out.close()

columns = {0: "text", 1: "ner"}

label_name_map = None

if not fine_grained:
# Only allowed classes in course setting are: PER, LOC, ORG and MISC.
# All other NEs are normalized to O, except EVENT and WOA are normalized to MISC (cf. Table 3 of paper).
label_name_map = {
"EVENT": "MISC",
"EVENTderiv": "O",
"EVENTpart": "O",
"LANG": "O",
"LANGderiv": "O",
"LANGpart": "O",
"LOCderiv": "O",
"LOCpart": "O",
"MISCderiv": "O",
"MISCpart": "O",
"ORGderiv": "O",
"ORGpart": "O",
"PERderiv": "O",
"PERpart": "O",
"RELIGION": "O",
"RELIGIONderiv": "O",
"WOA": "MISC",
"WOAderiv": "O",
"WOApart": "O",
}

super().__init__(
data_folder,
columns,
in_memory=in_memory,
comment_symbol="# ",
document_separator_token="-DOCSTART-",
label_name_map=label_name_map,
**corpusargs,
)
3 changes: 2 additions & 1 deletion flair/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def aggregate(value, aggregation_fn=np.mean):

def validate_corpus_same_each_process(corpus: Corpus) -> None:
"""Catches most cases in which a corpus is not the same on each process. However, there is no guarantee for two
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable"""
reasons: 1) It uses a sample for speed 2) It compares strings to avoid requiring the datasets to be serializable
"""
for dataset in [corpus.train, corpus.dev, corpus.test]:
if dataset is not None:
_validate_dataset_same_each_process(dataset)
Expand Down
5 changes: 2 additions & 3 deletions flair/embeddings/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def embed(self, sentences: Union[list[Sentence], Sentence]):
sentences = [sentences]

raw_sentences = [s.to_original_text() for s in sentences]
tfidf_vectors = torch.from_numpy(self.vectorizer.transform(raw_sentences).A)
tfidf_vectors = torch.from_numpy(self.vectorizer.transform(raw_sentences).toarray())

for sentence_id, sentence in enumerate(sentences):
sentence.set_embedding(self.name, tfidf_vectors[sentence_id])
Expand Down Expand Up @@ -691,10 +691,9 @@ def _add_embeddings_internal(self, sentences: list[Sentence]):

lengths: list[int] = [len(sentence.tokens) for sentence in sentences]
padding_length: int = max(max(lengths), self.min_sequence_length)

pre_allocated_zero_tensor = torch.zeros(
self.embeddings.embedding_length * padding_length,
dtype=self.convs[0].weight.dtype,
dtype=cast(torch.nn.Conv1d, self.convs[0]).weight.dtype,
device=flair.device,
)

Expand Down
2 changes: 1 addition & 1 deletion flair/embeddings/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ def _add_embeddings_internal(self, sentences: list[Sentence]) -> list[Sentence]:
word = token.text if self.field is None else token.get_label(self.field).value

if word.strip() == "":
ids = [self.spm.vocab_size(), self.embedder.spm.vocab_size()]
ids = [self.spm.vocab_size(), self.spm.vocab_size()]
else:
if self.do_preproc:
word = self._preprocess(word)
Expand Down
2 changes: 1 addition & 1 deletion flair/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,4 @@ def load_torch_state(model_file: str) -> dict[str, typing.Any]:
# to load models on some Mac/Windows setups
# see https://github.com/zalandoresearch/flair/issues/351
f = load_big_file(model_file)
return torch.load(f, map_location="cpu")
return torch.load(f, map_location="cpu", weights_only=False)
8 changes: 7 additions & 1 deletion flair/models/multitask_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def evaluate( # type: ignore[override]
main_score = 0.0
all_detailed_results = ""
all_classification_report: dict[str, dict[str, Any]] = {}
scores: dict[Any, float] = {}

for task_id, split in batch_split.items():
result = self.tasks[task_id].evaluate(
Expand Down Expand Up @@ -194,7 +195,12 @@ def evaluate( # type: ignore[override]
)
all_classification_report[task_id] = result.classification_report

scores = {"loss": loss.item() / len(batch_split)}
# Add metrics so they will be available to _publish_eval_result.
for avg_type in ("micro avg", "macro avg"):
for metric_type in ("f1-score", "precision", "recall"):
scores[(task_id, avg_type, metric_type)] = result.classification_report[avg_type][metric_type]

scores["loss"] = loss.item() / len(batch_split)

return Result(
main_score=main_score / len(batch_split),
Expand Down
8 changes: 4 additions & 4 deletions flair/models/pairwise_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def evaluate(
f"spearman: {metric.spearmanr():.4f}"
)

scores = {
eval_metrics = {
"loss": eval_loss.item(),
"mse": metric.mean_squared_error(),
"mae": metric.mean_absolute_error(),
Expand All @@ -354,12 +354,12 @@ def evaluate(
}

if main_evaluation_metric[0] in ("correlation", "other"):
main_score = scores[main_evaluation_metric[1]]
main_score = eval_metrics[main_evaluation_metric[1]]
else:
main_score = scores["spearman"]
main_score = eval_metrics["spearman"]

return Result(
main_score=main_score,
detailed_results=detailed_result,
scores=scores,
scores=eval_metrics,
)
Loading

0 comments on commit 95998f3

Please sign in to comment.