-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsaved_func_load_data_nmt.py
163 lines (141 loc) · 6.05 KB
/
saved_func_load_data_nmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import tensorflow as tf
import os
import collections
import zipfile
import hashlib
import requests
def load_data_nmt(batch_size, num_steps, num_examples=600):
"""Return the iterator and the vocabularies of the translation dataset."""
text = preprocess_nmt(read_data_nmt())
source, target = tokenize_nmt(text, num_examples)
src_vocab = Vocab(source, min_freq=2,
reserved_tokens=['<pad>', '<bos>', '<eos>'])
tgt_vocab = Vocab(target, min_freq=2,
reserved_tokens=['<pad>', '<bos>', '<eos>'])
src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
data_iter = load_array(data_arrays, batch_size)
return data_iter, src_vocab, tgt_vocab
def preprocess_nmt(text):
"""Preprocess the English-French dataset."""
def no_space(char, prev_char):
return char in set(',.!?') and prev_char != ' '
# Replace non-breaking space with space, and convert uppercase letters to
# lowercase ones
text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
# Insert space between words and punctuation marks
out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
for i, char in enumerate(text)]
return ''.join(out)
def read_data_nmt():
"""Load the English-French dataset."""
data_dir = download_extract('fra-eng')
with open(os.path.join(data_dir, 'fra.txt'), 'r') as f:
return f.read()
def download_extract(name, folder=None):
"""Download and extract a zip/tar file."""
fname = download(name)
base_dir = os.path.dirname(fname)
data_dir, ext = os.path.splitext(fname)
if ext == '.zip':
fp = zipfile.ZipFile(fname, 'r')
elif ext in ('.tar', '.gz'):
fp = tarfile.open(fname, 'r')
else:
assert False, 'Only zip/tar files can be extracted.'
fp.extractall(base_dir)
return os.path.join(base_dir, folder) if folder else data_dir
def download(name, cache_dir=os.path.join('..', 'data')):
"""Download a file inserted into DATA_HUB, return the local filename."""
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'
DATA_HUB = dict()
DATA_HUB['fra-eng'] = (DATA_URL + 'fra-eng.zip',
'94646ad1522d915e7b0f9296181140edcf86a4f5')
assert name in DATA_HUB, f"{name} does not exist in {DATA_HUB}."
url, sha1_hash = DATA_HUB[name]
os.makedirs(cache_dir, exist_ok=True)
fname = os.path.join(cache_dir, url.split('/')[-1])
if os.path.exists(fname):
sha1 = hashlib.sha1()
with open(fname, 'rb') as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
if sha1.hexdigest() == sha1_hash:
return fname # Hit cache
print(f'Downloading {fname} from {url}...')
r = requests.get(url, stream=True, verify=True)
with open(fname, 'wb') as f:
f.write(r.content)
return fname
def tokenize_nmt(text, num_examples=None):
"""Tokenize the English-French dataset."""
source, target = [], []
for i, line in enumerate(text.split('\n')):
if num_examples and i > num_examples:
break
parts = line.split('\t')
if len(parts) == 2:
source.append(parts[0].split(' '))
target.append(parts[1].split(' '))
return source, target
class Vocab:
"""Vocabulary for text."""
def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
if tokens is None:
tokens = []
if reserved_tokens is None:
reserved_tokens = []
# Sort according to frequencies
counter = count_corpus(tokens)
self.token_freqs = sorted(counter.items(), key=lambda x: x[1],
reverse=True)
# The index for the unknown token is 0
self.unk, uniq_tokens = 0, ['<unk>'] + reserved_tokens
uniq_tokens += [token for token, freq in self.token_freqs
if freq >= min_freq and token not in uniq_tokens]
self.idx_to_token, self.token_to_idx = [], dict()
for token in uniq_tokens:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
def __len__(self):
return len(self.idx_to_token)
def __getitem__(self, tokens):
if not isinstance(tokens, (list, tuple)):
return self.token_to_idx.get(tokens, self.unk)
return [self.__getitem__(token) for token in tokens]
def to_tokens(self, indices):
if not isinstance(indices, (list, tuple)):
return self.idx_to_token[indices]
return [self.idx_to_token[index] for index in indices]
def count_corpus(tokens):
"""Count token frequencies."""
# Here `tokens` is a 1D list or 2D list
if len(tokens) == 0 or isinstance(tokens[0], list):
# Flatten a list of token lists into a list of tokens
tokens = [token for line in tokens for token in line]
return collections.Counter(tokens)
def build_array_nmt(lines, vocab, num_steps):
"""Transform text sequences of machine translation into minibatches."""
lines = [vocab[l] for l in lines]
lines = [l + [vocab['<eos>']] for l in lines]
array = tf.constant([truncate_pad(
l, num_steps, vocab['<pad>']) for l in lines])
valid_len = tf.reduce_sum(
tf.cast(array != vocab['<pad>'], tf.int32), 1)
return array, valid_len
def truncate_pad(line, num_steps, padding_token):
"""Truncate or pad sequences."""
if len(line) > num_steps:
return line[:num_steps] # Truncate
return line + [padding_token] * (num_steps - len(line)) # Pad
def load_array(data_arrays, batch_size, is_train=True):
"""Construct a TensorFlow data iterator."""
dataset = tf.data.Dataset.from_tensor_slices(data_arrays)
if is_train:
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(batch_size)
return dataset