-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathavnet_face_tracking_mt.py
244 lines (182 loc) · 7.01 KB
/
avnet_face_tracking_mt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
'''
Copyright 2020 Avnet Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''
# USAGE
# python avnet_face_tracking_mt.py [--input 0] [--detthreshold 0.55] [--nmsthreshold 0.35] [--threads 4]
import numpy as np
import argparse
import imutils
import time
import cv2
import os, errno
import sys
import threading
import queue
from imutils.video import FPS
from pyimagesearch.centroidtracker import CentroidTracker
from vitis_ai_vart.facedetect import FaceDetect
import runner
import xir.graph
import pathlib
import xir.subgraph
def get_subgraph (g):
sub = []
root = g.get_root_subgraph()
sub = [ s for s in root.children
if s.metadata.get_attr_str ("device") == "DPU"]
return sub
global bQuit
def taskCapture(inputId,queueIn):
global bQuit
#print("[INFO] taskCapture : starting thread ...")
# Start the FPS counter
fpsIn = FPS().start()
# Initialize the camera input
print("[INFO] taskCapture : starting camera input ...")
cam = cv2.VideoCapture(inputId)
cam.set(cv2.CAP_PROP_FRAME_WIDTH,640)
cam.set(cv2.CAP_PROP_FRAME_HEIGHT,480)
if not (cam.isOpened()):
print("[ERROR] taskCapture : Failed to open camera ", inputId )
exit()
while not bQuit:
# Capture image from camera
ret,frame = cam.read()
# Update the FPS counter
fpsIn.update()
# Push captured image to input queue
queueIn.put(frame)
# Stop the timer and display FPS information
fpsIn.stop()
print("[INFO] taskCapture : elapsed time: {:.2f}".format(fpsIn.elapsed()))
print("[INFO] taskCapture : elapsed FPS: {:.2f}".format(fpsIn.fps()))
#print("[INFO] taskCapture : exiting thread ...")
def taskWorker(worker,dpu,detThreshold,nmsThreshold,ct,queueIn,queueOut):
global bQuit
#print("[INFO] taskWorker[",worker,"] : starting thread ...")
# Start the face detector
dpu_face_detector = FaceDetect(dpu,detThreshold,nmsThreshold)
dpu_face_detector.start()
while not bQuit:
# Pop captured image from input queue
frame = queueIn.get()
# Vitis-AI/DPU based face detector
faces = dpu_face_detector.process(frame)
# update our centroid tracker using the computed set of bounding
# box rectangles
objects = ct.update(faces)
# loop over the tracked objects
for (objectID, centroid) in objects.items():
# draw both the ID of the object and the centroid of the
# object on the output frame
text = "ID {}".format(objectID)
bbox = ct.bboxes[objectID]
cv2.putText(frame, text, (bbox[0], bbox[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
cv2.rectangle( frame, (bbox[0],bbox[1]), (bbox[2],bbox[3]), (0,255,0), 2)
# Push processed image to output queue
queueOut.put(frame)
# Stop the face detector
dpu_face_detector.stop()
# workaround : to ensure other worker threads stop,
# make sure input queue is not empty
queueIn.put(frame)
#print("[INFO] taskWorker[",worker,"] : exiting thread ...")
def taskDisplay(queueOut):
global bQuit
#print("[INFO] taskDisplay : starting thread ...")
# Start the FPS counter
fpsOut = FPS().start()
while not bQuit:
# Pop processed image from output queue
frame = queueOut.get()
# Display the processed image
cv2.imshow("Face Tracking", frame)
# Update the FPS counter
fpsOut.update()
# if the `q` key was pressed, break from the loop
key = cv2.waitKey(1) & 0xFF
if key == ord("q"):
break
# Trigger all threads to stop
bQuit = True
# Stop the timer and display FPS information
fpsOut.stop()
print("[INFO] taskDisplay : elapsed time: {:.2f}".format(fpsOut.elapsed()))
print("[INFO] taskDisplay : elapsed FPS: {:.2f}".format(fpsOut.fps()))
# Cleanup
cv2.destroyAllWindows()
#print("[INFO] taskDisplay : exiting thread ...")
def main(argv):
global bQuit
bQuit = False
# Construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--input", required=False,
help = "input camera identifier (default = 0)")
ap.add_argument("-d", "--detthreshold", required=False,
help = "face detector softmax threshold (default = 0.55)")
ap.add_argument("-n", "--nmsthreshold", required=False,
help = "face detector NMS threshold (default = 0.35)")
ap.add_argument("-t", "--threads", required=False,
help = "number of worker threads (default = 4)")
args = vars(ap.parse_args())
if not args.get("input",False):
inputId = 0
else:
inputId = int(args["input"])
print('[INFO] input camera identifier = ',inputId)
if not args.get("detthreshold",False):
detThreshold = 0.55
else:
detThreshold = float(args["detthreshold"])
print('[INFO] face detector - softmax threshold = ',detThreshold)
if not args.get("nmsthreshold",False):
nmsThreshold = 0.35
else:
nmsThreshold = float(args["nmsthreshold"])
print('[INFO] face detector - NMS threshold = ',nmsThreshold)
if not args.get("threads",False):
threads = 4
else:
threads = int(args["threads"])
print('[INFO] number of worker threads = ', threads )
# Initialize VART API
densebox_elf = "/usr/share/vitis_ai_library/models/densebox_640_360/densebox_640_360.elf"
densebox_graph = xir.graph.Graph.deserialize(pathlib.Path(densebox_elf))
densebox_subgraphs = get_subgraph(densebox_graph)
assert len(densebox_subgraphs) == 1 # only one DPU kernel
all_dpu_runners = [];
for i in range(int(threads)):
all_dpu_runners.append(runner.Runner(densebox_subgraphs[0], "run"));
# Initialize our centroid tracker and frame dimensions
ct = CentroidTracker()
# Init synchronous queues for inter-thread communication
queueIn = queue.Queue()
queueOut = queue.Queue()
# Launch threads
threadAll = []
tc = threading.Thread(target=taskCapture, args=(inputId,queueIn))
threadAll.append(tc)
for i in range(threads):
tw = threading.Thread(target=taskWorker, args=(i,all_dpu_runners[i],detThreshold,nmsThreshold,ct,queueIn,queueOut))
threadAll.append(tw)
td = threading.Thread(target=taskDisplay, args=(queueOut,))
threadAll.append(td)
for x in threadAll:
x.start()
# Wait for all threads to stop
for x in threadAll:
x.join()
# Cleanup VART API
del dpu
if __name__ == "__main__":
main(sys.argv)