-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsample_wavegan.py
executable file
·176 lines (148 loc) · 5.82 KB
/
sample_wavegan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# Copyright 2018 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sample from pre-trained WaveGAN model.
This script provides sampling from pre-trained WaveGAN model that is done
through the original author's code (https://github.com/chrisdonahue/wavegan).
The main purpose is to help manually check the quality of WaveGAN model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from operator import itemgetter
from os.path import join, expanduser
import numpy as np
from scipy.io import wavfile
import tensorflow as tf
from tqdm import tqdm
FLAGS = tf.flags.FLAGS
tf.flags.DEFINE_integer('total_per_label', '7000',
'Minimal # samples per label')
tf.flags.DEFINE_integer('top_per_label', '1700', '# of top samples per label')
tf.flags.DEFINE_boolean('selective', True,
'Whethor to be selective in finding top samples')
tf.flags.DEFINE_string('gen_ckpt_dir', '',
'The directory to WaveGAN generator\'s ckpt.')
tf.flags.DEFINE_string(
'inception_ckpt_dir', '',
'The directory to WaveGAN inception (classifier)\'s ckpt.')
tf.flags.DEFINE_string('latent_dir', '',
'The directory to WaveGAN\'s latent space.')
def main(unused_argv):
# pylint:disable=invalid-name
# Reason:
# Following variables have their name consider to be invalid by pylint so
# we disable the warning.
# - Variable that is class
del unused_argv
gen_ckpt_dir = expanduser(FLAGS.gen_ckpt_dir)
inception_ckpt_dir = expanduser(FLAGS.inception_ckpt_dir)
# TF init
tf.reset_default_graph()
# - generative model
graph_gan = tf.Graph()
with graph_gan.as_default():
sess_gan = tf.Session(graph=graph_gan)
saver_gan = tf.train.import_meta_graph(
join(gen_ckpt_dir, '..', 'infer', 'infer.meta'))
saver_gan.restore(sess_gan, join(gen_ckpt_dir, 'model.ckpt'))
# - classifier (inception)
graph_class = tf.Graph()
with graph_class.as_default():
sess_class = tf.Session(graph=graph_class)
saver_class = tf.train.import_meta_graph(
join(inception_ckpt_dir, 'infer.meta'))
saver_class.restore(sess_class, join(inception_ckpt_dir, 'best_acc-103005'))
# Generate: Tensor symbols
z = graph_gan.get_tensor_by_name('z:0')
G_z = graph_gan.get_tensor_by_name('G_z:0')[:, :, 0]
# G_z_spec = graph_gan.get_tensor_by_name('G_z_spec:0')
# Classification: Tensor symbols
x = graph_class.get_tensor_by_name('x:0')
scores = graph_class.get_tensor_by_name('scores:0')
# Sample something AND classify them
output_dir = expanduser(FLAGS.latent_dir)
tf.gfile.MakeDirs(output_dir)
np.random.seed(19260817)
total_per_label = FLAGS.total_per_label
top_per_label = FLAGS.top_per_label
group_by_label = [[] for _ in range(10)]
batch_size = 200
hidden_dim = 100
tf.logging.info('`selective` is %s', FLAGS.selective)
with tqdm(desc='min label count', unit=' #', total=total_per_label) as pbar:
label_count = [0] * 10
last_min_label_count = 0
while True:
min_label_count = min(label_count)
pbar.update(min_label_count - last_min_label_count)
last_min_label_count = min_label_count
_z = np.random.randn(batch_size, hidden_dim)
# _G_z, _G_z_spec = sess_gan.run([G_z, G_z_spec], {z: _z})
_G_z = sess_gan.run(G_z, {z: _z})
_x = _G_z
_scores = sess_class.run(scores, {x: _x})
_max_scores = np.max(_scores, axis=1)
_labels = np.argmax(_scores, axis=1)
for i in range(batch_size):
label = _labels[i]
group_by_label[label].append((_max_scores[i], (_z[i], _G_z[i])))
label_count[label] += 1
if len(group_by_label[label]) >= top_per_label * 2:
# remove unneeded tails
if FLAGS.selective:
group_by_label[label].sort(key=itemgetter(0), reverse=True)
group_by_label[label] = group_by_label[label][:top_per_label]
if last_min_label_count >= total_per_label:
break
for label in range(10):
if FLAGS.selective:
group_by_label[label].sort(key=itemgetter(0), reverse=True)
group_by_label[label] = group_by_label[label][:top_per_label]
# output a few samples as image
image_output_dir = join(output_dir, 'sample_iamge')
tf.gfile.MakeDirs(image_output_dir)
for label in range(10):
if FLAGS.selective:
group_by_label[label].sort(key=itemgetter(0), reverse=True)
index = 0
for confidence, (
_,
this_G_z,
) in group_by_label[label][:10]:
output_basename = 'predlabel=%d_index=%02d_confidence=%.6f' % (
label, index, confidence)
wavfile.write(
filename=join(image_output_dir, output_basename + '_sound.wav'),
rate=16000,
data=this_G_z)
# Make Numpy arrays and save everything as an npz file
array_label, array_z, array_G_z = [], [], []
for label in range(10):
for _, blob in group_by_label[label]:
this_z, this_G_z = blob[:2]
array_label.append(label)
array_z.append(this_z)
array_G_z.append(this_G_z)
array_label = np.array(array_label, dtype='i')
array_z = np.array(array_z)
array_G_z = np.array(array_G_z)
np.savez(
join(output_dir, 'data_train.npz'),
label=array_label,
z=array_z,
G_z=array_G_z,
)
# pylint:enable=invalid-name
if __name__ == '__main__':
tf.app.run(main)