Skip to content

Commit 734189c

Browse files
Merge pull request #1 from sovrasov/asynch_detector_call
Asynch detector call
2 parents 49565d2 + e243362 commit 734189c

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

demos/python_demos/multi_camera_multi_person_tracking/multi_camera_multi_person_tracking.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def run(params, capture, detector, reid):
8080
else:
8181
output_video = None
8282

83+
prev_frames = thread_body.frames_queue.get()
84+
detector.run_asynch(prev_frames, frame_number)
85+
8386
while thread_body.process:
8487
key = check_pressed_keys(key)
8588
if key == 27:
@@ -93,30 +96,30 @@ def run(params, capture, detector, reid):
9396
if frames is None:
9497
continue
9598

96-
if params.detections:
97-
all_detections = detector.get_detections(frame_number)
98-
else:
99-
all_detections = detector.get_detections(frames)
99+
frame_number += 1
100+
all_detections = detector.wait_and_grab()
101+
detector.run_asynch(frames, frame_number)
102+
100103
all_masks = [[] for _ in range(len(all_detections))]
101104
for i, detections in enumerate(all_detections):
102105
all_detections[i] = [det[0] for det in detections]
103106
all_masks[i] = [det[2] for det in detections if len(det) == 3]
104107

105-
tracker.process(frames, all_detections, all_masks)
108+
tracker.process(prev_frames, all_detections, all_masks)
106109
tracked_objects = tracker.get_tracked_objects()
107110

108111
latency = time.time() - start
109112
avg_latency.update(latency)
110113
fps = round(1. / latency, 1)
111114

112-
vis = visualize_multicam_detections(frames, tracked_objects, fps, **config['visualization_config'])
115+
vis = visualize_multicam_detections(prev_frames, tracked_objects, fps, **config['visualization_config'])
113116
cv.imshow(win_name, vis)
114117
if output_video:
115118
output_video.write(cv.resize(vis, video_output_size))
116119

117120
print('\rProcessing frame: {}, fps = {} (avg_fps = {:.3})'.format(
118121
frame_number, fps, 1. / avg_latency.get()), end="")
119-
frame_number += 1
122+
prev_frames, frames = frames, prev_frames
120123
print('')
121124

122125
thread_body.process = False

demos/python_demos/multi_camera_multi_person_tracking/utils/network_wrappers.py

+39-13
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import json
1515
import logging as log
1616
from collections import namedtuple
17+
from abc import ABC, abstractmethod
1718

1819
import cv2
1920
import numpy as np
@@ -22,7 +23,17 @@
2223
from .segm_postrocess import postprocess
2324

2425

25-
class Detector:
26+
class DetectorInterface(ABC):
27+
@abstractmethod
28+
def run_asynch(self, frames, index):
29+
pass
30+
31+
@abstractmethod
32+
def wait_and_grab(self):
33+
pass
34+
35+
36+
class Detector(DetectorInterface):
2637
"""Wrapper class for detector"""
2738

2839
def __init__(self, ie, model_path, conf=.6, device='CPU', ext_path='', max_num_frames=1):
@@ -31,21 +42,26 @@ def __init__(self, ie, model_path, conf=.6, device='CPU', ext_path='', max_num_f
3142
self.expand_ratio = (1., 1.)
3243
self.max_num_frames = max_num_frames
3344

34-
def get_detections(self, frames):
35-
"""Returns all detections on frames"""
45+
def run_asynch(self, frames, index):
3646
assert len(frames) <= self.max_num_frames
37-
38-
all_detections = []
47+
self.shapes = []
3948
for i in range(len(frames)):
49+
self.shapes.append(frames[i].shape)
4050
self.net.forward_async(frames[i])
41-
outputs = self.net.grab_all_async()
4251

52+
def wait_and_grab(self):
53+
all_detections = []
54+
outputs = self.net.grab_all_async()
4355
for i, out in enumerate(outputs):
44-
detections = self.__decode_detections(out, frames[i].shape)
56+
detections = self.__decode_detections(out, self.shapes[i])
4557
all_detections.append(detections)
46-
4758
return all_detections
4859

60+
def get_detections(self, frames):
61+
"""Returns all detections on frames"""
62+
self.run_asynch(frames)
63+
return self.wait_and_grab()
64+
4965
def __decode_detections(self, out, frame_shape):
5066
"""Decodes raw SSD output"""
5167
detections = []
@@ -130,7 +146,7 @@ def forward(self, rois):
130146
return embeddings
131147

132148

133-
class MaskRCNN:
149+
class MaskRCNN(DetectorInterface):
134150
"""Wrapper class for a network returning masks of objects"""
135151

136152
def __init__(self, ie, model_path, conf=.6, device='CPU', ext_path='',
@@ -213,6 +229,12 @@ def get_detections(self, frames, return_cropped_masks=True, only_class_person=Tr
213229
outputs.append(frame_output)
214230
return outputs
215231

232+
def run_asynch(self, frames, index):
233+
self.frames = frames
234+
235+
def wait_and_grab(self):
236+
return self.get_detections(self.frames)
237+
216238
class Compose(object):
217239
def __init__(self, transforms):
218240
self.transforms = transforms
@@ -286,7 +308,7 @@ def __call__(self, sample):
286308
return sample
287309

288310

289-
class DetectionsFromFileReader(object):
311+
class DetectionsFromFileReader(DetectorInterface):
290312
"""Read detection from *.json file.
291313
Format of the file should be:
292314
[
@@ -296,6 +318,7 @@ class DetectionsFromFileReader(object):
296318
...
297319
]
298320
"""
321+
299322
def __init__(self, input_files, score_thresh):
300323
self.input_files = input_files
301324
self.score_thresh = score_thresh
@@ -309,12 +332,15 @@ def __init__(self, input_files, score_thresh):
309332
detections_dict[det['frame_id']] = {'boxes': det['boxes'], 'scores': det['scores']}
310333
self.detections.append(detections_dict)
311334

312-
def get_detections(self, frame_id):
335+
def run_asynch(self, frames, index):
336+
self.last_index = index
337+
338+
def wait_and_grab(self):
313339
output = []
314340
for source in self.detections:
315341
valid_detections = []
316-
if frame_id in source:
317-
for bbox, score in zip(source[frame_id]['boxes'], source[frame_id]['scores']):
342+
if self.last_index in source:
343+
for bbox, score in zip(source[self.last_index]['boxes'], source[self.last_index]['scores']):
318344
if score > self.score_thresh:
319345
bbox = [int(value) for value in bbox]
320346
valid_detections.append((bbox, score))

0 commit comments

Comments
 (0)