-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraininator.py
194 lines (176 loc) · 8.05 KB
/
traininator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import numpy as np
import time
import cv2
from trendi import background_removal
from trendi import constants
from trendi import Utils
from gcloud import storage
from oauth2client.client import GoogleCredentials
import sys
import urllib2
#TO UPLOAD IMAGES TO BUCKET:
# gsutil -m cp -r new_photos_512x512/ gs://tg-training/tamara_berg_street2shop_dataset/images
# TO GRANT PERMISSION TO WORLD TO SEE:
#gsutil acl ch -u AllUsers:R gs://tg-training/tamara_berg_street2shop_dataset/images/*
db = constants.db
cats = constants.tamara_berg_categories
credentials = GoogleCredentials.get_application_default()
bucket = storage.Client(credentials=credentials).bucket("tg-training")
def get_margined_bb(image, bb, buffer):
x, y, w, h = bb
x_back = np.max((x - int(buffer*w), 0))
x_ahead = np.min((x + int((1+buffer)*w), image.shape[1]-1))
y_up = np.max((y - int(buffer*h), 0))
y_down = np.min((y + int((1+buffer)*h), image.shape[0]-1))
return [x_back, y_up, x_ahead-x_back, y_down-y_up]
def save_to_storage(buck, data, filename):
blb = buck.blob(filename)
mask_str = cv2.imencode('.png', data)[1].tostring()
# To upload from 2d-mask
blb.upload_from_string(mask_str)
def create_training_set_with_grabcut(collection):
coll = db[collection]
i = 1
total = db.training_images.count()
start = time.time()
for doc in coll.find():
if not i % 10:
print "did {0}/{1} documents in {2} seconds".format(i, total, time.time()-start)
print "average time for image = {0}".format((time.time()-start)/i)
url = doc['url'].split('/')[-1]
img_url = 'https://tg-training.storage.googleapis.com/tamara_berg_street2shop_dataset/images/' + url
image = Utils.get_cv2_img_array(img_url)
if image is None:
print "{0} is a bad image".format(img_url)
continue
i += 1
small_image, ratio = background_removal.standard_resize(image, 600)
# skin_mask = kassper.skin_detection_with_grabcut(small_image, small_image, skin_or_clothes='skin')
# mask = np.where(skin_mask == 255, 35, 0).astype(np.uint8)
mask = np.zeros(small_image.shape[:2], dtype=np.uint8)
for item in doc['items']:
try:
bb = [int(c/ratio) for c in item['bb']]
item_bb = get_margined_bb(small_image, bb, 0)
if item['category'] not in cats:
continue
category_num = cats.index(item['category'])
item_mask = background_removal.simple_mask_grabcut(small_image, rect=item_bb)
except:
continue
mask = np.where(item_mask == 255, category_num, mask)
filename = 'tamara_berg_street2shop_dataset/masks/' + url[:-4] + '.txt'
save_to_storage(bucket, mask, filename)
coll.update_one({'_id': doc['_id']}, {'$set': {'mask_url': 'https://tg-training.storage.googleapis.com/' + filename}})
print "Done masking! took {0} seconds".format(time.time()-start)
def bucket_to_training_set(collection):
'''
Takes a bucket of data and adds to db collection
if not in db, add
if in db, fix url, make user a list, already_done is counter
:param collection: mongodb colleciton
:return:
'''
coll = db[collection]
i = 1
total = db.training_images.count()
print(str(total)+' images in collection '+collection)
start = time.time()
for i in range(0,600000):
photo_name = 'photo_'+str(i)+'.jpg'
img_url = 'https://tg-training.storage.googleapis.com/tamara_berg_street2shop_dataset/images/'+photo_name
print('\nattempting to open '+img_url)
try:
ret = urllib2.urlopen(img_url)
if ret.code == 200:
print(photo_name+" exists, checking if in db")
try:
doc = coll.find_one({'url':'/home/jeremy/dataset/images/'+photo_name})
doc2 = coll.find_one({'url':'https://tg-training.storage.googleapis.com/tamara_berg_street2shop_dataset/images/'+photo_name})
# print('doc1:'+str(doc))
if doc :
print('found doc for '+str(photo_name)+' in db already')
#doc = doc[0]
# print(doc)
id = None
already_done = None
already_done_image_level = None
already_seen_image_level = None
user_name = None
if '_id' in doc:
id = doc['_id']
if 'already_done' in doc:
already_done = doc['already_done']
del doc['already_done']
doc['already_seen_image_level'] = 1
if 'already_seen_image_level' in doc:
already_seen_image_level = doc['already_seen_image_level']
doc['already_seen_image_level'] = 1
if 'user_name' in doc:
user_name = doc['user_name']
if isinstance(user_name,basestring):
doc['user_name'] = [user_name]
if 'url' in doc:
url = doc['url']
doc['url'] = img_url
# print('id {} ad {} asil {} un {}'.format(id,already_done,already_seen_image_level,user_name))
# print('items:'+str(doc['items']))
# print('new doc:\n'+str(doc))
res = coll.replace_one({'_id':id},doc)
# print('replace result:'+str(res))
elif doc2:
print('doc already replaced')
continue
else:
doc = {}
print('doc for '+str(photo_name)+' not found, adding to db')
doc['url'] = img_url
doc['round'] = 'v2'
doc['items'] = []
try:
res = coll.insert(doc)
# print('replace result:'+str(res))
except:
print('error trying to insert doc , err:'+str(sys.exc_info()[0]))
except:
print('error trying to get doc , err:'+str(sys.exc_info()[0]))
else:
print('image '+photo_name +' not found (ret code not 200)')
except:
print('error trying to open '+photo_name+' err:'+str(sys.exc_info()[0]))
# raw_input('ret to cont')
def clean_training(collection):
'''
Takes a bucket of data and adds to db collection
if not in db, add
if in db, fix url, make user a list, already_done is counter
:param collection: mongodb colleciton
:return:
'''
coll = db[collection]
i = 1
total = db.training_images.count()
print(str(total)+' images in collection '+collection)
start = time.time()
cursor = db.training_images.find()
doc = cursor.next()
while doc is not None:
# img_url = 'https://tg-training.storage.googleapis.com/tamara_berg_street2shop_dataset/images/'+photo_name
url = doc['url']
print('url:'+str(url))
if '/home/jeremy' in url: # 'home/jeremy/dataset/images/'+photo_name
photo_name = url.split('/')[-1]
new_url = 'https://tg-training.storage.googleapis.com/tamara_berg_street2shop_dataset/images/'+photo_name
print('photoname:'+str(photo_name)+' newurl:'+str(new_url))
doc['url'] = new_url
try:
id = doc['_id']
res = coll.replace_one({'_id':id},doc)
print('replace result:'+str(res))
except:
print('error trying to replace doc , err:'+str(sys.exc_info()[0]))
# raw_input('ret to cont')
doc = cursor.next()
if __name__ == "__main__":
clean_training('training_images')
# bucket_to_training_set('training_images')