-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathvocab.py
111 lines (81 loc) · 3.43 KB
/
vocab.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
"""
Generate the vocabulary file for neural network training
A vocabulary file is a mapping of tokens to their indices
Usage:
vocab.py --train-src=<file> --train-tgt=<file> [options] VOCAB_FILE
Options:
-h --help Show this screen.
--train-src=<file> File of training source sentences
--train-tgt=<file> File of training target sentences
--size=<int> vocab size [default: 50000]
--freq-cutoff=<int> frequency cutoff [default: 2]
"""
from typing import List
from collections import Counter
from itertools import chain
from docopt import docopt
import pickle
from utils import read_corpus, input_transpose
class VocabEntry(object):
def __init__(self):
self.word2id = dict()
self.unk_id = 3
self.word2id['<pad>'] = 0
self.word2id['<s>'] = 1
self.word2id['</s>'] = 2
self.word2id['<unk>'] = 3
self.id2word = {v: k for k, v in self.word2id.items()}
def __getitem__(self, word):
return self.word2id.get(word, self.unk_id)
def __contains__(self, word):
return word in self.word2id
def __setitem__(self, key, value):
raise ValueError('vocabulary is readonly')
def __len__(self):
return len(self.word2id)
def __repr__(self):
return 'Vocabulary[size=%d]' % len(self)
def id2word(self, wid):
return self.id2word[wid]
def add(self, word):
if word not in self:
wid = self.word2id[word] = len(self)
self.id2word[wid] = word
return wid
else:
return self[word]
def words2indices(self, sents):
if type(sents[0]) == list:
return [[self[w] for w in s] for s in sents]
else:
return [self[w] for w in sents]
@staticmethod
def from_corpus(corpus, size, freq_cutoff=2):
vocab_entry = VocabEntry()
word_freq = Counter(chain(*corpus))
valid_words = [w for w, v in word_freq.items() if v >= freq_cutoff]
print(f'number of word types: {len(word_freq)}, number of word types w/ frequency >= {freq_cutoff}: {len(valid_words)}')
top_k_words = sorted(valid_words, key=lambda w: word_freq[w], reverse=True)[:size]
for word in top_k_words:
vocab_entry.add(word)
return vocab_entry
class Vocab(object):
def __init__(self, src_sents, tgt_sents, vocab_size, freq_cutoff):
assert len(src_sents) == len(tgt_sents)
print('initialize source vocabulary ..')
self.src = VocabEntry.from_corpus(src_sents, vocab_size, freq_cutoff)
print('initialize target vocabulary ..')
self.tgt = VocabEntry.from_corpus(tgt_sents, vocab_size, freq_cutoff)
def __repr__(self):
return 'Vocab(source %d words, target %d words)' % (len(self.src), len(self.tgt))
if __name__ == '__main__':
args = docopt(__doc__)
print('read in source sentences: %s' % args['--train-src'])
print('read in target sentences: %s' % args['--train-tgt'])
src_sents = read_corpus(args['--train-src'], source='src')
tgt_sents = read_corpus(args['--train-tgt'], source='tgt')
vocab = Vocab(src_sents, tgt_sents, int(args['--size']), int(args['--freq-cutoff']))
print('generated vocabulary, source %d words, target %d words' % (len(vocab.src), len(vocab.tgt)))
pickle.dump(vocab, open(args['VOCAB_FILE'], 'wb'))
print('vocabulary saved to %s' % args['VOCAB_FILE'])