forked from AlexeyAB/darknet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdarknet.py
218 lines (172 loc) · 7.54 KB
/
darknet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!python3
"""
Python 3 wrapper for identifying objects in images
Requires DLL compilation
Both the GPU and no-GPU version should be compiled; the no-GPU version should be renamed "yolo_cpp_dll_nogpu.dll".
On a GPU system, you can force CPU evaluation by any of:
- Set global variable DARKNET_FORCE_CPU to True
- Set environment variable CUDA_VISIBLE_DEVICES to -1
- Set environment variable "FORCE_CPU" to "true"
To use, either run performDetect() after import, or modify the end of this file.
See the docstring of performDetect() for parameters.
Directly viewing or returning bounding-boxed images requires scikit-image to be installed (`pip install scikit-image`)
Original *nix 2.7: https://github.com/pjreddie/darknet/blob/0f110834f4e18b30d5f101bf8f1724c34b7b83db/python/darknet.py
Windows Python 2.7 version: https://github.com/AlexeyAB/darknet/blob/fc496d52bf22a0bb257300d3c79be9cd80e722cb/build/darknet/x64/darknet.py
@author: Philip Kahn
@date: 20180503
Modified by Guilherme Fickel - Meerkat
"""
#pylint: disable=R, W0401, W0614, W0703
from ctypes import *
import math
import random
import os
import cv2
import numpy as np
def sample(probs):
s = sum(probs)
probs = [a/s for a in probs]
r = random.uniform(0, 1)
for i in range(len(probs)):
r = r - probs[i]
if r <= 0:
return i
return len(probs)-1
def c_array(ctype, values):
arr = (ctype*len(values))()
arr[:] = values
return arr
class BOX(Structure):
_fields_ = [("x", c_float),
("y", c_float),
("w", c_float),
("h", c_float)]
class DETECTION(Structure):
_fields_ = [("bbox", BOX),
("classes", c_int),
("prob", POINTER(c_float)),
("mask", POINTER(c_float)),
("objectness", c_float),
("sort_class", c_int),
("uc", POINTER(c_float)),
("points", c_int)]
class DETNUMPAIR(Structure):
_fields_ = [("num", c_int),
("dets", POINTER(DETECTION))]
class IMAGE(Structure):
_fields_ = [("w", c_int),
("h", c_int),
("c", c_int),
("data", POINTER(c_float))]
class METADATA(Structure):
_fields_ = [("classes", c_int),
("names", POINTER(c_char_p))]
hasGPU = True
class Darknet():
def __init__(self, configPath, weightPath, metaPath, libPath='/usr/local/libdarknet.so'):
if not os.path.exists(configPath):
raise ValueError("Invalid config path `"+os.path.abspath(configPath)+"`")
if not os.path.exists(weightPath):
raise ValueError("Invalid weight path `"+os.path.abspath(weightPath)+"`")
if not os.path.exists(metaPath):
raise ValueError("Invalid data file path `"+os.path.abspath(metaPath)+"`")
self.lib = CDLL(libPath, RTLD_GLOBAL)
self.lib.network_width.argtypes = [c_void_p]
self.lib.network_width.restype = c_int
self.lib.network_height.argtypes = [c_void_p]
self.lib.network_height.restype = c_int
copy_image_from_bytes = self.lib.copy_image_from_bytes
copy_image_from_bytes.argtypes = [IMAGE,c_char_p]
init_cpu = self.lib.init_cpu
make_image = self.lib.make_image
make_image.argtypes = [c_int, c_int, c_int]
make_image.restype = IMAGE
self.get_network_boxes = self.lib.get_network_boxes
self.get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int), c_int]
self.get_network_boxes.restype = POINTER(DETECTION)
self.free_detections = self.lib.free_detections
self.free_detections.argtypes = [POINTER(DETECTION), c_int]
self.network_predict = self.lib.network_predict_ptr
self.network_predict.argtypes = [c_void_p, POINTER(c_float)]
load_net_custom = self.lib.load_network_custom
load_net_custom.argtypes = [c_char_p, c_char_p, c_int, c_int]
load_net_custom.restype = c_void_p
self.do_nms_obj = self.lib.do_nms_obj
self.do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
self.do_nms_sort = self.lib.do_nms_sort
self.do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float]
self.free_image = self.lib.free_image
self.free_image.argtypes = [IMAGE]
load_meta = self.lib.get_metadata
self.lib.get_metadata.argtypes = [c_char_p]
self.lib.get_metadata.restype = METADATA
self.load_image = self.lib.load_image_color
self.load_image.argtypes = [c_char_p, c_int, c_int]
self.load_image.restype = IMAGE
self.predict_image = self.lib.network_predict_image
self.predict_image.argtypes = [c_void_p, IMAGE]
self.predict_image.restype = POINTER(c_float)
self.netMain = load_net_custom(configPath.encode("ascii"), weightPath.encode("ascii"), 0, 1) # batch size = 1
self.metaMain = load_meta(metaPath.encode("ascii"))
with open(metaPath) as metaFH:
metaContents = metaFH.read()
import re
match = re.search("names *= *(.*)$", metaContents, re.IGNORECASE | re.MULTILINE)
if match:
result = match.group(1)
else:
result = None
try:
if os.path.exists(result):
with open(result) as namesFH:
namesList = namesFH.read().strip().split("\n")
self.altNames = [x.strip() for x in namesList]
except TypeError:
pass
def network_width(self):
return self.lib.network_width(self.netMain)
def network_height(self):
return self.lib.network_height(self.netMain)
def array_to_image(self, arr):
# need to return old values to avoid python freeing memory
arr = arr.transpose(2,0,1)
c = arr.shape[0]
h = arr.shape[1]
w = arr.shape[2]
arr = np.ascontiguousarray(arr.flat, dtype=np.float32) / 255.0
data = arr.ctypes.data_as(POINTER(c_float))
im = IMAGE(w,h,c,data)
return im, arr
def detect(self, image, thresh=.5, hier_thresh=.5, nms=.45):
im, _ = self.array_to_image(image)
# W, H = self.network_width(), self.network_height()
# im = cv2.resize(im, (W,H))
ret = self._detect_image(im, thresh, hier_thresh, nms)
return ret
def _detect_image(self, im, thresh=.5, hier_thresh=.5, nms=.45):
num = c_int(0)
pnum = pointer(num)
self.predict_image(self.netMain, im)
letter_box = 0
dets = self.get_network_boxes(self.netMain, im.w, im.h, thresh, hier_thresh, None, 0, pnum, letter_box)
num = pnum[0]
if nms:
self.do_nms_sort(dets, num, self.metaMain.classes, nms)
res = []
for j in range(num):
for i in range(self.metaMain.classes):
if dets[j].prob[i] > 0:
b = dets[j].bbox
nameTag = self.metaMain.names[i]
left, right = b.x-b.w/2, b.x+b.w/2
top, bottom = b.y-b.h/2, b.y+b.h/2
res.append({
'label': nameTag.decode('utf-8'),
'confidence': float(dets[j].prob[i]),
'box': {
'top_left': {'x': left, 'y': top},
'bottom_right': {'x': right, 'y': bottom},
}})
res = sorted(res, key=lambda x: -x['confidence'])
self.free_detections(dets, num)
return res