Skip to content

Commit

Permalink
EfficientDet Python Sample Updates
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <rajeevrao@nvidia.com>
  • Loading branch information
wraveane authored and rajeevsrao committed Mar 24, 2022
1 parent d868c4c commit 763dbfd
Show file tree
Hide file tree
Showing 8 changed files with 611 additions and 362 deletions.
145 changes: 110 additions & 35 deletions samples/python/efficientdet/README.md

Large diffs are not rendered by default.

153 changes: 108 additions & 45 deletions samples/python/efficientdet/build_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def read_calibration_cache(self):
Read the calibration cache file stored on disk, if it exists.
:return: The contents of the cache file, if any.
"""
if os.path.exists(self.cache_file):
if self.cache_file is not None and os.path.exists(self.cache_file):
with open(self.cache_file, "rb") as f:
log.info("Using calibration cache file: {}".format(self.cache_file))
return f.read()
Expand All @@ -102,6 +102,8 @@ def write_calibration_cache(self, cache):
Store the calibration cache to a file on disk.
:param cache: The contents of the calibration cache to store.
"""
if self.cache_file is None:
return
with open(self.cache_file, "wb") as f:
log.info("Writing calibration cache data to: {}".format(self.cache_file))
f.write(cache)
Expand All @@ -127,14 +129,16 @@ def __init__(self, verbose=False, workspace=8):
self.config = self.builder.create_builder_config()
self.config.max_workspace_size = workspace * (2 ** 30)

self.batch_size = None
self.network = None
self.parser = None

def create_network(self, onnx_path):
def create_network(self, onnx_path, batch_size, dynamic_batch_size=None):
"""
Parse the ONNX graph and create the corresponding TensorRT network definition.
:param onnx_path: The path to the ONNX graph to load.
:param batch_size: Static batch size to build the engine with.
:param dynamic_batch_size: Dynamic batch size to build the engine with, if given,
batch_size is ignored, pass as a comma-separated string or int list as MIN,OPT,MAX
"""
network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

Expand All @@ -149,24 +153,72 @@ def create_network(self, onnx_path):
log.error(self.parser.get_error(error))
sys.exit(1)

inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]

log.info("Network Description")

inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
profile = self.builder.create_optimization_profile()
dynamic_inputs = False
for input in inputs:
self.batch_size = input.shape[0]
log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype))
if input.shape[0] == -1:
dynamic_inputs = True
if dynamic_batch_size:
if type(dynamic_batch_size) is str:
dynamic_batch_size = [int(v) for v in dynamic_batch_size.split(",")]
assert len(dynamic_batch_size) == 3
min_shape = [dynamic_batch_size[0]] + list(input.shape[1:])
opt_shape = [dynamic_batch_size[1]] + list(input.shape[1:])
max_shape = [dynamic_batch_size[2]] + list(input.shape[1:])
profile.set_shape(input.name, min_shape, opt_shape, max_shape)
log.info("Input '{}' Optimization Profile with shape MIN {} / OPT {} / MAX {}".format(
input.name, min_shape, opt_shape, max_shape))
else:
shape = [batch_size] + list(input.shape[1:])
profile.set_shape(input.name, shape, shape, shape)
log.info("Input '{}' Optimization Profile with shape {}".format(input.name, shape))
if dynamic_inputs:
self.config.add_optimization_profile(profile)

outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]
for output in outputs:
log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype))
assert self.batch_size > 0
self.builder.max_batch_size = self.batch_size

def set_mixed_precision(self):
"""
Experimental precision mode.
Enable mixed-precision mode. When set, the layers defined here will be forced to FP16 to maximize
INT8 inference accuracy, while having minimal impact on latency.
"""
self.config.set_flag(trt.BuilderFlag.STRICT_TYPES)

# All convolution operations in the first four blocks of the graph are pinned to FP16.
# These layers have been manually chosen as they give a good middle-point between int8 and fp16
# accuracy in COCO, while maintining almost the same latency as a normal int8 engine.
# To experiment with other datasets, or a different balance between accuracy/latency, you may
# add or remove blocks.
for i in range(self.network.num_layers):
layer = self.network.get_layer(i)
if layer.type == trt.LayerType.CONVOLUTION and any([
# AutoML Layer Names:
"/stem/" in layer.name,
"/blocks_0/" in layer.name,
"/blocks_1/" in layer.name,
"/blocks_2/" in layer.name,
# TFOD Layer Names:
"/stem_conv2d/" in layer.name,
"/stack_0/block_0/" in layer.name,
"/stack_1/block_0/" in layer.name,
"/stack_1/block_1/" in layer.name,
]):
self.network.get_layer(i).precision = trt.DataType.HALF
log.info("Mixed-Precision Layer {} set to HALF STRICT data type".format(layer.name))

