-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
69 lines (60 loc) · 3.18 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import torch
from torch.utils.data import DataLoader
import hyperparams as hparams
from utils.audio_process import spectrogram
from utils.data import LJSpeechDataset, RandomBucketBatchSampler, TextAudioCollate
from src.loss import FeaturePredictNetLoss
from src.model import FeaturePredictNet
from utils.text_process import text_to_sequence
from utils.solver import Solver
parser = argparse.ArgumentParser("Tacotron2 FeaturePredictNet Training")
parser.add_argument('--train_dir', type=str, required=True, help='dir including wav')
parser.add_argument('--train_csv', type=str, default='metadata.csv', help='csv file such metadata.csv')
parser.add_argument('--use_cuda', type=int, default=1)
parser.add_argument('--epochs', default=500, type=int)
parser.add_argument('--max_norm', default=1, type=float, help='Gradient norm threshold to clip')
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--lr', default=1e-3, type=float, help='Init learning rate')
parser.add_argument('--l2', default=0.0, type=float, help='weight decay (L2)')
parser.add_argument('--save_folder', default='exp/temp', help='Dir to save models')
parser.add_argument('--checkpoint', default=1, type=int, help='Enables checkpoint saving of model')
parser.add_argument('--continue_from', default='', help='Continue from checkpoint model')
parser.add_argument('--model_path', default='final.pth.tar', help='model name')
parser.add_argument('--print_freq', default=1, type=int, help='Frequency of printing training infomation')
parser.add_argument('--visdom', type=int, default=0, help='Turn on visdom graphing')
parser.add_argument('--visdom_epoch', type=int, default=0, help='Turn on visdom graphing each epoch')
parser.add_argument('--visdom_id', default='Taco2 training', help='Identifier for visdom run')
def main(args):
dataset = LJSpeechDataset(args.train_dir, args.train_csv,
text_transformer=text_to_sequence,
audio_transformer=spectrogram)
print(len(dataset))
batch_sampler = RandomBucketBatchSampler(dataset,
batch_size=args.batch_size,
drop_last=False)
collate_fn = TextAudioCollate()
data_loader = DataLoader(dataset, batch_sampler=batch_sampler,
collate_fn=collate_fn, num_workers=1)
# Build model
print(next(iter(data_loader)))
print("{} {} {}".format(hparams.num_chars, hparams.padding_idx, hparams.feature_dim))
model = FeaturePredictNet(hparams.num_chars, hparams.padding_idx,
hparams.feature_dim)
# print(model)
if args.use_cuda:
# model = torch.nn.DataParallel(model)
model.cuda()
print(model)
# Build criterion
criterion = FeaturePredictNetLoss()
# Build optimizer
optimizier = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.l2,
betas=(0.9, 0.999), eps=1e-6)
solver = Solver(data_loader, model, criterion, optimizier, args)
solver.train()
if __name__ == '__main__':
args = parser.parse_args()
print(args)
main(args)