-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmirdnn_explain.py
executable file
·46 lines (35 loc) · 1.27 KB
/
mirdnn_explain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#!/usr/bin/python3
import sys
import os.path
import torch as tr
from tqdm import tqdm
import csv
from torch.utils.data import DataLoader
from src.model import mirDNN
from src.parameters import ParameterParser
from src.fold_dataset import FoldDataset
def main(argv):
pp = ParameterParser(argv)
model = mirDNN(pp)
model.load(pp.model_file)
model.eval()
for i, ifile in enumerate(pp.input_files):
dataset = FoldDataset([ifile], pp.seq_len)
ind = tr.LongTensor(range(pp.seq_len))
with open(pp.output_file[i], 'w') as csvfile:
of = csv.writer(csvfile, delimiter=',', )
line = ["sequence_name"] + [",score"] + \
[",N{0}".format(i) for i in range(pp.seq_len)]
of.writerow(line)
for i, data in enumerate(tqdm(dataset)):
x, v, _ = data
mean = model(x.unsqueeze(0), v.unsqueeze(0)).cpu().detach().item()
x = x.repeat(pp.seq_len, 1)
x[ind,ind] = 0
v = v.repeat(pp.seq_len, 1)
z = model(x, v).cpu().detach().squeeze()
z = mean - z
line = [dataset.name[i]] + [mean] + z.tolist()
of.writerow(line)
if __name__ == "__main__":
main(sys.argv[1:])