def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=5000,
calib_batch_size=8):
"""
Build the TensorRT engine and serialize it to disk.
:param engine_path: The path where to serialize the engine to.
:param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
:param precision: The datatype to use for the engine, either 'fp32', 'fp16', 'int8', or 'mixed'.
:param calib_input: The path to a directory holding the calibration images.
:param calib_cache: The path where to write the calibration cache to, or if it already exists, load it from.
:param calib_num_images: The maximum number of images to use for calibration.
Expand All @@ -179,62 +231,73 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No

inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]

if precision == "fp16":
if precision in ["fp16", "int8", "mixed"]:
if not self.builder.platform_has_fast_fp16:
log.warning("FP16 is not supported natively on this platform/device")
else:
self.config.set_flag(trt.BuilderFlag.FP16)
elif precision == "int8":
self.config.set_flag(trt.BuilderFlag.FP16)
if precision in ["int8", "mixed"]:
if not self.builder.platform_has_fast_int8:
log.warning("INT8 is not supported natively on this platform/device")
else:
if self.builder.platform_has_fast_fp16:
# Also enable fp16, as some layers may be even more efficient in fp16 than int8
self.config.set_flag(trt.BuilderFlag.FP16)
self.config.set_flag(trt.BuilderFlag.INT8)
self.config.int8_calibrator = EngineCalibrator(calib_cache)
if not os.path.exists(calib_cache):
calib_shape = [calib_batch_size] + list(inputs[0].shape[1:])
calib_dtype = trt.nptype(inputs[0].dtype)
self.config.int8_calibrator.set_image_batcher(
ImageBatcher(calib_input, calib_shape, calib_dtype, max_num_images=calib_num_images,
exact_batches=True))

with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f:
self.config.set_flag(trt.BuilderFlag.INT8)
self.config.int8_calibrator = EngineCalibrator(calib_cache)
if calib_cache is None or not os.path.exists(calib_cache):
calib_shape = [calib_batch_size] + list(inputs[0].shape[1:])
calib_dtype = trt.nptype(inputs[0].dtype)
self.config.int8_calibrator.set_image_batcher(
ImageBatcher(calib_input, calib_shape, calib_dtype, max_num_images=calib_num_images,
exact_batches=True, shuffle_files=True))

engine_bytes = None
try:
engine_bytes = self.builder.build_serialized_network(self.network, self.config)
except AttributeError:
engine = self.builder.build_engine(self.network, self.config)
engine_bytes = engine.serialize()
del engine
assert engine_bytes
with open(engine_path, "wb") as f:
log.info("Serializing engine to file: {:}".format(engine_path))
f.write(engine.serialize())
f.write(engine_bytes)


def main(args):
builder = EngineBuilder(args.verbose, args.workspace)
builder.create_network(args.onnx)
builder.create_network(args.onnx, args.batch_size, args.dynamic_batch_size)
if args.precision == "mixed":
builder.set_mixed_precision()
builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images,
args.calib_batch_size)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--onnx", help="The input ONNX model file to load")
parser.add_argument("-e", "--engine", help="The output path for the TRT engine")
parser.add_argument("-p", "--precision", default="fp16", choices=["fp32", "fp16", "int8"],
help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'")
parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output")
parser.add_argument("-w", "--workspace", default=8, type=int, help="The max memory workspace size to allow in Gb, "
"default: 8")
parser.add_argument("--calib_input", help="The directory holding images to use for calibration")
parser.add_argument("--calib_cache", default="./calibration.cache",
parser.add_argument("-o", "--onnx", required=True,
help="The input ONNX model file to load")
parser.add_argument("-e", "--engine", required=True,
help="The output path for the TRT engine")
parser.add_argument("-b", "--batch_size", default=1, type=int,
help="The static batch size to build the engine with, default: 1")
parser.add_argument("-d", "--dynamic_batch_size", default=None,
help="Enable dynamic batch size by providing a comma-separated MIN,OPT,MAX batch size, "
"if this option is set, --batch_size is ignored, example: -d 1,16,32, "
"default: None, build static engine")
parser.add_argument("-p", "--precision", default="fp16", choices=["fp32", "fp16", "int8", "mixed"],
help="The precision mode to build in, either fp32/fp16/int8/mixed, default: fp16")
parser.add_argument("-v", "--verbose", action="store_true",
help="Enable more verbose log output")
parser.add_argument("-w", "--workspace", default=8, type=int,
help="The max memory workspace size to allow in Gb, default: 8")
parser.add_argument("--calib_input",
help="The directory holding images to use for calibration")
parser.add_argument("--calib_cache", default=None,
help="The file path for INT8 calibration cache to use, default: ./calibration.cache")
parser.add_argument("--calib_num_images", default=5000, type=int,
help="The maximum number of images to use for calibration, default: 5000")
parser.add_argument("--calib_batch_size", default=8, type=int,
help="The batch size for the calibration process, default: 8")
args = parser.parse_args()
if not all([args.onnx, args.engine]):
parser.print_help()
log.error("These arguments are required: --onnx and --engine")
sys.exit(1)
if args.precision == "int8" and not (args.calib_input or os.path.exists(args.calib_cache)):
if args.precision in ["int8", "mixed"] and not (args.calib_input or os.path.exists(args.calib_cache)):
parser.print_help()
log.error("When building in int8 precision, --calib_input or an existing --calib_cache file is required")
log.error("When building in int8 or mixed precision, --calib_input or an existing --calib_cache file is required")
sys.exit(1)
main(args)
91 changes: 2 additions & 89 deletions samples/python/efficientdet/compare_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,103 +23,16 @@
import tensorflow as tf

from infer import TensorRTInfer
from infer_tf import TensorFlowInfer
from image_batcher import ImageBatcher
from visualize import visualize_detections, concat_visualizations


class TensorFlowInfer:
"""
Implements TensorFlow inference of a saved model, following the same API as the TensorRTInfer class.
"""

def __init__(self, saved_model_path):
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)

self.model = tf.saved_model.load(saved_model_path)
self.pred_fn = self.model.signatures['serving_default']

# Setup I/O bindings
self.inputs = []
fn_inputs = self.pred_fn.structured_input_signature[1]
for i, input in enumerate(list(fn_inputs.values())):
self.inputs.append({
'index': i,
'name': input.name,
'dtype': np.dtype(input.dtype.as_numpy_dtype()),
'shape': [1, 512, 512, 3], # This can be overridden later
})
self.outputs = []
fn_outputs = self.pred_fn.structured_outputs
for i, output in enumerate(list(fn_outputs.values())):
self.outputs.append({
'index': i,
'name': output.name,
'dtype': np.dtype(output.dtype.as_numpy_dtype()),
'shape': output.shape.as_list(),
})

def override_input_shape(self, input, shape):
self.inputs[input]['shape'] = shape

def input_spec(self):
return self.inputs[0]['shape'], self.inputs[0]['dtype']

def output_spec(self):
return self.outputs[0]['shape'], self.outputs[0]['dtype']

def infer(self, batch, scales=None, nms_threshold=None):
# Process I/O and execute the network
input = {self.inputs[0]['name']: tf.convert_to_tensor(batch)}
output = self.pred_fn(**input)

# Extract the results depending on what kind of saved model this is
boxes = None
scores = None
classes = None
if len(self.outputs) == 1:
# Detected as AutoML Saved Model
assert len(self.outputs[0]['shape']) == 3 and self.outputs[0]['shape'][2] == 7
results = output[self.outputs[0]['name']].numpy()
boxes = results[:, :, 1:5]
scores = results[:, :, 5]
classes = results[:, :, 6].astype(np.int32)
elif len(self.outputs) >= 4:
# Detected as TFOD Saved Model
assert output['num_detections']
num = int(output['num_detections'].numpy().flatten()[0])
boxes = output['detection_boxes'].numpy()[:, 0:num, :]
scores = output['detection_scores'].numpy()[:, 0:num]
classes = output['detection_classes'].numpy()[:, 0:num]

# Process the results
detections = [[]]
normalized = (np.max(boxes) < 2.0)
for n in range(scores.shape[1]):
if scores[0][n] == 0.0:
break
scale = self.inputs[0]['shape'][2] if normalized else 1.0
if scales:
scale /= scales[0]
if nms_threshold and scores[0][n] < nms_threshold:
continue
detections[0].append({
'ymin': boxes[0][n][0] * scale,
'xmin': boxes[0][n][1] * scale,
'ymax': boxes[0][n][2] * scale,
'xmax': boxes[0][n][3] * scale,
'score': scores[0][n],
'class': int(classes[0][n]) - 1,
})
return detections


def run(batcher, inferer, framework, nms_threshold=None):
res_images = []
res_detections = []
for batch, images, scales in batcher.get_batch():
res_detections += inferer.infer(batch, scales, nms_threshold)
res_detections += inferer.process(batch, scales, nms_threshold)
res_images += images
print("Processing {} / {} images ({})".format(batcher.image_index, batcher.num_images, framework), end="\r")
print()
Expand Down
Loading

0 comments on commit 763dbfd

Please sign in to comment.