Skip to content
This repository has been archived by the owner on Aug 18, 2020. It is now read-only.

Commit

Permalink
Merge in competition changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bbridges committed Jul 14, 2018
2 parents e4c49e1 + dc2a93f commit 0bfbf97
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 52 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,6 @@ ENV/
# target-finder-cli output
/blob-*.jpg
/target-*.jpg

# target-finder model files
/target_finder/data/
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
Pillow>=4.3.0
scipy
webcolors>=1.7
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def load_requirements():
author='Unmanned Aerial Vehicle Team | UT Austin',
url='https://github.com/uavaustin/target-finder',
packages=find_packages(),
include_package_data=True,
package_data={
'target_finder': [
'data/retrained_graph.pb', 'data/retrained_labels.txt'
]
},
install_requires=[load_requirements()],
entry_points='''
[console_scripts]
Expand Down
1 change: 0 additions & 1 deletion target_finder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

from .classification import find_targets
from .preprocessing import find_blobs
from .training import train
from .types import Blob, Color, Shape, Target
from .version import __version__
194 changes: 183 additions & 11 deletions target_finder/classification.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,56 @@
"""Contains logic for finding targets in blobs."""

import os
from pkg_resources import resource_filename

import cv2
import numpy as np
import os
import PIL.Image
import scipy.cluster
import scipy.misc
import tensorflow as tf
import warnings
import webcolors

from .preprocessing import find_blobs
from .types import Color, Shape, Target


# TODO: Make a Tensorflow session we can use. There should be two
# seperate lookups for models. The first one will see if we
# have a user-made one, if so, we'll use that one. Otherwise,
# we'll use a default model that ships with the library.
tf_session = None
softmax_tensor = None
label_lines = None

graph_loc = resource_filename(__name__, 'data/graph.pb')
labels_loc = resource_filename(__name__, 'data/labels.txt')

# Configure if the graph and label files exist, otherwise, send a
# warning.
if os.path.isfile(graph_loc) and os.path.isfile(labels_loc):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Load label file and strip off newlines.
label_lines = [line.rstrip() for line in tf.gfile.GFile(labels_loc)]

# Register the graph with tensorflow.
with tf.gfile.FastGFile(graph_loc, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

tf.import_graph_def(graph_def, name='')

def find_targets(image=None, blobs=None, min_confidence=0.95, limit=10):
"""Returns the targets found in an image.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

tf_session = tf.Session(config=config)
softmax_tensor = tf_session.graph.get_tensor_by_name('final_result:0')
else:
warnings.warn('Missing model files. Classification will not return any ' +
'targets.')


def find_targets(image=None, blobs=None, min_confidence=0.85, limit=10):
"""Return the targets found in an image.
Targets are returned in the order of highest confidence. Once the
limit is hit, classification will stop and just take the first
Expand All @@ -36,11 +75,16 @@ def find_targets(image=None, blobs=None, min_confidence=0.95, limit=10):
List[Target]: The list of targets found.
"""

# Check that if we don't have an image passed, that the each
# blobs have their own image.
# Check that when is not an image passed, the each blob have
# their own image.
if image is None and blobs is None:
raise Exception('Blobs must be provided if an image is not.')

# If there is not a tensorflow session because of a missing graph
# or labels, then there's nothing to do.
if tf_session is None:
return []

# If we didn't get blobs, then we'll find them.
if blobs is None:
blobs = find_blobs(image)
Expand All @@ -50,7 +94,8 @@ def find_targets(image=None, blobs=None, min_confidence=0.95, limit=10):
# Try and find a target for each blob, if it exists then register
# it. Stop if we hit the limit.
for blob in blobs:
if len(targets) == limit: break
if len(targets) == limit:
break

target = _do_classify(blob, min_confidence)

Expand All @@ -71,5 +116,132 @@ def _do_classify(blob, min_confidence):
Returns None if it's not a target.
"""

# TODO: Implement.
return None
cropped_img = blob.image

image_array = cropped_img.convert('RGB')
predictions = tf_session.run(softmax_tensor, {'DecodeJpeg:0': image_array})

top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]

shape = Shape[label_lines[top_k[0]]]
confidence = predictions[0][top_k[0]]

shape = None

if confidence >= min_confidence and shape != Shape.NAS:
primary, secondary = _get_color(blob)
shape = Target(blob.x, blob.y, blob.width, blob.height, shape=shape,
background_color=primary, alphanumeric_color=secondary,
image=blob.image, confidence=confidence)

return shape


def _get_color(blob):
colors_set = {
'#000000': None,
'#000001': Color.BLACK,
'#ffffff': Color.WHITE,
'#407340': Color.GREEN,
'#94ff94': Color.GREEN,
'#00ff00': Color.GREEN,
'#008004': Color.GREEN,
'#525294': Color.BLUE,
'#7f7fff': Color.BLUE,
'#0000ff': Color.BLUE,
'#000087': Color.BLUE,
'#808080': Color.GRAY,
'#994c00': Color.BROWN,
'#e1dd68': Color.YELLOW,
'#fffc7a': Color.YELLOW,
'#fff700': Color.YELLOW,
'#d2cb00': Color.YELLOW,
'#d8ac53': Color.ORANGE,
'#FFCC65': Color.ORANGE,
'#ffa500': Color.ORANGE,
'#d28c00': Color.ORANGE,
'#bc3c3c': Color.RED,
'#ff5050': Color.RED,
'#ff0000': Color.RED,
'#9a0000': Color.RED,
'#800080': Color.PURPLE
}

mask_img = np.array(blob.image)

dst = mask_img
width, height = dst.shape[:2]

if width > height:
y1 = blob.y
y2 = blob.y + blob.height
x1 = blob.x
x2 = blob.x + blob.width
else:
x1 = blob.y
x2 = blob.y + blob.height
y1 = blob.x
y2 = blob.x + blob.width

if blob.has_mask:
mask = np.zeros(mask_img.shape[:2], dtype='uint8')
cv2.drawContours(mask, [blob.cnt], -1, 255, -1)
dst = cv2.bitwise_and(mask_img, mask_img, mask=mask)
else:
y1 = y1 + 5
y2 = y2 - 5
x1 = x1 + 5
x2 = x2 - 5

cropped_img = PIL.Image.fromarray(dst)
cropped_img.crop((x1, y1, x2, y2))

ar = scipy.misc.fromimage(cropped_img)
dim = ar.shape
ar = ar.reshape(scipy.product(dim[:2]), dim[2])
codes, dist = scipy.cluster.vq.kmeans(ar.astype(float), 3)

primary = _get_color_name(codes[0].astype(int), None, colors_set)

if len(codes) > 1:
secondary = _get_color_name(codes[1].astype(int), primary, colors_set)
else:
secondary = Color.NONE

# Ignore black mask for color detection, return the most
# prominent color as shape.
if primary is None:
primary = secondary
secondary = Color.NONE

if secondary == Color.NONE and len(codes) > 2:
tertiary = _get_color_name(codes[2].astype(int), secondary, colors_set)
secondary = tertiary

return primary, secondary


def _get_color_name(requested_color, prev_color, colors_set):
color_codes = {}
i = 0

# Makes sure alpha color and shape color are different.
if prev_color is not None:
for key, name in colors_set.items():
if name == prev_color:
color_codes[i] = key
i = i + 1
for i in color_codes:
del colors_set[color_codes[i]]

min_colors = {}

# Find closest color with a given RGB value.
for key, name in colors_set.items():
r_c, g_c, b_c = webcolors.hex_to_rgb(key)
rd = (r_c - requested_color[0]) ** 2
gd = (g_c - requested_color[1]) ** 2
bd = (b_c - requested_color[2]) ** 2
min_colors[(rd + gd + bd)] = name

return min_colors[min(min_colors.keys())]
21 changes: 10 additions & 11 deletions target_finder/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from .preprocessing import find_blobs


# Creating the top level parser.
# Create the top level parser.
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='subcommand', title='subcommands')

# Making our parser for the blobs subcommand.
# Parser for the blobs subcommand.
blob_parser = subparsers.add_parser('blobs', help='finds the interesting '
'blobs in images')
blob_parser.add_argument('filename', type=str, nargs='+',
Expand Down Expand Up @@ -42,16 +42,17 @@ def run_blobs(args):
"""Run the blobs subcommand."""
blob_num = 0

# Making the output directory if it doesn't already exist.
# Create the output directory if it doesn't already exist.
os.makedirs(args.output, exist_ok=True)

for filename in _list_images(args.filename):
image = PIL.Image.open(filename)
mask_img = cv2.imread(filename)

blobs = find_blobs(image, min_width=args.min_width, limit=args.limit,
padding=args.padding)
blobs = find_blobs(image, mask_img, min_width=args.min_width,
limit=args.limit, padding=args.padding)

# Saving each blob we find with an incrementing number.
# Save each blob found with an incrementing number.
for blob in blobs:
print('Saving blob #{:06d} from {:s}'.format(blob_num, filename))

Expand All @@ -66,8 +67,7 @@ def _list_images(filenames):
images = []

for filename in filenames:
# If we're just dealing with a normal filename, add it to the
# list directly.
# If this is a normal filename, add it to the list directly.
if os.path.isfile(filename):
images.append(filename)

Expand All @@ -84,7 +84,7 @@ def _list_images(filenames):
print('Bad filename: "{:s}".'.format(filename))
sys.exit(1)

# We have a problem if we can't find any images.
# There's a problem if we can't find any images.
if images == []:
print('No images found.')
sys.exit(1)
Expand All @@ -93,7 +93,6 @@ def _list_images(filenames):


# Set the functions to run for each subcommand. If a subcommand was
# not provided, we'll just print the usage message and set the exit
# code to 1.
# not provided, print the usage message and set the exit code to 1.
blob_parser.set_defaults(func=run_blobs)
parser.set_defaults(func=lambda _: parser.print_usage() or sys.exit(1))
Loading

0 comments on commit 0bfbf97

Please sign in to comment.