forked from zkx06111/WSRGlow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
93 lines (79 loc) · 2.93 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import glob
import re
import pdb
import torch.nn as nn
import torch
import librosa
import soundfile as sf
from hparams import hparams, set_hparams
import numpy as np
import pyloudnorm as pyln
from model import WaveGlowMelHF
from utils import load_ckpt
def run(model, wav, sigma=1.0):
wav = torch.Tensor(wav).reshape(1, -1).cuda()
output = np.array(model.infer(wav, sigma=sigma)[0].cpu().detach())
output = output.reshape(-1)
return output
def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
'''compute right padding (final frame) or both sides padding (first and final frames)
'''
assert pad_sides in (1, 2)
# return int(fsize // 2)
pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
if pad_sides == 1:
return 0, pad
else:
return pad // 2, pad // 2 + pad % 2
def load_wav(wav_fn):
wav, sr = librosa.core.load(wav_fn, sr=hparams['sampling_rate'] // 4)
print(wav.shape, sr, hparams['sampling_rate'])
if hparams['loud_norm']:
print('LOUD NORM!', flush=True)
meter = pyln.Meter(sr) # create BS.1770 meter
loudness = meter.integrated_loudness(wav)
wav = pyln.normalize.loudness(wav, loudness, -22.0)
if np.abs(wav).max() > 1:
wav = wav / np.abs(wav).max()
# get amplitude spectrogram
fft_size = hparams['fft_size']
hop_size = hparams['hop_size']
win_length = hparams['win_size']
fmin = hparams['fmin']
fmax = hparams['fmax']
sample_rate = hparams['sampling_rate']
num_mels = hparams['num_mels']
x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
win_length=win_length, window='hann', pad_mode="constant")
spc = np.abs(x_stft) # (n_bins, T)
# get mel basis
fmin = 0 if fmin == -1 else fmin
fmax = sample_rate / 2 if fmax == -1 else fmax
mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)
mel = mel_basis @ spc
eps = 1e-10
mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
wav = wav[:mel.shape[1] * hop_size]
return wav, sr
if __name__ == '__main__':
set_hparams()
model = WaveGlowMelHF(**hparams['waveglow_config']).cuda()
#pdb.set_trace()
load_ckpt(model, 'checkpoints/DIDI_GLOWSR4/model_ckpt_best.pt')
model.eval()
fns = ['12k_db_001_000.wav',
'12k_db_001_001.wav',
'12k_db_001_002.wav',
'12k_db_001_003.wav',
'12k_db_001_004.wav']
sigma = 1
for lr_fn in fns:
lr, sr = load_wav('inference_code/data/' + lr_fn)
print(f'lr.shape = {lr.shape}', flush=True)
with torch.no_grad():
pred = run(model, lr, sigma=sigma)
print(lr.shape, pred.shape)
pred_fn = f'inference_code/data/{sigma}_pred_{lr_fn}'
sf.write(open(pred_fn, 'wb'), pred, sr * 4)