-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathbearsvm.cpp
425 lines (376 loc) · 15.3 KB
/
bearsvm.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
// The contents of this file are licensed under the MIT license.
// See LICENSE.txt for more information.
/*
This program will train a linear SVM to classify bears from bear face
embeddings.
*/
#include <iostream>
#include <ctime>
#include <vector>
#include <dlib/svm_threaded.h>
#include <dlib/svm.h>
#include <dlib/image_io.h>
#include <dlib/cmd_line_parser.h>
#include <boost/filesystem.hpp>
#include <boost/foreach.hpp>
#include <boost/algorithm/string.hpp>
#include <boost/property_tree/ptree.hpp>
#include <boost/property_tree/xml_parser.hpp>
using namespace std;
using namespace dlib;
using boost::property_tree::ptree;
ptree g_xml_tree;
std::string g_mode = "none";
std::vector <ptree> g_chips; // stores face boxes of embeddings
// This function spiders the top level directory and obtains a list of all the
// files.
std::vector<std::vector<string>> load_objects_list (
const string& dir
)
{
std::vector<std::vector<string>> objects;
for (auto subdir : directory(dir).get_dirs())
{
std::vector<string> imgs;
for (auto img : subdir.get_files())
imgs.push_back(img);
if (imgs.size() != 0)
objects.push_back(imgs);
}
return objects;
}
//-----------------------------------------------------------------
// Grab embeddings from xml and store each under its label.
// populate obj_labels and returns vector of embeddings.
//-----------------------------------------------------------------
std::vector<std::vector<string>> load_embeds_map (
const std::string& xml_file, std::vector<std::string>& obj_labels)
{
std::map<string,std::vector<std::string>> embeds_map;
ptree tree, empty_tree;
boost::property_tree::read_xml (xml_file, tree);
std::vector<std::vector<string>> objects; // return object
// add all embedsfiles to map by bearID
BOOST_FOREACH(ptree::value_type& child, tree.get_child("dataset.embeddings"))
{
std::string child_name = child.first;
if (child_name == "embedding")
{
ptree embedding = child.second;
std::string embedfile = child.second.get<std::string>("<xmlattr>.file");
std::string bearID = child.second.get<std::string>("label");
if (bearID.empty())
{
std::cout << "Error: embedfile " << embedfile << " has no bearID.\n" << endl;
continue;
}
ptree chip = embedding.get_child ("chip", empty_tree); // from bearface
if (chip == empty_tree)
std::cout << "Warning: embedfile " << embedfile << " has no face information.\n" << endl;
g_chips.push_back (chip);
embeds_map[bearID].push_back (embedfile);
}
}
// massage map of vector to return vector of vector
std::string key;
std::vector<std::string> value;
obj_labels.clear ();
std::map<std::string, std::vector<std::string>>::iterator it;
for ( it = embeds_map.begin(); it != embeds_map.end(); it++ )
{
objects.push_back (it->second);
obj_labels.push_back (it->first);
}
return objects;
}
//-----------------------------------------------------------------
// get content from metadata file and generating 4 vectors:
// embeddings (list of face embeddings. [b1.dat,b2-1.dat,b2-2.dat,b3.dat)
// labels_idx (list of label indices of respective embedding. 0,1,1,2)
// ids (list of labels ["b1", "b2", "b3"])
// embed_filenames (list of embeding files)
// label_id_map (map of label to id)
// create flattened list of embeddings
// if label_id_map exists (?), use exsisting mapping instead of new one.
//-----------------------------------------------------------------
void extract_embeds (std::string embed_xml,
std::vector<matrix<float,128,1>> &embeddings,
std::vector<double> & labels_idx,
std::vector <std::string> & ids,
std::vector <std::string> & embed_filenames,
std::map<std::string,int> label_id_map)
{
std::vector<std::vector<string>> emb_objs;
// gets ids from metadata file
emb_objs = load_embeds_map (embed_xml, ids);
double label_idx;
std::string embed_filename;
std::string mode = g_mode;
for (size_t i=0; i < emb_objs.size(); ++i)
{
label_idx = i;
if (label_id_map.size () > 0) // use existing mapping
{
if (label_id_map.count (ids[i]) > 0) // label exists
{
label_idx = label_id_map[ids[i]];
// cout << "index: " << label_idx << " : " << ids[i] << endl;
}
else if (g_mode != "infer") // unknown label, skip to next label
{
cout << "Ignoring unrecognized label: " << ids[i] << endl;
continue;
}
}
for (size_t j = 0; j < emb_objs[i].size(); ++j)
{
matrix<float,128,1> embedding;
embed_filename = emb_objs[i][j];
embed_filenames.push_back (embed_filename);
deserialize(embed_filename) >> embedding;
embeddings.push_back(embedding);
labels_idx.push_back(label_idx);
}
}
}
//-----------------------------------------------------------------
// returns network_ids.dat from network.dat
//-----------------------------------------------------------------
std::string get_network_id_name (std::string network_str)
{
boost::filesystem::path network_path (network_str);
std::string ids_str = network_path.parent_path().c_str();
if (!ids_str.empty())
ids_str += "/";
ids_str += network_path.stem().c_str();
ids_str += "_ids";
ids_str += network_path.extension().c_str();
return ids_str;
}
// ----------------------------------------------------------------------------
// Copy of test_multiclass_decision_function from dlib/svm.
// Added source_files and lookup map to identify images with wrong answers.
// ----------------------------------------------------------------------------
template <
typename dec_funct_type,
typename sample_type,
typename label_type
>
const matrix<double> bearid_test_multiclass_decision_function (
const dec_funct_type& dec_funct,
const std::vector<sample_type>& x_test,
const std::vector<label_type>& y_test,
const std::vector<std::string>& source_files,
const std::map<int, std::string>& id_label_map
)
{
const std::vector<label_type> all_labels = dec_funct.get_labels();
// make a lookup table that maps from labels to their index in all_labels
std::map<label_type,unsigned long> label_to_int;
for (unsigned long i = 0; i < all_labels.size(); ++i)
label_to_int[all_labels[i]] = i;
matrix<double, 0, 0, typename dec_funct_type::mem_manager_type> res;
res.set_size(all_labels.size(), all_labels.size());
res = 0;
typename std::map<label_type,unsigned long>::const_iterator iter;
// now test this trained object
for (unsigned long i = 0; i < x_test.size(); ++i)
{
iter = label_to_int.find(y_test[i]);
// ignore samples with labels that the decision function doesn't know about.
if (iter == label_to_int.end())
continue;
const unsigned long truth = iter->second;
int label_id = dec_funct(x_test[i]);
const unsigned long pred = label_to_int[dec_funct(x_test[i])];
std::string label = id_label_map.find (label_id)->second;
boost::filesystem::path path_full_imgfile (source_files[i]);
boost::filesystem::path path_imgfile = path_full_imgfile.filename ();
boost::filesystem::path path_parent_path = path_full_imgfile.parent_path ();
boost::filesystem::path path_parent = path_parent_path.filename ();
boost::filesystem::path path_source_path = path_parent / path_imgfile;
res(truth,pred) += 1;
// cout << "Matched " << path_source_path.string() << " to " << label << endl;
if (truth != pred)
{
cout << "Matched " << path_source_path.string() << " to " << label << endl;
}
}
return res;
}
// ----------------------------------------------------------------------------------------
// Set folds higer for cross validation
// However, any bears with less than folds number of embeddings will be skipped
int folds = 3;
//--------------------------------------------------
// initialize xml
//--------------------------------------------------
int xml_add_headers ()
{
g_xml_tree.add("dataset.name", "bearid dataset");
g_xml_tree.add("dataset.comment", "Created by bearsvm");
return 0;
}
//-----------------------------------------------------------------
// main function
//-----------------------------------------------------------------
int main(int argc, char** argv)
{
try
{
time_t timeStart = time(NULL);
command_line_parser parser;
parser.add_option("h","Display this help message.");
parser.add_option("train","Train the svm and write to network file. Also write an ids file.", 1);
// --test <network>
parser.add_option("test","Test the svm using the network file.", 1);
parser.add_option("infer","Infer the IDs of embeddings using network file.", 1);
// --output: <trained_network> with --train; <embed_directory> with --embed
parser.parse(argc, argv);
// Now we do a little command line validation. Each of the following functions
// checks something and throws an exception if the test fails.
const char* one_time_opts[] = {"h", "train", "test", "infer"};
parser.check_one_time_options(one_time_opts); // Can't give an option more than once
if (parser.option("h") || parser.number_of_arguments () != 1)
{
cout << "\nUsage : bearsvm <option network_file> <embed_xml>\n";
cout << "\nExample: bearsvm -test bearsvm_network.dat val_embeds.xml\n\n";
parser.print_options();
return EXIT_SUCCESS;
}
if (parser.option("train"))
g_mode = "train";
else if (parser.option("test"))
g_mode = "test";
else if (parser.option("infer"))
g_mode = "infer";
// Samples are embeddings, which ar 128D vector of floats
typedef matrix<float,128,1> sample_type;
typedef one_vs_one_trainer<any_trainer<sample_type> > ovo_trainer;
typedef linear_kernel<sample_type> kernel_type;
std::vector<sample_type> samples;
std::vector<double> label_indices;
std::vector <std::string> ids;
std::vector <std::string> embed_files;
std::map<std::string,int> label_id_map;
std::map<int, std::string> id_label_map;
std::string embed_xml = parser[0];
std::string svm_network_name, svm_network_ids_name;
// ----- Training ------------
if (parser.option("train"))
{
svm_network_name = parser.option("train").argument();
svm_network_ids_name = get_network_id_name (svm_network_name);
cout << "\nTraining with embed file.... : " << embed_xml << endl;
extract_embeds (embed_xml, samples, label_indices, ids, embed_files, label_id_map);
// Onve vs One trainer
ovo_trainer trainer;
// Linear SVM
svm_c_linear_trainer<kernel_type> linear_trainer;
linear_trainer.set_c(100);
// Use the SVM classifier for the OVO trainer
trainer.set_trainer(linear_trainer);
//------------------------------------------------------------
// Now let's do 5-fold cross-validation using the one_vs_one_trainer we just setup.
// As an aside, always shuffle the order of the samples before doing cross validation.
randomize_samples(samples, label_indices);
// Create a decision function
one_vs_one_decision_function<ovo_trainer> df = trainer.train(samples, label_indices);
// Test one_vs_one_decision_function
// cout << "predicted label: "<< df(samples[0]) << ", true label: "<< label_indices[0] << endl;
// Save SVM to disk
one_vs_one_decision_function<ovo_trainer,
decision_function<kernel_type> // This is the output of the linear_trainer
> df2, df3;
df2 = df;
serialize(svm_network_name) << df2;
serialize(svm_network_ids_name) << ids;
cout << "\nWrote " << svm_network_name << " and " << svm_network_ids_name << ".\n" << endl;
}
else if (parser.option("test"))
{
svm_network_name = parser.option("test").argument();
svm_network_ids_name = get_network_id_name (svm_network_name);
one_vs_one_decision_function<ovo_trainer,
decision_function<kernel_type> > df3;
// Check serialization
std::vector <string> ids2;
deserialize(svm_network_ids_name) >> ids2;
deserialize(svm_network_name) >> df3;
// recreate label:index map from training run
for (int i = 0; i < ids2.size (); ++i)
{
// cout << "ID: " << i << "\t: " << ids2[i] << endl;
label_id_map [ids2[i]] = i;
id_label_map [i] = ids2[i];
}
cout << "\nTesting with embed file.... : " << embed_xml << endl;
extract_embeds (embed_xml, samples, label_indices, ids, embed_files, label_id_map);
// call copy of test_multiclass_decision_function. Add embed_files
// to identify images with wrong answers.
matrix<double> cm = bearid_test_multiclass_decision_function(df3, samples, label_indices, embed_files, id_label_map);
// cout << "test df: \n" << cm << endl;
cout << "correct: " << sum(diag(cm)) << " : total : " << sum(cm) << endl;
cout << "accuracy: " << sum(diag(cm))/sum(cm) << endl;
}
else if (parser.option("infer")) // doing inference ----------
{
boost::filesystem::path current_dir (boost::filesystem::current_path());
g_xml_tree.add ("dataset.command", argv[0]);
g_xml_tree.add ("dataset.cwd", current_dir.string());
g_xml_tree.add ("dataset.filetype", "svm output");
ptree images = g_xml_tree.add ("dataset.images", "");
svm_network_name = parser.option("infer").argument();
svm_network_ids_name = get_network_id_name (svm_network_name);
one_vs_one_decision_function<ovo_trainer,
decision_function<kernel_type> > df3;
// Check serialization
std::vector <string> ids2;
int idx;
deserialize(svm_network_ids_name) >> ids2;
deserialize(svm_network_name) >> df3;
cout << "\nInferring with embed file.... : " << embed_xml << endl;
extract_embeds (embed_xml, samples, label_indices, ids, embed_files, label_id_map);
ptree empty_tree;
ptree image;
for (int i = 0 ; i < samples.size (); ++i)
{
idx = df3 (samples[i]);
boost::filesystem::path path_emb_file (embed_files[i]);
ptree chip_source = g_chips[i].get_child ("source", empty_tree);
if (chip_source != empty_tree)
{
// image = images.add_child ("image", chip_source);
ptree box = chip_source.get_child ("box", empty_tree);
ptree label = chip_source.get_child ("box.label");
chip_source.put ("box.label", ids2[idx]);
image = g_xml_tree.add_child ("dataset.images.image", chip_source);
}
else
cout << "missing image for embedding " << embed_files[i] << endl;
cout << ids2[idx] << " : " << path_emb_file.parent_path().stem().string() << "/" << path_emb_file.stem().string() << endl;
}
boost::filesystem::path xml_file (parser[0]);
std::string svm_xml_file;
if (xml_file.has_parent_path ())
svm_xml_file = xml_file.parent_path().string() + "/";
svm_xml_file += xml_file.filename().stem().string() + "_svm.xml";
xml_add_headers (); // put at end since writing reverse order added
// boost::property_tree::xml_writer_settings<char> settings ('\t', 1);
boost::property_tree::xml_writer_settings<std::string> settings('\t', 1, "utf-8\"?>\n<?xml-stylesheet type=\"text/xsl\" href=\"image_metadata_stylesheet.xsl");
// boost::property_tree::xml_writer_settings<std::string> settings (' ', 4);
write_xml(svm_xml_file, g_xml_tree,std::locale(),settings);
cout << "\ngenerated: \n\t-- " << svm_xml_file << endl;
}
else
{
cout << "Need one of <train|test|infer> to run bearsvm." << endl;
}
}
catch (exception& e)
{
// Note that this will catch any cmd_line_parse_error exceptions and print
// the default message.
cout << e.what() << endl;
}
}