forked from wyharveychen/CloserLookFewShot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsave_features.py
94 lines (76 loc) · 3.12 KB
/
save_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import torch
from torch.autograd import Variable
import os
import glob
import h5py
import configs
import backbone
from data.datamgr import SimpleDataManager
from methods.protonet_S import ProtoNet_S
from methods.protonet_DR import ProtoNet_DR
from methods.softmax_1nn import Softmax1NN
from methods.DR_1nn import DR1NN
from io_utils import model_dict, parse_args, get_resume_file, get_best_file, get_assigned_file
def save_features(model, data_loader, outfile ):
f = h5py.File(outfile, 'w')
max_count = len(data_loader)*data_loader.batch_size
all_labels = f.create_dataset('all_labels',(max_count,), dtype='i')
all_feats=None
count=0
for i, (x,y) in enumerate(data_loader):
if i%10 == 0:
print('{:d}/{:d}'.format(i, len(data_loader)))
x = x.cuda()
x_var = Variable(x)
feats = model(x_var)
print(feats.shape)
if all_feats is None:
all_feats = f.create_dataset('all_feats', [max_count] + list( feats.size()[1:]) , dtype='f')
all_feats[count:count+feats.size(0)] = feats.data.cpu().numpy()
all_labels[count:count+feats.size(0)] = y.cpu().numpy()
count = count + feats.size(0)
count_var = f.create_dataset('count', (1,), dtype='i')
count_var[0] = count
f.close()
if __name__ == '__main__':
params = parse_args('save_features')
if 'Conv' in params.model:
image_size = 84 #Conv4
else:
image_size = 224 #ResNet18
split = params.split
loadfile = configs.data_dir[params.dataset] + split + '.json'
checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
if params.train_aug:
checkpoint_dir += '_aug'
checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)
print(checkpoint_dir)
if params.save_iter != -1:
modelfile = get_assigned_file(checkpoint_dir,params.save_iter)
else:
modelfile = get_best_file(checkpoint_dir)
if params.save_iter != -1:
outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + "_" + str(params.save_iter)+ ".hdf5")
else:
outfile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + ".hdf5")
print(loadfile)
datamgr = SimpleDataManager(image_size, batch_size = 64)
data_loader = datamgr.get_data_loader(loadfile, aug = False)
model = model_dict[params.model]()
model = model.cuda()
tmp = torch.load(modelfile)
state = tmp['state']
state_keys = list(state.keys())
for i, key in enumerate(state_keys):
if "feature." in key:
newkey = key.replace("feature.","") # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
state[newkey] = state.pop(key)
else:
state.pop(key)
model.load_state_dict(state)
model.eval()
dirname = os.path.dirname(outfile)
if not os.path.isdir(dirname):
os.makedirs(dirname)
save_features(model, data_loader, outfile)