-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_objetcs_and_coordinates.py
146 lines (124 loc) · 5.25 KB
/
get_objetcs_and_coordinates.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
This script test the trained model and performs object detection on a given image, the script prints
in console a JSON with the objects detected and their info.
E.g.
>>> python get_objects_and_coordinates.py -i ../images/test_model/IMG_2362.jpg -o some_folder/ -v true
output:
{
"jugo_jumex": {
"score": 0.9995954632759094,
"ymin": 435.3594779968262,
"xmin": 185.599547624588,
"ymax": 547.5826263427734,
"xmax": 248.1843888759613
},
"coca_cola": {
"score": 0.9212740063667297,
"ymin": 314.3073320388794,
"xmin": 328.5956025123596,
"ymax": 473.17280769348145,
"xmax": 405.4909944534302
}
}
"""
import os
import sys
import json
import argparse
import cv2
import numpy as np
import tensorflow as tf
from utils import label_map_util
from utils import visualization_utils as vis_util
def run_model(image_directory):
"""
This function loads the neccesary files, runs the model and gets the detection boxes, scores and classes.
"""
global boxes, scores, classes, num, category_index, image
MODEL_NAME = 'inference_graph'
CWD_PATH = os.getcwd()
PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,'frozen_inference_graph.pb')
PATH_TO_LABELS = os.path.join(CWD_PATH,'training','labelmap.pbtxt')
PATH_TO_IMAGE = image_directory
NUM_CLASSES = 3
# Load the label map.
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
# Load the Tensorflow model into memory.
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
sess = tf.Session(graph=detection_graph)
# Define input and output tensors (i.e. data) for the object detection classifier
image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
num_detections = detection_graph.get_tensor_by_name('num_detections:0')
# Load image using OpenCV
image = cv2.imread(PATH_TO_IMAGE)
image_expanded = np.expand_dims(image, axis=0)
# Perform the actual detection by running the model with the image as input
(boxes, scores, classes, num) = sess.run(
[detection_boxes, detection_scores, detection_classes, num_detections],
feed_dict={image_tensor: image_expanded})
def get_objects():
"""
This function creates a dict of the detected objects with its coordinates.
"""
MIN_SCORE_THRESH = 0.85
height = image.shape[0]
width = image.shape[1]
objects = {}
for index,value in enumerate(classes[0]):
if scores[0,index] > MIN_SCORE_THRESH:
if category_index.get(value)['name'] in objects:
# in case it detects more that one of each object, grabs the one with higher score
if objects[category_index.get(value)['name']]['score'] > scores[0,index]:
continue
objects[category_index.get(value)['name']] = {
'score': float(scores[0,index]),
'ymin': float(boxes[0][index][0]*height),
'xmin': float(boxes[0][index][1]*width),
'ymax': float(boxes[0][index][2]*height),
'xmax': float(boxes[0][index][3]*width)
}
return objects
def display_image_results():
"""
This function draws the results of the detection (aka 'visulaize the results')
"""
vis_util.visualize_boxes_and_labels_on_image_array(
image,
np.squeeze(boxes),
np.squeeze(classes).astype(np.int32),
np.squeeze(scores),
category_index,
use_normalized_coordinates=True,
line_thickness=8,
min_score_thresh=0.85)
# Display the results image.
cv2.imshow('image_results', image)
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Get objects and coordinates")
parser.add_argument('-i', '--input_directory', type=str, required=True, help='Input directory containing the image to analize')
parser.add_argument('-o', '--output_directory', type=str, required=True, help='Output directory to write the JSON results')
parser.add_argument('-v', '--image_view', type=str, required=True, help='Display the image results (true/false)')
args = parser.parse_args()
run_model(args.input_directory)
detected_objects = get_objects()
# Write JSON results to a txt file
with open(os.path.join(args.output_directory,'detected_objects.txt'),'w') as outfile:
json.dump(detected_objects, outfile)
# Print results on terminal too
print('Results')
print(json.dumps(detected_objects, indent=4))
if (args.image_view == 'true'):
display_image_results()