-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdb_fingerprint_nadav.py
381 lines (316 loc) · 14.8 KB
/
db_fingerprint_nadav.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
__author__ = 'Nadav Paz'
import logging
import multiprocessing as mp
import argparse
import time
import signal
import traceback
from rq import Queue
import pymongo.errors
import numpy as np
import cv2
from . import geometry
from . import fingerprint_core as fp
from . import background_removal
from . import Utils
from . import constants
from .constants import db
from .constants import redis_conn
q2 = Queue('change_fp_to_dict', connection=redis_conn)
# globals
CLASSIFIER_FOR_CATEGORY = {}
TOTAL_PRODUCTS = mp.Value("i", 0)
CURRENT = Utils.ThreadSafeCounter()
DB = None
FP_VERSION = 0
START_TIME = 0
CONTINUE = mp.Value("b", True)
Q = mp.Queue(1000)
MAIN_PID = 0
NUM_PROCESSES = mp.Value("i", 0)
def get_all_subcategories(category_collection, category_id):
"""
create a list of all subcategories in category_id, including itself.
assumes category_collection is a mongodb Collection of category dictionaries
with keys "id" and "childrenIds"
:param category_collection: mongodb Collection
:param category_id: string
:return: list of all subcategories in category_id, including itself.
"""
subcategories = []
def get_subcategories(c_id):
subcategories.append(c_id)
curr_cat = category_collection.find_one({"id": c_id})
if "childrenIds" in curr_cat.keys():
for childId in curr_cat["childrenIds"]:
get_subcategories(childId)
get_subcategories(category_id)
return subcategories
def create_classifier_for_category_dict(db):
"""
Creates a dictionary with items: category_id: CascasdeClassifier
Requires cv2 and constants to be imported
:param db: connected pymongo.MongoClient().db object
:return: dictionary with items: category_id: CascasdeClassifier
"""
result_dict = {}
classifier_dict = {xml: cv2.CascadeClassifier(constants.classifiers_folder + xml)
for xml in constants.classifier_to_category_dict.keys()}
for xml, cats in constants.classifier_to_category_dict.iteritems():
for cat in cats:
for sub_cat in get_all_subcategories(db.categories, cat):
result_dict[sub_cat] = classifier_dict[xml]
return result_dict
def run_fp_on_db_product(doc):
CURRENT.increment()
if CURRENT.value % 50 == 0:
print "Process {process} starting {i} of {total}...".format(process=mp.current_process().name,
i=CURRENT.value, total=TOTAL_PRODUCTS.value)
image_url = doc["image"]["sizes"]["XLarge"]["url"]
image = Utils.get_cv2_img_array(image_url)
if not Utils.is_valid_image(image):
logging.warning("image is None. url: {url}".format(url=image_url))
return
small_image, resize_ratio = background_removal.standard_resize(image, 400)
# I think we can delete this... memory management FTW??
del image
# print "Image URL: {0}".format(image_url)
# if there is a valid human BB, use it
if "human_bb" in doc.keys() and doc["human_bb"] != [0, 0, 0, 0] and doc["human_bb"] is not None:
chosen_bounding_box = doc["human_bb"]
chosen_bounding_box = [int(b) for b in (np.array(chosen_bounding_box) / resize_ratio)]
mask = background_removal.get_fg_mask(small_image, chosen_bounding_box)
logging.debug("Human bb found: {bb} for item: {id}".format(bb=chosen_bounding_box, id=doc["id"]))
# otherwise use the largest of possibly many classifier bb's
else:
if "categories" in doc:
classifier = CLASSIFIER_FOR_CATEGORY.get(doc["categories"][0]["id"], "")
else:
classifier = None
# first try grabcut with no bb
if not Utils.is_valid_image(small_image):
logging.warning("small_image is Bad. {img}".format(img=small_image))
return
mask = background_removal.get_fg_mask(small_image)
bounding_box_list = []
if classifier and not classifier.empty():
# then - try to classify the image (white backgrounded and get a more accurate bb
white_bckgnd_image = background_removal.image_white_bckgnd(small_image, mask)
try:
bounding_box_list = classifier.detectMultiScale(white_bckgnd_image)
except KeyError:
logging.info("Could not classify with {0}".format(classifier))
# choosing the biggest bounding box if there are a few
max_bb_area = 0
chosen_bounding_box = None
for possible_bb in bounding_box_list:
if possible_bb[2] * possible_bb[3] > max_bb_area:
chosen_bounding_box = possible_bb
max_bb_area = possible_bb[2] * possible_bb[3]
if chosen_bounding_box is None:
logging.info("No Bounding Box found, using the whole image. "
"Document id: {0}, BB_list: {1}".format(doc.get("id"), str(bounding_box_list)))
else:
mask = background_removal.get_fg_mask(small_image, chosen_bounding_box)
try:
fingerprint = fp.fp(small_image, mask)
DB.products.update({"id": doc["id"]},
{"$set": {"fingerprint": fingerprint.tolist(),
"fp_version": FP_VERSION,
"bounding_box": np.array(chosen_bounding_box).tolist()}})
except Exception as ex:
logging.warning("Exception caught while fingerprinting: {0}".format(ex))
def do_work_on_q(some_func, q):
current_pid = mp.current_process().pid
print "{0} Getting ready to do some work...".format(str(current_pid))
try:
while CONTINUE.value:
popped_item = q.get()
if popped_item is None:
print "Process {0} finished".format(str(current_pid))
return
some_func(popped_item)
except BaseException as be:
print "Process {0}, exception reached do_work:\n".format(str(current_pid))
traceback.print_exc()
# get back to work
do_work_on_q(some_func, q)
print "{0} all done...".format(str(current_pid))
return "{0} returned".format(str(current_pid))
def connect_db_feed_q(q, query_doc, fields_doc, retry_num=0):
"""
Connects to the DB, queries, and fills q with results.
Also sets global TOTAL_PRODUCTS, DB
:param q:
:return:
"""
global TOTAL_PRODUCTS, DB
DB = DB or db
product_cursor = DB.products.find(query_doc, fields_doc) # .batch_size(n)
TOTAL_PRODUCTS.value = product_cursor.count()
print "Total tasks: {0}".format(str(TOTAL_PRODUCTS.value))
try:
for doc in product_cursor:
q.put(doc)
except pymongo.errors.OperationFailure:
# I think this happens if cursor has been inactive too long
traceback.print_exc()
print "\n Trying reconnect in 5 seconds"
if retry_num <= 5:
time.sleep(5)
connect_db_feed_q(q, query_doc, fields_doc, retry_num + 1)
else:
print "Could not reconnect..."
CONTINUE.value = False
for p in range(0, NUM_PROCESSES.value):
q.put(None)
print "Done putting all docs in Q"
q.close()
def print_stats(start_time):
stop_time = time.time()
total_time = stop_time - start_time
print "Stats:\n " \
"Completed {total} fingerprints in {seconds} seconds with {procs} processes.\n " \
"Average time per fingerprint: {avg}\n " \
"Average time per fingerprint per core: {avgc}\n"\
.format(avg=total_time/CURRENT.value, total=CURRENT.value,
seconds=total_time, procs=NUM_PROCESSES.value,
avgc=(total_time/CURRENT.value)*NUM_PROCESSES.value)
def fingerprint_db(fp_version, category_id=None, num_processes=None):
"""
main function - fingerprints items in category_id and its subcategories.
If category_id is None, then fingerprints entire db. Also manages the multiprocessing
:param fp_version: integer to keep track of which items have been already fingerprinted with this version
:param category_id: category to be fingerprinted
:return:
"""
global CURRENT, CLASSIFIER_FOR_CATEGORY, FP_VERSION, NUM_PROCESSES, DB, START_TIME, MAIN_PID
MAIN_PID = mp.current_process().pid
NUM_PROCESSES.value = num_processes or int(mp.cpu_count() * 0.75)
DB = DB or db
if category_id is not None:
query_doc = {"$and": [
{"categories": {"$elemMatch": {"id": {"$in": get_all_subcategories(DB.categories, category_id)}}}},
{"$or": [{"fp_version": {"$lt": fp_version}}, {"fp_version": {"$exists": 0}}]}
]}
else:
query_doc = {"$or": [{"fp_version": {"$lt": fp_version}}, {"fp_version": {"$exists": 0}}]}
fields = {"image": 1, "human_bb": 1, "fp_version": 1, "bounding_box": 1, "categories": 1, "id": 1}
CLASSIFIER_FOR_CATEGORY = create_classifier_for_category_dict(DB)
FP_VERSION = fp_version
feeder = mp.Process(target=connect_db_feed_q, name="Feeder", args=[Q, query_doc, fields])
worker_list = [mp.Process(target=do_work_on_q, name="Worker {0}".format(p), args=(run_fp_on_db_product, Q))
for p in range(0, NUM_PROCESSES.value)]
START_TIME = time.time()
feeder.start()
for p in worker_list:
p.start()
for p in worker_list:
p.join()
feeder.join()
stop_time = time.time()
total_time = stop_time - START_TIME
print "All done!!"
print "Completed {total} fingerprints in {seconds} seconds " \
"with {procs} processes.".format(total=TOTAL_PRODUCTS.value, seconds=total_time, procs=num_processes)
print "Average time per fingerprint: {avg}".format(avg=total_time/TOTAL_PRODUCTS.value)
print "Average time per fingerprint per core: {avgc}".format(avgc=(total_time/TOTAL_PRODUCTS.value)*num_processes)
def fingerprint_db_old(fp_version, category_id=None, num_processes=None):
"""
main function - fingerprints items in category_id and its subcategories.
If category_id is None, then fingerprints entire db.
:param fp_version: integer to keep track of which items have been already fingerprinted with this version
:param category_id: category to be fingerprinted
:return:
"""
global DB, TOTAL_PRODUCTS, CURRENT, CLASSIFIER_FOR_CATEGORY, FP_VERSION
DB = DB or db
num_processes = num_processes or mp.cpu_count() - 2
if category_id is not None:
query_doc = {"$and": [
{"categories": {"$elemMatch": {"id": {"$in": get_all_subcategories(DB.categories, category_id)}}}},
{"$or": [{"fp_version": {"$lt": fp_version}}, {"fp_version": {"$exists": 0}}]}
]}
else:
query_doc = {"$or": [{"fp_version": {"$lt": fp_version}}, {"fp_version": {"$exists": 0}}]}
fields = {"image": 1, "human_bb": 1, "fp_version": 1, "bounding_box": 1, "categories": 1, "id": 1}
# batch_size required because cursor timed out without it. Could use further investigation
product_cursor = DB.products.find(query_doc, fields).batch_size(num_processes)
TOTAL_PRODUCTS = product_cursor.count()
CLASSIFIER_FOR_CATEGORY = create_classifier_for_category_dict(DB)
FP_VERSION = fp_version
pool = mp.Pool(num_processes, maxtasksperchild=5)
start_time = time.time()
pool.map(run_fp_on_db_product, product_cursor)
stop_time = time.time()
total_time = stop_time - start_time
pool.close()
pool.join()
print "All done!!"
print "Completed {total} fingerprints in {seconds} seconds " \
"with {procs} processes.".format(total=TOTAL_PRODUCTS.value, seconds=total_time, procs=num_processes)
print "Average time per fingerprint: {avg}".format(avg=total_time/TOTAL_PRODUCTS.value)
print "Average time per fingerprint per core: {avgc}".format(avgc=(total_time/TOTAL_PRODUCTS.value)*num_processes)
def receive_signal(signum, stack):
if signum == 17 or 28:
# 17 creating child process, ignore
# 28 SIGWINCH, ignore
return
if signum == 2 and mp.current_process().pid == MAIN_PID:
print_stats(START_TIME)
return
print '{0} caught signal {1}.'.format(mp.current_process().pid, str(signum))
traceback.print_stack(stack)
if __name__ == "__main__":
uncatchable = ['SIG_DFL', 'SIGSTOP', 'SIGKILL']
for i in [x for x in dir(signal) if x.startswith("SIG")]:
if not i in uncatchable:
signum = getattr(signal, i)
signal.signal(signum, receive_signal)
try:
parser = argparse.ArgumentParser(description='Fingerprint the DB or part of it')
parser.add_argument('-c', '--category_id', help='id of category to be fingerprinted', required=False)
parser.add_argument('-p', '--num_processes', help='number of parallel processes to spawn',
required=False, type=int)
parser.add_argument('-v', '--fp_version', help='current fp version', required=True)
args = vars(parser.parse_args())
fingerprint_db(int(args['fp_version']), args['category_id'], args['num_processes'])
except Exception as e:
logging.warning("Exception reached main!: {0}".format(e))
def make_the_change(collection):
q2.enqueue('change_fp_to_dict', collection)
def change_fp_to_dict(collection, category=None):
coll = db[collection]
i = 0
for doc in coll.find():
if 'fingerprint' in doc.keys():
if i % 10 == 0:
print "doing the {0}th item".format(i)
fp_dict = {'color': doc['fingerprint']}
image = Utils.get_cv2_img_array(str(doc["images"]["XLarge"]))
result = background_removal.image_is_relevant(image)
if result.is_relevant:
faces = result.faces
start = time.time()
fp_dict['lod'] = geometry.length_of_lower_body_part_field(image, faces[0])[0]
print "lod took {0} seconds..".format(time.time() - start)
# try:
# start = time.time()
# fp_dict['collar'] = collar_classifier.collar_classifier(image, faces[0])
# print "collar took {0} seconds..".format(time.time() - start)
# except:
# print "problem with collar.."
# fp_dict['collar'] = {'roundneck': 0.5, 'squareneck': 0.5, 'v-neck': 0.5}
else:
fp_dict['lod'] = 0.6
# fp_dict['collar'] = {'roundneck': 0.5, 'squareneck': 0.5, 'v-neck': 0.5}
print "updating fp with lod: {0}".format(fp_dict['lod'])
coll.update_one({'_id': doc['_id']}, {'$set': {'fp_dict': fp_dict}, '$unset': {'fingerprint': ""}})
i += 1
def fp_go_back(collection):
coll = db[collection]
bads = coll.find({'fingerprint': {'$exists': 0}})
print bads.count()
for doc in bads:
coll.update_one({'_id': doc['_id']}, {'$set': {'fingerprint': doc['fp_dict']['color']},
'$unset': {'fp_dict': ""}})