-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataset.py
104 lines (86 loc) · 3.86 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
import random
import pickle
class SelectionDataset(Dataset):
def __init__(self, file_path, context_transform, response_transform, concat_transform, sample_cnt=None, mode='poly'):
self.context_transform = context_transform
self.response_transform = response_transform
self.concat_transform = concat_transform
self.data_source = []
self.mode = mode
neg_responses = []
with open(file_path, encoding='utf-8') as f:
group = {
'context': None,
'responses': [],
'labels': []
}
for line in f:
split = line.strip().split('\t')
lbl, context, response = int(split[0]), split[1], split[2]
if lbl == 1 and len(group['responses']) > 0:
self.data_source.append(group)
group = {
'context': None,
'responses': [],
'labels': []
}
if sample_cnt is not None and len(self.data_source) >= sample_cnt:
break
else:
neg_responses.append(response)
group['responses'].append(response)
group['labels'].append(lbl)
group['context'] = context
if len(group['responses']) > 0:
self.data_source.append(group)
def __len__(self):
return len(self.data_source)
def __getitem__(self, index):
group = self.data_source[index]
context, responses, labels = group['context'], group['responses'], group['labels']
if self.mode == 'cross':
transformed_text = self.concat_transform(context, responses)
ret = transformed_text, labels
else:
transformed_context = self.context_transform(context) # [token_ids],[seg_ids],[masks]
transformed_responses = self.response_transform(responses) # [token_ids],[seg_ids],[masks]
ret = transformed_context, transformed_responses, labels
return ret
def batchify_join_str(self, batch):
if self.mode == 'cross':
text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch = [], [], []
labels_batch = []
for sample in batch:
text_token_ids_list, text_input_masks_list, text_segment_ids_list = sample[0]
text_token_ids_list_batch.append(text_token_ids_list)
text_input_masks_list_batch.append(text_input_masks_list)
text_segment_ids_list_batch.append(text_segment_ids_list)
labels_batch.append(sample[1])
long_tensors = [text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch]
text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch = (
torch.tensor(t, dtype=torch.long) for t in long_tensors)
labels_batch = torch.tensor(labels_batch, dtype=torch.long)
return text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch, labels_batch
else:
contexts_token_ids_list_batch, contexts_input_masks_list_batch, \
responses_token_ids_list_batch, responses_input_masks_list_batch = [], [], [], []
labels_batch = []
for sample in batch:
(contexts_token_ids_list, contexts_input_masks_list), (responses_token_ids_list, responses_input_masks_list) = sample[:2]
contexts_token_ids_list_batch.append(contexts_token_ids_list)
contexts_input_masks_list_batch.append(contexts_input_masks_list)
responses_token_ids_list_batch.append(responses_token_ids_list)
responses_input_masks_list_batch.append(responses_input_masks_list)
labels_batch.append(sample[-1])
long_tensors = [contexts_token_ids_list_batch, contexts_input_masks_list_batch,
responses_token_ids_list_batch, responses_input_masks_list_batch]
contexts_token_ids_list_batch, contexts_input_masks_list_batch, \
responses_token_ids_list_batch, responses_input_masks_list_batch = (
torch.tensor(t, dtype=torch.long) for t in long_tensors)
labels_batch = torch.tensor(labels_batch, dtype=torch.long)
return contexts_token_ids_list_batch, contexts_input_masks_list_batch, \
responses_token_ids_list_batch, responses_input_masks_list_batch, labels_batch