-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmirdnn_eval.py
executable file
·32 lines (27 loc) · 982 Bytes
/
mirdnn_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#!/usr/bin/python3
import sys
import os.path
import torch as tr
import csv
from torch.utils.data import DataLoader
from src.model import mirDNN
from src.parameters import ParameterParser
from src.fold_dataset import FoldDataset
def main(argv):
pp = ParameterParser(argv)
model = mirDNN(pp)
model.load(pp.model_file)
model.eval()
for i, ifile in enumerate(pp.input_files):
dataset = FoldDataset([ifile], pp.seq_len)
loader = DataLoader(dataset, batch_size=pp.batch_size, pin_memory=True)
with open(pp.output_file[i], 'w') as csvfile:
of = csv.writer(csvfile, delimiter=',', )
for i, sample in enumerate(loader):
seq, val, _ = sample
res = model(seq, val).data.tolist()
for k, pred in enumerate(res):
line = [dataset.name[i * pp.batch_size + k], pred[0]]
of.writerow(line)
if __name__ == "__main__":
main(sys.argv[1:])