14
14
import json
15
15
import logging as log
16
16
from collections import namedtuple
17
+ from abc import ABC , abstractmethod
17
18
18
19
import cv2
19
20
import numpy as np
22
23
from .segm_postrocess import postprocess
23
24
24
25
25
- class Detector :
26
+ class DetectorInterface (ABC ):
27
+ @abstractmethod
28
+ def run_asynch (self , frames , index ):
29
+ pass
30
+
31
+ @abstractmethod
32
+ def wait_and_grab (self ):
33
+ pass
34
+
35
+
36
+ class Detector (DetectorInterface ):
26
37
"""Wrapper class for detector"""
27
38
28
39
def __init__ (self , ie , model_path , conf = .6 , device = 'CPU' , ext_path = '' , max_num_frames = 1 ):
@@ -31,21 +42,26 @@ def __init__(self, ie, model_path, conf=.6, device='CPU', ext_path='', max_num_f
31
42
self .expand_ratio = (1. , 1. )
32
43
self .max_num_frames = max_num_frames
33
44
34
- def get_detections (self , frames ):
35
- """Returns all detections on frames"""
45
+ def run_asynch (self , frames , index ):
36
46
assert len (frames ) <= self .max_num_frames
37
-
38
- all_detections = []
47
+ self .shapes = []
39
48
for i in range (len (frames )):
49
+ self .shapes .append (frames [i ].shape )
40
50
self .net .forward_async (frames [i ])
41
- outputs = self .net .grab_all_async ()
42
51
52
+ def wait_and_grab (self ):
53
+ all_detections = []
54
+ outputs = self .net .grab_all_async ()
43
55
for i , out in enumerate (outputs ):
44
- detections = self .__decode_detections (out , frames [i ]. shape )
56
+ detections = self .__decode_detections (out , self . shapes [i ])
45
57
all_detections .append (detections )
46
-
47
58
return all_detections
48
59
60
+ def get_detections (self , frames ):
61
+ """Returns all detections on frames"""
62
+ self .run_asynch (frames )
63
+ return self .wait_and_grab ()
64
+
49
65
def __decode_detections (self , out , frame_shape ):
50
66
"""Decodes raw SSD output"""
51
67
detections = []
@@ -130,7 +146,7 @@ def forward(self, rois):
130
146
return embeddings
131
147
132
148
133
- class MaskRCNN :
149
+ class MaskRCNN ( DetectorInterface ) :
134
150
"""Wrapper class for a network returning masks of objects"""
135
151
136
152
def __init__ (self , ie , model_path , conf = .6 , device = 'CPU' , ext_path = '' ,
@@ -213,6 +229,12 @@ def get_detections(self, frames, return_cropped_masks=True, only_class_person=Tr
213
229
outputs .append (frame_output )
214
230
return outputs
215
231
232
+ def run_asynch (self , frames , index ):
233
+ self .frames = frames
234
+
235
+ def wait_and_grab (self ):
236
+ return self .get_detections (self .frames )
237
+
216
238
class Compose (object ):
217
239
def __init__ (self , transforms ):
218
240
self .transforms = transforms
@@ -286,7 +308,7 @@ def __call__(self, sample):
286
308
return sample
287
309
288
310
289
- class DetectionsFromFileReader (object ):
311
+ class DetectionsFromFileReader (DetectorInterface ):
290
312
"""Read detection from *.json file.
291
313
Format of the file should be:
292
314
[
@@ -296,6 +318,7 @@ class DetectionsFromFileReader(object):
296
318
...
297
319
]
298
320
"""
321
+
299
322
def __init__ (self , input_files , score_thresh ):
300
323
self .input_files = input_files
301
324
self .score_thresh = score_thresh
@@ -309,12 +332,15 @@ def __init__(self, input_files, score_thresh):
309
332
detections_dict [det ['frame_id' ]] = {'boxes' : det ['boxes' ], 'scores' : det ['scores' ]}
310
333
self .detections .append (detections_dict )
311
334
312
- def get_detections (self , frame_id ):
335
+ def run_asynch (self , frames , index ):
336
+ self .last_index = index
337
+
338
+ def wait_and_grab (self ):
313
339
output = []
314
340
for source in self .detections :
315
341
valid_detections = []
316
- if frame_id in source :
317
- for bbox , score in zip (source [frame_id ]['boxes' ], source [frame_id ]['scores' ]):
342
+ if self . last_index in source :
343
+ for bbox , score in zip (source [self . last_index ]['boxes' ], source [self . last_index ]['scores' ]):
318
344
if score > self .score_thresh :
319
345
bbox = [int (value ) for value in bbox ]
320
346
valid_detections .append ((bbox , score ))
0 commit comments