From 3fe5ee164da40e3d6d7d99967ea30ed55a5234fd Mon Sep 17 00:00:00 2001 From: Lukas Garbas Date: Tue, 12 Nov 2024 09:24:06 +0100 Subject: [PATCH] Refactor label map creation --- transformer_ranker/datacleaner.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/transformer_ranker/datacleaner.py b/transformer_ranker/datacleaner.py index e40af8f..c323b97 100644 --- a/transformer_ranker/datacleaner.py +++ b/transformer_ranker/datacleaner.py @@ -176,7 +176,13 @@ def _find_text_and_label_columns(dataset: Dataset, text_column: Optional[str] = def _merge_textpairs(dataset: Dataset, text_column: str, text_pair_column: str) -> Tuple[Dataset, str]: """Concatenate text pairs into a single text using separator token""" new_text_column_name = text_column + '+' + text_pair_column + print(dataset.column_names) + if text_pair_column not in dataset.column_names: + raise ValueError( + f"Text pair column name '{text_pair_column}' can not be found in the dataset. " + f"Use one of the following names for tex pair: {dataset.column_names}." + ) def merge_texts(dataset_entry: Dict[str, str]) -> Dict[str, str]: dataset_entry[text_column] = dataset_entry[text_column] + " [SEP] " + dataset_entry[text_pair_column] dataset_entry[new_text_column_name] = dataset_entry.pop(text_column) @@ -229,7 +235,7 @@ def is_valid_entry(sample) -> bool: has_text = bool(text) and (not isinstance(text, list) or '\uFE0F' not in text) # Check if label is non-null and all elements are non-negative - valid_label = label is not None and (not isinstance(label, list) or all(value >= 0 for value in label)) + valid_label = label is not None and (all(l >= 0 for l in label) if isinstance(label, list) else label >= 0) return has_text and valid_label @@ -250,13 +256,19 @@ def map_labels(dataset_entry): @staticmethod def _create_label_map(dataset: Dataset, label_column: str) -> Dict[str, int]: - """Try to find feature names in a hf dataset""" + """Try to find feature names in a hf dataset.""" label_names = ( - getattr(getattr(dataset.features[label_column], 'feature', None), 'names', None) - or getattr(dataset.features[label_column], 'names', None) - or sorted({str(label) for labels in dataset[label_column] for label in labels}) + getattr(getattr(dataset.features[label_column], 'feature', None), 'names', None) + or getattr(dataset.features[label_column], 'names', None) ) + # If label names are missing, create them manually + if not label_names: + label_names = sorted( + {str(label) for sublist in dataset[label_column] + for label in (sublist if isinstance(sublist, list) else [sublist])} + ) + return {label: idx for idx, label in enumerate(label_names)} @staticmethod