-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
85 lines (67 loc) · 2.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import random
import collections
import numpy as np
import torch
from easydict import EasyDict
from transformers import AutoTokenizer, default_data_collator
import yaml
def load_config(config_path="./config.yaml"):
# Read config.yaml file
with open(config_path) as infile:
SAVED_CFG = yaml.load(infile, Loader=yaml.FullLoader)
CFG = EasyDict(SAVED_CFG["CFG"])
return CFG
CFG = load_config()
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)
tokenizer = AutoTokenizer.from_pretrained(CFG.model_checkpoint)
def tokenize_function(examples):
result = tokenizer(examples["text"])
if tokenizer.is_fast:
result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
return result
def group_texts(examples):
# Concatenate all texts
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
# Compute length of concatenated texts
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the last chunk if it's smaller than CFG.chunk_size
total_length = (total_length // CFG.chunk_size) * CFG.chunk_size
# Split by chunks of max_len
result = {
k: [t[i : i + CFG.chunk_size] for i in range(0, total_length, CFG.chunk_size)]
for k, t in concatenated_examples.items()
}
# Create a new labels column
result["labels"] = result["input_ids"].copy()
return result
def whole_word_masking_data_collator(features):
for feature in features:
word_ids = feature.pop("word_ids")
# Create a map between words and corresponding token indices
mapping = collections.defaultdict(list)
current_word_index = -1
current_word = None
for idx, word_id in enumerate(word_ids):
if word_id is not None:
if word_id != current_word:
current_word = word_id
current_word_index += 1
mapping[current_word_index].append(idx)
# Randomly mask words
mask = np.random.binomial(1, CFG.whole_word_masking_probability, (len(mapping),))
input_ids = feature["input_ids"]
labels = feature["labels"]
new_labels = [-100] * len(labels)
for word_id in np.where(mask)[0]:
word_id = word_id.item()
for idx in mapping[word_id]:
new_labels[idx] = labels[idx]
input_ids[idx] = tokenizer.mask_token_id
return default_data_collator(features)