-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathpre_process.py
68 lines (49 loc) · 1.65 KB
/
pre_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import zipfile
from collections import Counter
import jieba
from tqdm import tqdm
from config import *
from utils import ensure_folder
def extract(folder):
filename = '{}.zip'.format(folder)
print('Extracting {}...'.format(filename))
with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall('data')
def create_input_files():
json_path = train_annotations_filename
# Read JSON
with open(json_path, 'r') as j:
samples = json.load(j)
# Read image paths and captions for each image
word_freq = Counter()
for sample in tqdm(samples):
caption = sample['caption']
for c in caption:
seg_list = jieba.cut(c, cut_all=True)
# Update word frequency
word_freq.update(seg_list)
# Create word map
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0
print(len(word_map))
print(words[:10])
# Save word map to a JSON
with open(os.path.join(data_folder, 'WORDMAP.json'), 'w') as j:
json.dump(word_map, j)
if __name__ == '__main__':
# parameters
ensure_folder('data')
if not os.path.isdir(train_image_folder):
extract(train_folder)
if not os.path.isdir(valid_image_folder):
extract(valid_folder)
if not os.path.isdir(test_a_image_folder):
extract(test_a_folder)
if not os.path.isdir(test_b_image_folder):
extract(test_b_folder)
create_input_files()