Skip to content

Commit

Permalink
Refactor label map creation
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasgarbas committed Nov 12, 2024
1 parent 0a763e6 commit 3fe5ee1
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions transformer_ranker/datacleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,13 @@ def _find_text_and_label_columns(dataset: Dataset, text_column: Optional[str] =
def _merge_textpairs(dataset: Dataset, text_column: str, text_pair_column: str) -> Tuple[Dataset, str]:
"""Concatenate text pairs into a single text using separator token"""
new_text_column_name = text_column + '+' + text_pair_column
print(dataset.column_names)

if text_pair_column not in dataset.column_names:
raise ValueError(
f"Text pair column name '{text_pair_column}' can not be found in the dataset. "
f"Use one of the following names for tex pair: {dataset.column_names}."
)
def merge_texts(dataset_entry: Dict[str, str]) -> Dict[str, str]:
dataset_entry[text_column] = dataset_entry[text_column] + " [SEP] " + dataset_entry[text_pair_column]
dataset_entry[new_text_column_name] = dataset_entry.pop(text_column)
Expand Down Expand Up @@ -229,7 +235,7 @@ def is_valid_entry(sample) -> bool:
has_text = bool(text) and (not isinstance(text, list) or '\uFE0F' not in text)

# Check if label is non-null and all elements are non-negative
valid_label = label is not None and (not isinstance(label, list) or all(value >= 0 for value in label))
valid_label = label is not None and (all(l >= 0 for l in label) if isinstance(label, list) else label >= 0)

return has_text and valid_label

Expand All @@ -250,13 +256,19 @@ def map_labels(dataset_entry):

@staticmethod
def _create_label_map(dataset: Dataset, label_column: str) -> Dict[str, int]:
"""Try to find feature names in a hf dataset"""
"""Try to find feature names in a hf dataset."""
label_names = (
getattr(getattr(dataset.features[label_column], 'feature', None), 'names', None)
or getattr(dataset.features[label_column], 'names', None)
or sorted({str(label) for labels in dataset[label_column] for label in labels})
getattr(getattr(dataset.features[label_column], 'feature', None), 'names', None)
or getattr(dataset.features[label_column], 'names', None)
)

# If label names are missing, create them manually
if not label_names:
label_names = sorted(
{str(label) for sublist in dataset[label_column]
for label in (sublist if isinstance(sublist, list) else [sublist])}
)

return {label: idx for idx, label in enumerate(label_names)}

@staticmethod
Expand Down

0 comments on commit 3fe5ee1

Please sign in to comment.