From 763dbfdace47de69f846a2da954f25ee6833027a Mon Sep 17 00:00:00 2001 From: William Raveane Date: Sun, 30 Jan 2022 18:53:58 -0500 Subject: [PATCH] EfficientDet Python Sample Updates Signed-off-by: Rajeev Rao --- samples/python/efficientdet/README.md | 145 +++++++--- samples/python/efficientdet/build_engine.py | 153 +++++++--- samples/python/efficientdet/compare_tf.py | 91 +----- samples/python/efficientdet/create_onnx.py | 290 +++++++++---------- samples/python/efficientdet/eval_coco.py | 2 +- samples/python/efficientdet/image_batcher.py | 8 +- samples/python/efficientdet/infer.py | 127 +++++--- samples/python/efficientdet/infer_tf.py | 157 ++++++++++ 8 files changed, 611 insertions(+), 362 deletions(-) create mode 100644 samples/python/efficientdet/infer_tf.py diff --git a/samples/python/efficientdet/README.md b/samples/python/efficientdet/README.md index 7128d958..037842a8 100644 --- a/samples/python/efficientdet/README.md +++ b/samples/python/efficientdet/README.md @@ -5,6 +5,7 @@ These scripts help with conversion and execution of [Google EfficientDet](https://arxiv.org/abs/1911.09070) models with [NVIDIA TensorRT](https://developer.nvidia.com/tensorrt). This process is compatible with models trained through either Google AutoML or the TensorFlow Object Detection API. ## Contents +- [Changelog](#changelog) - [Setup](#setup) - [Model Conversion](#model-conversion) * [TensorFlow Saved Model](#tensorflow-saved-model) @@ -15,12 +16,27 @@ These scripts help with conversion and execution of [Google EfficientDet](https: * [Evaluate mAP Metric](#evaluate-map-metric) * [TF vs TRT Comparison](#tf-vs-trt-comparison) +## Changelog + +- January 2022: + - Added support for EfficientDet Lite and AdvProp models. + - Added dynamic batch support. + - Added mixed precision engine builder. +- July 2021: + - Initial release. + ## Setup -For best results, we recommend running these scripts on an environment with TensorRT >= 8.0.1 and TensorFlow 2.5. +We recommend running these scripts on an environment with TensorRT >= 8.0.1 and TensorFlow >= 2.5. Install TensorRT as per the [TensorRT Install Guide](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html). You will need to make sure the Python bindings for TensorRT are also installed correctly, these are available by installing the `python3-libnvinfer` and `python3-libnvinfer-dev` packages on your TensorRT download. +To simplify TensorRT and TensorFlow installation, use an [NGC TensorFlow Docker Image](https://ngc.nvidia.com/catalog/containers/nvidia:tensorflow), such as: + +```bash +docker pull nvcr.io/nvidia/tensorflow:22.01-tf1-py3 +``` + Install all dependencies listed in `requirements.txt`: ```bash @@ -40,10 +56,12 @@ pip3 install onnx-graphsurgeon --index-url https://pypi.ngc.nvidia.com **NOTE:** Please make sure that the `onnx-graphsurgeon` module installed by pip is version >= 0.3.9. -Finally, you may want to clone the EfficientDet code from the [AutoML Repository](https://github.com/google/automl) to use some helper utilities from it: +Finally, you may want to clone the EfficientDet code from the [AutoML Repository](https://github.com/google/automl) to use some helper utilities from it. This exporter has been tested with commit [0b0ba5e](https://github.com/google/automl/tree/0b0ba5ebd0860edd939465fc4152da4ff9f79b44/efficientdet) from December 2021, so it may be a good idea to checkout the repository at that specific commit to avoid possible future incompatibilities: ```bash git clone https://github.com/google/automl +cd automl +git checkout 0b0ba5e ``` ## Model Conversion @@ -54,22 +72,17 @@ The workflow to convert an EfficientDet model is basically TensorFlow → ONNX The starting point of conversion is a TensorFlow saved model. This can be exported from your own trained models, or you can download a pre-trained model. This conversion script is compatible with two types of models: -1. EfficientDet models trained with the [AutoML](https://github.com/google/automl/tree/master/efficientdet) framework. +1. EfficientDet models trained with the [AutoML](https://github.com/google/automl/tree/master/efficientdet) framework. Compatible with all "d0-7", "lite0-4" and "AdvProp" variations. 2. EfficientDet models trained with the [TensorFlow Object Detection](https://github.com/tensorflow/models/tree/master/research/object_detection) API (TFOD). +3. EfficientDet models pre-trained on COCO and downloaded from [TFHub](https://tfhub.dev/s?network-architecture=efficientdet). #### 1. AutoML Models -You can download one of the pre-trained AutoML saved models from the [EfficientDet TFHub](https://tfhub.dev/s?network-architecture=efficientdet), such as: - -```bash -wget https://storage.googleapis.com/tfhub-modules/tensorflow/efficientdet/d0/1.tar.gz -``` - -The contents of this package, when extracted, will hold a saved model ready for conversion. +If you are training your own model, you will need the training checkpoint. You can also download a pre-trained checkpoint from the "ckpt" links on the [AutoML Repository](https://github.com/google/automl/tree/master/efficientdet) README file, such as [this](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.tar.gz). -**NOTE:** Some saved models in TFHub may give problems with ONNX conversion. If so, please download the original checkpoint and export the saved model manually as per the instructions below. +This converter is compatible with all *efficientdet-d0* through *efficientdet-d7x*, and *efficientdet-lite0* through *efficientdet-lite4* model variations. This converter also works with the [AdvProp](https://github.com/google/automl/blob/master/efficientdet/Det-AdvProp.md) models. However, AdvProp models are trained with the `scale_range` hparam, which changes the expected input image value range, so you will need to adjust the preprocessor argument when creating the ONNX graph. More details on the corresponding section below. -Alternatively, if you are training your own model, or if you need to re-export the saved model manually, you will need the training checkpoint (or a pre-trained "ckpt" from the [AutoML Repository](https://github.com/google/automl/tree/master/efficientdet) such as [this](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientdet/coco2/efficientdet-d0.tar.gz)). The checkpoint directory should have a file structure such as this: +The checkpoint directory should have a file structure such as this: ``` efficientdet-d0 @@ -103,7 +116,7 @@ wget http://download.tensorflow.org/models/object_detection/tf2/20200711/efficie When extracted, this package holds a directory named `saved_model` which holds the saved model ready for conversion. -However, if you are working with your own trained model, or if you need to re-export the saved model, you can do so from the training checkpoint. The downloaded package above also contains a pre-trained checkpoint. The structure is similar to this: +However, if you are working with your own trained EfficientDet model from the TensorFlow Object Detection API, or if you need to re-export the saved model, you can do so from the training checkpoint. The downloaded package above also contains a pre-trained checkpoint. The structure is similar to this: ``` efficientdet_d0_coco17_tpu-32 @@ -130,31 +143,45 @@ Where `--trained_checkpoint_dir` and `--pipeline_config_path` point to the corre **NOTE:** TFOD EfficientDet models will have a slightly reduced throughput than their AutoML model counterparts. This is due to differences in the graph construction that TFOD makes use of. -### Create ONNX Graph +#### 3. TFHub Models -To generate an ONNX model file, first find the input shape that corresponds to the model you're converting: +You can download one of the pre-trained AutoML saved models from the [EfficientDet TFHub](https://tfhub.dev/s?network-architecture=efficientdet). Currently, only the efficientdet/d0 - d7 models are compatible with this converter. If you need to work with a pre-trained lite model, please follow the AutoML checkpoint route above. -| **Model** | **Input Shape** | -| -----------------|-----------------| -| EfficientDet D0 | N,512,512,3 | -| EfficientDet D1 | N,640,640,3 | -| EfficientDet D2 | N,768,768,3 | -| EfficientDet D3 | N,896,896,3 | -| EfficientDet D4 | N,1024,1024,3 | -| EfficientDet D5 | N,1280,1280,3 | -| EfficientDet D6 | N,1280,1280,3 | -| EfficientDet D7 | N,1536,1536,3 | -| EfficientDet D7x | N,1536,1536,3 | +Download a model from TFHub, such as: -Where **N** is the batch size you would like to run inference at, such as `8,512,512,3` for a batch size of 8. If you exported the saved model with a custom input image size, you should use that specific shape instead. +```bash +wget https://storage.googleapis.com/tfhub-modules/tensorflow/efficientdet/d0/1.tar.gz +``` -The ONNX conversion process supports both `NHWC` and `NCHW` input formats, so if your input source is an `NCHW` data format, you can use the corresponding input shape, i.e. `1,512,512,3` -> `1,3,512,512`. +The contents of this package, when extracted, will hold a saved model ready for conversion. -With the correct input shape selected, and the TF saved model ready to be converted, run: +### Create ONNX Graph + +To generate an ONNX model file, first find the input size that corresponds to the model you're converting: + +| **Model** | **Input Size** | +| --------------------|----------------| +| efficientdet-d0 | 512,512 | +| efficientdet-d1 | 640,640 | +| efficientdet-d2 | 768,768 | +| efficientdet-d3 | 896,896 | +| efficientdet-d4 | 1024,1024 | +| efficientdet-d5 | 1280,1280 | +| efficientdet-d6 | 1280,1280 | +| efficientdet-d7 | 1536,1536 | +| efficientdet-d7x | 1536,1536 | +| efficientdet-lite0 | 320,320 | +| efficientdet-lite1 | 384,384 | +| efficientdet-lite2 | 448,448 | +| efficientdet-lite3 | 512,512 | +| efficientdet-lite3x | 640,640 | +| efficientdet-lite4 | 640,640 | + +If you've re-exported the model with a custom image size, then of course use that. With the correct input size and the TF saved model ready to be converted, run: ```bash python3 create_onnx.py \ - --input_shape '1,512,512,3' \ + --input_size 512,512 \ --saved_model /path/to/saved_model \ --onnx /path/to/model.onnx ``` @@ -163,8 +190,9 @@ This will create the file `model.onnx` which is ready to convert to TensorRT. The script has a few optional arguments, including: +* `--input_format [NHWC,NCHW]` allows switching between NHWC (default) and NCHW data format modes. If your data source is in NCHW format, you may want to select this mode to avoid extra transposes. * `--nms_threshold [...]` allows overriding the default NMS score threshold parameter, as the runtime latency of the NMS plugin is sensitive to this value. It's a good practice to set this value as high as possible, while still fulfilling your application requirements, to reduce inference latency. -* `--legacy_plugins` allows falling back to older plugins on systems where a version lower than TensorRT 8.0.1 is installed. This will result in substantially slower inference times however, but is provided for compatibility. +* `--preprocessor [imagenet,scale_range]` allows switching between two possible image preprocessing methods. Most EfficientDet models use the `imagenet` method, which this argument defaults to, and corresponds to standard ImageNet mean subtraction and standard deviation normalization. The `scale_range` method instead normalizes the image to a range of [-1,+1]. Please use this method only when converting the **AdvProp** pre-trained checkpoints, as they were created with this preprocessor operation. Optionally, you may wish to visualize the resulting ONNX graph with a tool such as [Netron](https://netron.app/). @@ -172,11 +200,23 @@ Optionally, you may wish to visualize the resulting ONNX graph with a tool such The input to the graph is a `float32` tensor with the selected input shape, containing RGB pixel data in the range of 0 to 255. Normalization, mean subtraction and scaling will be performed inside the EfficientDet graph, so it is not required to further pre-process the input data. -The outputs of the graph are the same as the outputs of the [EfficientNMS](https://github.com/NVIDIA/TensorRT/tree/master/plugin/efficientNMSPlugin) plugin. If the ONNX graph was created with `--legacy_plugins` for TensorRT 7 compatibility, the outputs will correspond to those of the [BatchedNMS](https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin) plugin instead. +The outputs of the graph are the same as the outputs of the [EfficientNMS](https://github.com/NVIDIA/TensorRT/tree/main/plugin/efficientNMSPlugin) plugin. ### Build TensorRT Engine -It is possible to build the TensorRT engine directly with `trtexec` using the ONNX graph generated in the previous step. However, the script `build_engine.py` is provided for convenience, as it has been tailored to EfficientDet engine building and calibration. Run `python3 build_engine.py --help` for details on available settings. +It is possible to build the TensorRT engine directly with `trtexec` using the ONNX graph generated in the previous step. You can do so by running: + +```bash +trtexec \ + --onnx=/path/to/model.onnx \ + --saveEngine=/path/to/engine.trt \ + --optShapes=input:$INPUT_SHAPE \ + --workspace=1024 +``` + +Where `$INPUT_SHAPE` defines the input spec to build the engine with, e.g. `--optShapes=input:8x512x512x3`. Other common `trtexec` functionality for lower precision modes or other options will also work as expected. + +However, the script `build_engine.py` is also provided in this repository for convenience, as it has been tailored to EfficientDet engine building and INT8 calibration. Run `python3 build_engine.py --help` for details on available settings. #### FP16 Precision @@ -208,7 +248,42 @@ python3 build_engine.py \ Where `--calib_input` points to a directory with several thousands of images. For example, this could be a subset of the training or validation datasets that were used for the model. It's important that this data represents the runtime data distribution relatively well, therefore, the more images that are used for calibration, the better accuracy that will be achieved in INT8 precision. For models trained for the [COCO dataset](https://cocodataset.org/#home), we have found that 5,000 images gives a good result. -The `--calib_cache` controls where the calibration cache file will be written to. This is useful to keep a cached copy of the calibration results. Next time you need to build the engine for the same network, if this file exists, it will skip the calibration step and use the cached values instead. +The `--calib_cache` is optional, and it controls where the calibration cache file will be written to. This is useful to keep a cached copy of the calibration results. Next time you need to build an int8 engine for the same network, if this file exists, the builder will skip the calibration step and use the cached values instead. + +#### Mixed Precision (Experimental) + +Mixed precision is a custom mode that pins some key layers to FP16, while the rest of the network is converted at INT8 precision. The purpose of this mode is to balance accuracy and throughput. It's experimental and is given here to show one possible way of balancing achieved accuracy according to an application's latency budget. This mode has been tuned for COCO pre-trained models. For other datasets, you may need to adjust the layers to pin. + +Some sample results of using this mode: + +| **Model / Precision** | **Latency** | **COCO mAP** | +| ------------------------|-------------|--------------| +| efficientdet-d0 / fp32 | 3.25 ms | 0.341 | +| efficientdet-d0 / fp16 | 2.27 ms | 0.341 | +| efficientdet-d0 / mixed | **1.75 ms** | **0.320** | +| efficientdet-d0 / int8 | 1.63 ms | 0.299 | + +To use mixed precision mode, follow the same instructions as for building and calibrating an INT8 engine as given above, but using the argument `--precision mixed` instead. + +#### Static and Dynamic Batch Size + +By default, `build_engine.py` creates a static batch size 1 engine. To build with a different static batch size, set the `--batch_size` argument accordingly: + +```bash +python3 build_engine.py \ + --onnx /path/to/model.onnx \ + --engine /path/to/engine.trt \ + --batch_size 8 +``` + +You can also build an engine with a dynamic batch size. To do so, select a minimum and maximum batch size, as well as an optimal batch size for which TensorRT will fine tune the engine performance best. These batch sizes should be given via the argument `--dynamic_batch_size MIN,OPT,MAX`, such as: + +```bash +python3 build_engine.py \ + --onnx /path/to/model.onnx \ + --engine /path/to/engine.trt \ + --dynamic_batch_size 1,16,32 +``` #### Benchmark Engine @@ -252,7 +327,7 @@ The detection results will be written out to the specified output directory, con ![infer](https://drive.google.com/uc?export=view&id=1ZzTHizLx65t_cJcIIflnzXA5yxCYsQz6) -> *This example is generated with a TensorRT engine for the pre-trained AutoML EfficientDet-D0 model re-exported with a custom image size of 1920x1080 as described above. The engine uses an NMS score threshold of 0.4. This is the same [sample image](https://user-images.githubusercontent.com/11736571/77320690-099af300-6d37-11ea-9d86-24f14dc2d540.png) and model parameters as used in the AutoML [inference tutorial](https://github.com/google/automl/blob/master/efficientdet/g3doc/street.jpg).* +> *This example is generated with a TensorRT engine for the pre-trained AutoML EfficientDet-D0 model re-exported with a custom image size of 1920x1080 as described above. The engine uses an NMS score threshold of 0.4. This is the same [sample image](https://user-images.githubusercontent.com/11736571/77320690-099af300-6d37-11ea-9d86-24f14dc2d540.png) and model parameters as used in the AutoML [inference tutorial](https://github.com/google/automl/blob/master/efficientdet/tutorial.ipynb) to produce this [sample TensorFlow inference image](https://github.com/google/automl/blob/master/efficientdet/g3doc/street.jpg).* ### Evaluate mAP Metric diff --git a/samples/python/efficientdet/build_engine.py b/samples/python/efficientdet/build_engine.py index dfae3588..6e4eecb2 100644 --- a/samples/python/efficientdet/build_engine.py +++ b/samples/python/efficientdet/build_engine.py @@ -91,7 +91,7 @@ def read_calibration_cache(self): Read the calibration cache file stored on disk, if it exists. :return: The contents of the cache file, if any. """ - if os.path.exists(self.cache_file): + if self.cache_file is not None and os.path.exists(self.cache_file): with open(self.cache_file, "rb") as f: log.info("Using calibration cache file: {}".format(self.cache_file)) return f.read() @@ -102,6 +102,8 @@ def write_calibration_cache(self, cache): Store the calibration cache to a file on disk. :param cache: The contents of the calibration cache to store. """ + if self.cache_file is None: + return with open(self.cache_file, "wb") as f: log.info("Writing calibration cache data to: {}".format(self.cache_file)) f.write(cache) @@ -127,14 +129,16 @@ def __init__(self, verbose=False, workspace=8): self.config = self.builder.create_builder_config() self.config.max_workspace_size = workspace * (2 ** 30) - self.batch_size = None self.network = None self.parser = None - def create_network(self, onnx_path): + def create_network(self, onnx_path, batch_size, dynamic_batch_size=None): """ Parse the ONNX graph and create the corresponding TensorRT network definition. :param onnx_path: The path to the ONNX graph to load. + :param batch_size: Static batch size to build the engine with. + :param dynamic_batch_size: Dynamic batch size to build the engine with, if given, + batch_size is ignored, pass as a comma-separated string or int list as MIN,OPT,MAX """ network_flags = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) @@ -149,24 +153,72 @@ def create_network(self, onnx_path): log.error(self.parser.get_error(error)) sys.exit(1) - inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] - outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] - log.info("Network Description") + + inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] + profile = self.builder.create_optimization_profile() + dynamic_inputs = False for input in inputs: - self.batch_size = input.shape[0] log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype)) + if input.shape[0] == -1: + dynamic_inputs = True + if dynamic_batch_size: + if type(dynamic_batch_size) is str: + dynamic_batch_size = [int(v) for v in dynamic_batch_size.split(",")] + assert len(dynamic_batch_size) == 3 + min_shape = [dynamic_batch_size[0]] + list(input.shape[1:]) + opt_shape = [dynamic_batch_size[1]] + list(input.shape[1:]) + max_shape = [dynamic_batch_size[2]] + list(input.shape[1:]) + profile.set_shape(input.name, min_shape, opt_shape, max_shape) + log.info("Input '{}' Optimization Profile with shape MIN {} / OPT {} / MAX {}".format( + input.name, min_shape, opt_shape, max_shape)) + else: + shape = [batch_size] + list(input.shape[1:]) + profile.set_shape(input.name, shape, shape, shape) + log.info("Input '{}' Optimization Profile with shape {}".format(input.name, shape)) + if dynamic_inputs: + self.config.add_optimization_profile(profile) + + outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] for output in outputs: log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype)) - assert self.batch_size > 0 - self.builder.max_batch_size = self.batch_size + + def set_mixed_precision(self): + """ + Experimental precision mode. + Enable mixed-precision mode. When set, the layers defined here will be forced to FP16 to maximize + INT8 inference accuracy, while having minimal impact on latency. + """ + self.config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + # All convolution operations in the first four blocks of the graph are pinned to FP16. + # These layers have been manually chosen as they give a good middle-point between int8 and fp16 + # accuracy in COCO, while maintining almost the same latency as a normal int8 engine. + # To experiment with other datasets, or a different balance between accuracy/latency, you may + # add or remove blocks. + for i in range(self.network.num_layers): + layer = self.network.get_layer(i) + if layer.type == trt.LayerType.CONVOLUTION and any([ + # AutoML Layer Names: + "/stem/" in layer.name, + "/blocks_0/" in layer.name, + "/blocks_1/" in layer.name, + "/blocks_2/" in layer.name, + # TFOD Layer Names: + "/stem_conv2d/" in layer.name, + "/stack_0/block_0/" in layer.name, + "/stack_1/block_0/" in layer.name, + "/stack_1/block_1/" in layer.name, + ]): + self.network.get_layer(i).precision = trt.DataType.HALF + log.info("Mixed-Precision Layer {} set to HALF STRICT data type".format(layer.name)) def create_engine(self, engine_path, precision, calib_input=None, calib_cache=None, calib_num_images=5000, calib_batch_size=8): """ Build the TensorRT engine and serialize it to disk. :param engine_path: The path where to serialize the engine to. - :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. + :param precision: The datatype to use for the engine, either 'fp32', 'fp16', 'int8', or 'mixed'. :param calib_input: The path to a directory holding the calibration images. :param calib_cache: The path where to write the calibration cache to, or if it already exists, load it from. :param calib_num_images: The maximum number of images to use for calibration. @@ -179,62 +231,73 @@ def create_engine(self, engine_path, precision, calib_input=None, calib_cache=No inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] - if precision == "fp16": + if precision in ["fp16", "int8", "mixed"]: if not self.builder.platform_has_fast_fp16: log.warning("FP16 is not supported natively on this platform/device") - else: - self.config.set_flag(trt.BuilderFlag.FP16) - elif precision == "int8": + self.config.set_flag(trt.BuilderFlag.FP16) + if precision in ["int8", "mixed"]: if not self.builder.platform_has_fast_int8: log.warning("INT8 is not supported natively on this platform/device") - else: - if self.builder.platform_has_fast_fp16: - # Also enable fp16, as some layers may be even more efficient in fp16 than int8 - self.config.set_flag(trt.BuilderFlag.FP16) - self.config.set_flag(trt.BuilderFlag.INT8) - self.config.int8_calibrator = EngineCalibrator(calib_cache) - if not os.path.exists(calib_cache): - calib_shape = [calib_batch_size] + list(inputs[0].shape[1:]) - calib_dtype = trt.nptype(inputs[0].dtype) - self.config.int8_calibrator.set_image_batcher( - ImageBatcher(calib_input, calib_shape, calib_dtype, max_num_images=calib_num_images, - exact_batches=True)) - - with self.builder.build_engine(self.network, self.config) as engine, open(engine_path, "wb") as f: + self.config.set_flag(trt.BuilderFlag.INT8) + self.config.int8_calibrator = EngineCalibrator(calib_cache) + if calib_cache is None or not os.path.exists(calib_cache): + calib_shape = [calib_batch_size] + list(inputs[0].shape[1:]) + calib_dtype = trt.nptype(inputs[0].dtype) + self.config.int8_calibrator.set_image_batcher( + ImageBatcher(calib_input, calib_shape, calib_dtype, max_num_images=calib_num_images, + exact_batches=True, shuffle_files=True)) + + engine_bytes = None + try: + engine_bytes = self.builder.build_serialized_network(self.network, self.config) + except AttributeError: + engine = self.builder.build_engine(self.network, self.config) + engine_bytes = engine.serialize() + del engine + assert engine_bytes + with open(engine_path, "wb") as f: log.info("Serializing engine to file: {:}".format(engine_path)) - f.write(engine.serialize()) + f.write(engine_bytes) def main(args): builder = EngineBuilder(args.verbose, args.workspace) - builder.create_network(args.onnx) + builder.create_network(args.onnx, args.batch_size, args.dynamic_batch_size) + if args.precision == "mixed": + builder.set_mixed_precision() builder.create_engine(args.engine, args.precision, args.calib_input, args.calib_cache, args.calib_num_images, args.calib_batch_size) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-o", "--onnx", help="The input ONNX model file to load") - parser.add_argument("-e", "--engine", help="The output path for the TRT engine") - parser.add_argument("-p", "--precision", default="fp16", choices=["fp32", "fp16", "int8"], - help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'") - parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output") - parser.add_argument("-w", "--workspace", default=8, type=int, help="The max memory workspace size to allow in Gb, " - "default: 8") - parser.add_argument("--calib_input", help="The directory holding images to use for calibration") - parser.add_argument("--calib_cache", default="./calibration.cache", + parser.add_argument("-o", "--onnx", required=True, + help="The input ONNX model file to load") + parser.add_argument("-e", "--engine", required=True, + help="The output path for the TRT engine") + parser.add_argument("-b", "--batch_size", default=1, type=int, + help="The static batch size to build the engine with, default: 1") + parser.add_argument("-d", "--dynamic_batch_size", default=None, + help="Enable dynamic batch size by providing a comma-separated MIN,OPT,MAX batch size, " + "if this option is set, --batch_size is ignored, example: -d 1,16,32, " + "default: None, build static engine") + parser.add_argument("-p", "--precision", default="fp16", choices=["fp32", "fp16", "int8", "mixed"], + help="The precision mode to build in, either fp32/fp16/int8/mixed, default: fp16") + parser.add_argument("-v", "--verbose", action="store_true", + help="Enable more verbose log output") + parser.add_argument("-w", "--workspace", default=8, type=int, + help="The max memory workspace size to allow in Gb, default: 8") + parser.add_argument("--calib_input", + help="The directory holding images to use for calibration") + parser.add_argument("--calib_cache", default=None, help="The file path for INT8 calibration cache to use, default: ./calibration.cache") parser.add_argument("--calib_num_images", default=5000, type=int, help="The maximum number of images to use for calibration, default: 5000") parser.add_argument("--calib_batch_size", default=8, type=int, help="The batch size for the calibration process, default: 8") args = parser.parse_args() - if not all([args.onnx, args.engine]): - parser.print_help() - log.error("These arguments are required: --onnx and --engine") - sys.exit(1) - if args.precision == "int8" and not (args.calib_input or os.path.exists(args.calib_cache)): + if args.precision in ["int8", "mixed"] and not (args.calib_input or os.path.exists(args.calib_cache)): parser.print_help() - log.error("When building in int8 precision, --calib_input or an existing --calib_cache file is required") + log.error("When building in int8 or mixed precision, --calib_input or an existing --calib_cache file is required") sys.exit(1) main(args) diff --git a/samples/python/efficientdet/compare_tf.py b/samples/python/efficientdet/compare_tf.py index 26c65c4e..70709c89 100644 --- a/samples/python/efficientdet/compare_tf.py +++ b/samples/python/efficientdet/compare_tf.py @@ -23,103 +23,16 @@ import tensorflow as tf from infer import TensorRTInfer +from infer_tf import TensorFlowInfer from image_batcher import ImageBatcher from visualize import visualize_detections, concat_visualizations -class TensorFlowInfer: - """ - Implements TensorFlow inference of a saved model, following the same API as the TensorRTInfer class. - """ - - def __init__(self, saved_model_path): - gpus = tf.config.experimental.list_physical_devices('GPU') - for gpu in gpus: - tf.config.experimental.set_memory_growth(gpu, True) - - self.model = tf.saved_model.load(saved_model_path) - self.pred_fn = self.model.signatures['serving_default'] - - # Setup I/O bindings - self.inputs = [] - fn_inputs = self.pred_fn.structured_input_signature[1] - for i, input in enumerate(list(fn_inputs.values())): - self.inputs.append({ - 'index': i, - 'name': input.name, - 'dtype': np.dtype(input.dtype.as_numpy_dtype()), - 'shape': [1, 512, 512, 3], # This can be overridden later - }) - self.outputs = [] - fn_outputs = self.pred_fn.structured_outputs - for i, output in enumerate(list(fn_outputs.values())): - self.outputs.append({ - 'index': i, - 'name': output.name, - 'dtype': np.dtype(output.dtype.as_numpy_dtype()), - 'shape': output.shape.as_list(), - }) - - def override_input_shape(self, input, shape): - self.inputs[input]['shape'] = shape - - def input_spec(self): - return self.inputs[0]['shape'], self.inputs[0]['dtype'] - - def output_spec(self): - return self.outputs[0]['shape'], self.outputs[0]['dtype'] - - def infer(self, batch, scales=None, nms_threshold=None): - # Process I/O and execute the network - input = {self.inputs[0]['name']: tf.convert_to_tensor(batch)} - output = self.pred_fn(**input) - - # Extract the results depending on what kind of saved model this is - boxes = None - scores = None - classes = None - if len(self.outputs) == 1: - # Detected as AutoML Saved Model - assert len(self.outputs[0]['shape']) == 3 and self.outputs[0]['shape'][2] == 7 - results = output[self.outputs[0]['name']].numpy() - boxes = results[:, :, 1:5] - scores = results[:, :, 5] - classes = results[:, :, 6].astype(np.int32) - elif len(self.outputs) >= 4: - # Detected as TFOD Saved Model - assert output['num_detections'] - num = int(output['num_detections'].numpy().flatten()[0]) - boxes = output['detection_boxes'].numpy()[:, 0:num, :] - scores = output['detection_scores'].numpy()[:, 0:num] - classes = output['detection_classes'].numpy()[:, 0:num] - - # Process the results - detections = [[]] - normalized = (np.max(boxes) < 2.0) - for n in range(scores.shape[1]): - if scores[0][n] == 0.0: - break - scale = self.inputs[0]['shape'][2] if normalized else 1.0 - if scales: - scale /= scales[0] - if nms_threshold and scores[0][n] < nms_threshold: - continue - detections[0].append({ - 'ymin': boxes[0][n][0] * scale, - 'xmin': boxes[0][n][1] * scale, - 'ymax': boxes[0][n][2] * scale, - 'xmax': boxes[0][n][3] * scale, - 'score': scores[0][n], - 'class': int(classes[0][n]) - 1, - }) - return detections - - def run(batcher, inferer, framework, nms_threshold=None): res_images = [] res_detections = [] for batch, images, scales in batcher.get_batch(): - res_detections += inferer.infer(batch, scales, nms_threshold) + res_detections += inferer.process(batch, scales, nms_threshold) res_images += images print("Processing {} / {} images ({})".format(batcher.image_index, batcher.num_images, framework), end="\r") print() diff --git a/samples/python/efficientdet/create_onnx.py b/samples/python/efficientdet/create_onnx.py index 8f2beb87..3cfa962d 100644 --- a/samples/python/efficientdet/create_onnx.py +++ b/samples/python/efficientdet/create_onnx.py @@ -34,12 +34,11 @@ class EfficientDetGraphSurgeon: - def __init__(self, saved_model_path, legacy_plugins=False): + def __init__(self, saved_model_path): """ Constructor of the EfficientDet Graph Surgeon object, to do the conversion of an EfficientDet TF saved model to an ONNX-TensorRT parsable model. :param saved_model_path: The path pointing to the TensorFlow saved model to load. - :param legacy_plugins: If using TensorRT version < 8.0.1, set this to True to use older (but slower) plugins. """ saved_model_path = os.path.realpath(saved_model_path) assert os.path.exists(saved_model_path) @@ -69,10 +68,7 @@ def __init__(self, saved_model_path, legacy_plugins=False): assert self.api log.info("Graph was detected as {}".format(self.api)) - self.batch_size = None - self.legacy_plugins = legacy_plugins - - def infer(self): + def sanitize(self): """ Sanitize the graph by cleaning any unconnected nodes, do a topological resort, and fold constant inputs values. When possible, run shape inference on the ONNX graph to determine tensor shapes. @@ -114,32 +110,29 @@ def save(self, output_path): onnx.save(model, output_path) log.info("Saved ONNX model to {}".format(output_path)) - def update_preprocessor(self, input_shape): + def update_preprocessor(self, input_format, input_size, preprocessor="imagenet"): """ Remove all the pre-processing nodes in the ONNX graph and leave only the image normalization essentials. - :param input_shape: The input tensor shape to use for the ONNX graph. + :param input_format: The input data format, either "NCHW" or "NHWC". + :param input_size: The input size as a comma-separated string in H,W format, e.g. "512,512". + :param preprocessor: The preprocessor to use, either "imagenet" for imagenet mean and stdev normalization, + or "scale_range" for uniform [-1,+1] range normalization. """ # Update the input and output tensors shape - input_shape = input_shape.split(",") - assert len(input_shape) == 4 - for i in range(len(input_shape)): - input_shape[i] = int(input_shape[i]) - assert input_shape[i] >= 1 - input_format = None - if input_shape[1] == 3: - input_format = "NCHW" - if input_shape[3] == 3: - input_format = "NHWC" + input_size = input_size.split(",") + assert len(input_size) == 2 + for i in range(len(input_size)): + input_size[i] = int(input_size[i]) + assert input_size[i] >= 1 assert input_format in ["NCHW", "NHWC"] - self.batch_size = input_shape[0] - self.graph.inputs[0].shape = input_shape + if input_format == "NCHW": + self.graph.inputs[0].shape = ['N', 3, input_size[0], input_size[1]] + if input_format == "NHWC": + self.graph.inputs[0].shape = ['N', input_size[0], input_size[1], 3] self.graph.inputs[0].dtype = np.float32 - if self.api == "TFOD" and self.batch_size > 1 and self.legacy_plugins: - log.error("TFOD models with a batch size larger than 1 are not currently supported in legacy plugin mode. " - "Please upgrade to TensorRT >= 8.0.1 or use batch size 1 for now.") - sys.exit(1) - self.infer() - log.info("ONNX graph input shape: {} [{} format detected]".format(self.graph.inputs[0].shape, input_format)) + self.graph.inputs[0].name = "input" + log.info("ONNX graph input shape: {} [{} format]".format(self.graph.inputs[0].shape, input_format)) + self.sanitize() # Find the initial nodes of the graph, whatever the input is first connected to, and disconnect them for node in [node for node in self.graph.nodes if self.graph.inputs[0] in node.inputs]: @@ -150,13 +143,25 @@ def update_preprocessor(self, input_shape): if input_format == "NHWC": input_tensor = self.graph.transpose("preprocessor/transpose", input_tensor, [0, 3, 1, 2]) - # RGB Normalizers. The per-channel values are given with shape [1, 3, 1, 1] for proper NCHW shape broadcasting - scale_val = 1 / np.asarray([255], dtype=np.float32) - mean_val = -1 * np.expand_dims(np.asarray([0.485, 0.456, 0.406], dtype=np.float32), axis=(0, 2, 3)) - stddev_val = 1 / np.expand_dims(np.asarray([0.229, 0.224, 0.225], dtype=np.float32), axis=(0, 2, 3)) - # y = (x * scale + mean) * stddev --> y = x * scale * stddev + mean * stddev - scale_out = self.graph.elt_const("Mul", "preprocessor/scale", input_tensor, scale_val * stddev_val) - mean_out = self.graph.elt_const("Add", "preprocessor/mean", scale_out, mean_val * stddev_val) + assert preprocessor in ["imagenet", "scale_range"] + preprocessed_tensor = None + if preprocessor == "imagenet": + # RGB Normalizers. The per-channel values are given with shape [1, 3, 1, 1] for proper NCHW shape broadcasting + scale_val = 1 / np.asarray([255], dtype=np.float32) + mean_val = -1 * np.expand_dims(np.asarray([0.485, 0.456, 0.406], dtype=np.float32), axis=(0, 2, 3)) + stddev_val = 1 / np.expand_dims(np.asarray([0.229, 0.224, 0.225], dtype=np.float32), axis=(0, 2, 3)) + # y = (x * scale + mean) * stddev --> y = x * scale * stddev + mean * stddev + scale_out = self.graph.elt_const("Mul", "preprocessor/scale", input_tensor, scale_val * stddev_val) + mean_out = self.graph.elt_const("Add", "preprocessor/mean", scale_out, mean_val * stddev_val) + preprocessed_tensor = mean_out[0] + if preprocessor == "scale_range": + # RGB Normalizers. The per-channel values are given with shape [1, 3, 1, 1] for proper NCHW shape broadcasting + scale_val = 2 / np.asarray([255], dtype=np.float32) + offset_val = np.expand_dims(np.asarray([-1, -1, -1], dtype=np.float32), axis=(0, 2, 3)) + # y = (x * scale + mean) * stddev --> y = x * scale * stddev + mean * stddev + scale_out = self.graph.elt_const("Mul", "preprocessor/scale", input_tensor, scale_val) + range_out = self.graph.elt_const("Add", "preprocessor/range", scale_out, offset_val) + preprocessed_tensor = range_out[0] # Find the first stem conv node of the graph, and connect the normalizer directly to it stem_name = None @@ -166,20 +171,69 @@ def update_preprocessor(self, input_shape): stem_name = "/stem_conv2d/" stem = [node for node in self.graph.nodes if node.op == "Conv" and stem_name in node.name][0] log.info("Found {} node '{}' as stem entry".format(stem.op, stem.name)) - stem.inputs[0] = mean_out[0] + stem.inputs[0] = preprocessed_tensor - # Reshape nodes tend to update the batch dimension to a fixed value of 1, they should use the batch size instead + self.sanitize() + + def update_shapes(self): + # Reshape nodes have the batch dimension as a fixed value of 1, they should use the batch size instead + # Output-Head reshapes use [1, -1, C], corrected reshape value should be [-1, V, C] + for node in [node for node in self.graph.nodes if node.op == "Reshape"]: + shape_in = node.inputs[0].shape + if shape_in is None or len(shape_in) not in [4,5]: # TFOD graphs have 5-dim inputs on this Reshape + continue + if type(node.inputs[1]) != gs.Constant: + continue + shape_out = node.inputs[1].values + if len(shape_out) != 3 or shape_out[0] != 1 or shape_out[1] != -1: + continue + volume = shape_in[1] * shape_in[2] * shape_in[3] / shape_out[2] + if len(shape_in) == 5: + volume *= shape_in[4] + shape_corrected = np.asarray([-1, volume, shape_out[2]], dtype=np.int64) + node.inputs[1] = gs.Constant("{}_shape".format(node.name), values=shape_corrected) + log.info("Updating Output-Head Reshape node {} to {}".format(node.name, node.inputs[1].values)) + + # Other Reshapes only need to change the first dim to -1, as long as there are no -1's already for node in [node for node in self.graph.nodes if node.op == "Reshape"]: - if type(node.inputs[1]) == gs.Constant and node.inputs[1].values[0] == 1: - node.inputs[1].values[0] = self.batch_size + if type(node.inputs[1]) != gs.Constant or node.inputs[1].values[0] != 1 or -1 in node.inputs[1].values: + continue + node.inputs[1].values[0] = -1 + log.info("Updating Reshape node {} to {}".format(node.name, node.inputs[1].values)) - self.infer() + # Resize nodes try to calculate the output shape dynamically, it's more optimal to pre-compute the shape + if self.api == "AutoML": + # Resize on a BiFPN will always be 2x, but grab it from the graph just in case + for node in [node for node in self.graph.nodes if node.op == "Resize"]: + if len(node.inputs) < 4 or node.inputs[0].shape is None: + continue + scale_h, scale_w = None, None + if type(node.inputs[3]) == gs.Constant: + # The sizes input is already folded + if len(node.inputs[3].values) != 4: + continue + scale_h = node.inputs[3].values[2] / node.inputs[0].shape[2] + scale_w = node.inputs[3].values[3] / node.inputs[0].shape[3] + if type(node.inputs[3]) == gs.Variable: + # The sizes input comes from Shape+Slice+Concat + concat = node.i(3) + if concat.op != "Concat": + continue + if type(concat.inputs[1]) != gs.Constant or len(concat.inputs[1].values) != 2: + continue + scale_h = concat.inputs[1].values[0] / node.inputs[0].shape[2] + scale_w = concat.inputs[1].values[1] / node.inputs[0].shape[3] + scales = np.asarray([1, 1, scale_h, scale_w], dtype=np.float32) + del node.inputs[3] + node.inputs[2] = gs.Constant(name="{}_scales".format(node.name), values=scales) + log.info("Updating Resize node {} to {}".format(node.name, scales)) + + self.sanitize() def update_network(self): """ Updates the graph to replace certain nodes in the main EfficientDet network: - the global average pooling nodes are optimized when running for TFOD models. - - the nearest neighbor resize ops in the FPN are replaced by a TRT plugin nodes when running in legacy mode. """ if self.api == "TFOD": @@ -218,26 +272,6 @@ def update_network(self): reduce.attrs['keepdims'] = 1 # Keep the reduced dimensions log.info("Optimized subgraph around ReduceMean node '{}'".format(reduce.name)) - if self.legacy_plugins: - self.infer() - count = 1 - for node in [node for node in self.graph.nodes if node.op == "Resize" and node.attrs['mode'] == "nearest"]: - # Older versions of TensorRT do not understand nearest neighbor resize ops, so a plugin is used to - # perform this operation. - self.graph.plugin( - op="ResizeNearest_TRT", - name="resize_nearest_{}".format(count), - inputs=[node.inputs[0]], - outputs=node.outputs, - attrs={ - 'plugin_version': "1", - 'scale': 2.0, # All resize ops in the EfficientDet FPN should have an upscale factor of 2.0 - }) - node.outputs.clear() - log.info( - "Replaced '{}' ({}) with a ResizeNearest_TRT plugin node".format(node.name, count)) - count += 1 - def update_nms(self, threshold=None, detections=None): """ Updates the graph to replace the NMS op by BatchedNMS_TRT TensorRT plugin node. @@ -284,7 +318,7 @@ def get_anchor_np(output_idx, op): anchors = np.concatenate([anchors_y, anchors_x, anchors_h, anchors_w], axis=2) return gs.Constant(name="nms/anchors:0", values=anchors) - self.infer() + self.sanitize() head_names = [] if self.api == "AutoML": @@ -329,77 +363,37 @@ def get_anchor_np(output_idx, op): nms_op = None nms_attrs = None nms_inputs = None - if not self.legacy_plugins: - # EfficientNMS TensorRT Plugin - # Fusing the decoder will always be faster, so this is the default NMS method supported. In this case, - # three inputs are given to the NMS TensorRT node: - # - The box predictions (from the Box Net node found above) - # - The class predictions (from the Class Net node found above) - # - The default anchor coordinates (from the extracted anchor constants) - # As the original tensors from EfficientDet will be used, the NMS code type is set to 1 (Center+Size), - # because this is the internal box coding format used by the network. - anchors_tensor = extract_anchors_tensor(box_net_split) - nms_inputs = [box_net_tensor, class_net_tensor, anchors_tensor] - nms_op = "EfficientNMS_TRT" - nms_attrs = { - 'plugin_version': "1", - 'background_class': -1, - 'max_output_boxes': num_detections, - 'score_threshold': max(0.01, score_threshold), # Keep threshold to at least 0.01 for better efficiency - 'iou_threshold': iou_threshold, - 'score_activation': True, - 'box_coding': 1, - } - nms_output_classes_dtype = np.int32 - else: - # BatchedNMS TensorRT Plugin - # Alternatively, the ONNX box decoder can be used. This will be slower, as more element-wise and non-fused - # operations will need to be performed by TensorRT. However, it's easier to implement, so it is shown here - # for reference. In this case, only two inputs are given to the NMS TensorRT node: - # - The box predictions (already decoded through the ONNX Box Decoder node) - # - The class predictions (from the Class Net node found above, but also needs to pass through a sigmoid) - # This time, the box predictions will have the coordinate coding from the ONNX box decoder, which matches - # what the BatchedNMS plugin uses. - - if self.api == "AutoML": - # The default boxes tensor has shape [batch_size, number_boxes, 4]. This will insert a "1" dimension - # in the second axis, to become [batch_size, number_boxes, 1, 4], the shape that BatchedNMS expects. - box_decoder_tensor = self.graph.unsqueeze("nms/box_net_reshape", box_decoder_tensor, axes=[2])[0] - if self.api == "TFOD": - # The default boxes tensor has shape [4, number_boxes]. This will transpose and insert a "1" dimension - # in the 0 and 2 axes, to become [1, number_boxes, 1, 4], the shape that BatchedNMS expects. - box_decoder_tensor = self.graph.transpose("nms/box_decoder_transpose", box_decoder_tensor, perm=[1, 0]) - box_decoder_tensor = self.graph.unsqueeze("nms/box_decoder_reshape", box_decoder_tensor, axes=[0, 2])[0] - - # BatchedNMS also expects the classes tensor to be already activated, in the case of EfficientDet, this is - # through a Sigmoid op. - class_net_tensor = self.graph.sigmoid("nms/class_net_sigmoid", class_net_tensor)[0] - - nms_inputs = [box_decoder_tensor, class_net_tensor] - nms_op = "BatchedNMS_TRT" - nms_attrs = { - 'plugin_version': "1", - 'shareLocation': True, - 'backgroundLabelId': -1, - 'numClasses': num_classes, - 'topK': 1024, - 'keepTopK': num_detections, - 'scoreThreshold': score_threshold, - 'iouThreshold': iou_threshold, - 'isNormalized': normalized, - 'clipBoxes': False, - # 'scoreBits': 10, # Some versions of the plugin may need this parameter. If so, uncomment this line. - } - nms_output_classes_dtype = np.float32 + + # EfficientNMS TensorRT Plugin + # Fusing the decoder will always be faster, so this is the default NMS method supported. In this case, + # three inputs are given to the NMS TensorRT node: + # - The box predictions (from the Box Net node found above) + # - The class predictions (from the Class Net node found above) + # - The default anchor coordinates (from the extracted anchor constants) + # As the original tensors from EfficientDet will be used, the NMS code type is set to 1 (Center+Size), + # because this is the internal box coding format used by the network. + anchors_tensor = extract_anchors_tensor(box_net_split) + nms_inputs = [box_net_tensor, class_net_tensor, anchors_tensor] + nms_op = "EfficientNMS_TRT" + nms_attrs = { + 'plugin_version': "1", + 'background_class': -1, + 'max_output_boxes': num_detections, + 'score_threshold': max(0.01, score_threshold), # Keep threshold to at least 0.01 for better efficiency + 'iou_threshold': iou_threshold, + 'score_activation': True, + 'box_coding': 1, + } + nms_output_classes_dtype = np.int32 # NMS Outputs - nms_output_num_detections = gs.Variable(name="num_detections", dtype=np.int32, shape=[self.batch_size, 1]) + nms_output_num_detections = gs.Variable(name="num_detections", dtype=np.int32, shape=['N', 1]) nms_output_boxes = gs.Variable(name="detection_boxes", dtype=np.float32, - shape=[self.batch_size, num_detections, 4]) + shape=['N', num_detections, 4]) nms_output_scores = gs.Variable(name="detection_scores", dtype=np.float32, - shape=[self.batch_size, num_detections]) + shape=['N', num_detections]) nms_output_classes = gs.Variable(name="detection_classes", dtype=nms_output_classes_dtype, - shape=[self.batch_size, num_detections]) + shape=['N', num_detections]) nms_outputs = [nms_output_num_detections, nms_output_boxes, nms_output_scores, nms_output_classes] @@ -415,14 +409,15 @@ def get_anchor_np(output_idx, op): self.graph.outputs = nms_outputs - self.infer() + self.sanitize() def main(args): - effdet_gs = EfficientDetGraphSurgeon(args.saved_model, args.legacy_plugins) + effdet_gs = EfficientDetGraphSurgeon(args.saved_model) if args.tf2onnx: effdet_gs.save(args.tf2onnx) - effdet_gs.update_preprocessor(args.input_shape) + effdet_gs.update_preprocessor(args.input_format, args.input_size, args.preprocessor) + effdet_gs.update_shapes() effdet_gs.update_network() effdet_gs.update_nms(args.nms_threshold, args.nms_detections) effdet_gs.save(args.onnx) @@ -430,22 +425,25 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-m", "--saved_model", help="The TensorFlow saved model directory to load") - parser.add_argument("-o", "--onnx", help="The output ONNX model file to write") - parser.add_argument("-i", "--input_shape", default="1,512,512,3", - help="Set the input shape of the graph, as comma-separated dimensions in NCHW or NHWC format, " - "default: 1,512,512,3") - parser.add_argument("-t", "--nms_threshold", type=float, help="Override the score threshold for the NMS op, " - "default: use the original value in the model") - parser.add_argument("-d", "--nms_detections", type=int, help="Override the max detections for the NMS op, " - "default: use the original value in the model") - parser.add_argument("--legacy_plugins", action="store_true", help="Use legacy plugins for support on TensorRT " - "versions lower than 8.0.1") - parser.add_argument("--tf2onnx", help="The path where to save the intermediate ONNX graph generated by tf2onnx, " - "useful for debugging purposes, default: not saved") + parser.add_argument("-m", "--saved_model", required=True, + help="The TensorFlow saved model directory to load") + parser.add_argument("-o", "--onnx", required=True, + help="The output ONNX model file to write") + parser.add_argument("-f", "--input_format", default="NHWC", choices=["NHWC", "NCHW"], + help="Set the input data format of the graph, either NCHW or NHWC, default: NHWC") + parser.add_argument("-i", "--input_size", default="512,512", + help="Set the input shape of the graph, as a comma-separated dimensions in H,W format, " + "default: 512,512") + parser.add_argument("-p", "--preprocessor", default="imagenet", choices=["imagenet", "scale_range"], + help="Set the preprocessor to apply on the graph, either 'imagenet' for standard mean " + "subtraction and stdev normalization, or 'scale_range' for uniform [-1,+1] " + "normalization as is used in the AdvProp models, default: imagenet") + parser.add_argument("-t", "--nms_threshold", type=float, + help="Override the NMS score threshold, default: use the original value in the model") + parser.add_argument("-d", "--nms_detections", type=int, + help="Override the NMS max detections, default: use the original value in the model") + parser.add_argument("--tf2onnx", + help="The path where to save the intermediate ONNX graph generated by tf2onnx, useful" + "for graph debugging purposes, default: not saved") args = parser.parse_args() - if not all([args.saved_model, args.onnx]): - parser.print_help() - print("\nThese arguments are required: --saved_model and --onnx") - sys.exit(1) main(args) diff --git a/samples/python/efficientdet/eval_coco.py b/samples/python/efficientdet/eval_coco.py index 30e04484..f10888ec 100644 --- a/samples/python/efficientdet/eval_coco.py +++ b/samples/python/efficientdet/eval_coco.py @@ -39,7 +39,7 @@ def main(args): evaluator = coco_metric.EvaluationMetric(filename=args.annotations) for batch, images, scales in batcher.get_batch(): print("Processing Image {} / {}".format(batcher.image_index, batcher.num_images), end="\r") - detections = trt_infer.infer(batch, scales, args.nms_threshold) + detections = trt_infer.process(batch, scales, args.nms_threshold) coco_det = np.zeros((len(images), max([len(d) for d in detections]), 7)) coco_det[:, :, -1] = -1 for i in range(len(images)): diff --git a/samples/python/efficientdet/image_batcher.py b/samples/python/efficientdet/image_batcher.py index d7161371..df50d7f1 100644 --- a/samples/python/efficientdet/image_batcher.py +++ b/samples/python/efficientdet/image_batcher.py @@ -16,6 +16,7 @@ import os import sys +import random import numpy as np from PIL import Image @@ -26,7 +27,7 @@ class ImageBatcher: Creates batches of pre-processed images. """ - def __init__(self, input, shape, dtype, max_num_images=None, exact_batches=False, preprocessor="EfficientDet"): + def __init__(self, input, shape, dtype, max_num_images=None, exact_batches=False, preprocessor="EfficientDet", shuffle_files=False): """ :param input: The input directory to read images from. :param shape: The tensor shape of the batch to prepare, either in NCHW or NHWC format. @@ -36,6 +37,7 @@ def __init__(self, input, shape, dtype, max_num_images=None, exact_batches=False size. If false, it will pad the final batch with zeros to reach the batch size. If true, it will *remove* the last few images in excess of a batch size multiple, to guarantee batches are exact (useful for calibration). :param preprocessor: Set the preprocessor to use, depending on which network is being used. + :param shuffle_files: Shuffle the list of files before batching. """ # Find images in the given input path input = os.path.realpath(input) @@ -49,6 +51,9 @@ def is_image(path): if os.path.isdir(input): self.images = [os.path.join(input, f) for f in os.listdir(input) if is_image(os.path.join(input, f))] self.images.sort() + if shuffle_files: + random.seed(47) + random.shuffle(self.images) elif os.path.isfile(input): if is_image(input): self.images.append(input) @@ -157,7 +162,6 @@ def get_batch(self): for i, batch_images in enumerate(self.batches): batch_data = np.zeros(self.shape, dtype=self.dtype) batch_scales = [None] * len(batch_images) - print("BATCH SCALES: ", batch_scales) for i, image in enumerate(batch_images): self.image_index += 1 batch_data[i], batch_scales[i] = self.preprocess_image(image) diff --git a/samples/python/efficientdet/infer.py b/samples/python/efficientdet/infer.py index 918b9477..c797ae74 100644 --- a/samples/python/efficientdet/infer.py +++ b/samples/python/efficientdet/infer.py @@ -56,26 +56,38 @@ def __init__(self, engine_path): if self.engine.binding_is_input(i): is_input = True name = self.engine.get_binding_name(i) - dtype = self.engine.get_binding_dtype(i) - shape = self.engine.get_binding_shape(i) + dtype = np.dtype(trt.nptype(self.engine.get_binding_dtype(i))) + shape = self.context.get_binding_shape(i) + if is_input and shape[0] < 0: + assert self.engine.num_optimization_profiles > 0 + profile_shape = self.engine.get_profile_shape(0, name) + assert len(profile_shape) == 3 # min,opt,max + # Set the *max* profile as binding shape + self.context.set_binding_shape(i, profile_shape[2]) + shape = self.context.get_binding_shape(i) if is_input: self.batch_size = shape[0] - size = np.dtype(trt.nptype(dtype)).itemsize + size = dtype.itemsize for s in shape: size *= s allocation = cuda.mem_alloc(size) + host_allocation = None if is_input else np.zeros(shape, dtype) binding = { 'index': i, 'name': name, - 'dtype': np.dtype(trt.nptype(dtype)), + 'dtype': dtype, 'shape': list(shape), 'allocation': allocation, + 'host_allocation': host_allocation, } self.allocations.append(allocation) if self.engine.binding_is_input(i): self.inputs.append(binding) else: self.outputs.append(binding) + print("{} '{}' with shape {} and dtype {}".format( + "Input" if is_input else "Output", + binding['name'], binding['shape'], binding['dtype'])) assert self.batch_size > 0 assert len(self.inputs) > 0 @@ -99,7 +111,20 @@ def output_spec(self): specs.append((o['shape'], o['dtype'])) return specs - def infer(self, batch, scales=None, nms_threshold=None): + def infer(self, batch): + """ + Execute inference on a batch of images. + :param batch: A numpy array holding the image batch. + :return A list of outputs as numpy arrays. + """ + # Copy I/O and Execute + cuda.memcpy_htod(self.inputs[0]['allocation'], batch) + self.context.execute_v2(self.allocations) + for o in range(len(self.outputs)): + cuda.memcpy_dtoh(self.outputs[o]['host_allocation'], self.outputs[o]['allocation']) + return [o['host_allocation'] for o in self.outputs] + + def process(self, batch, scales=None, nms_threshold=None): """ Execute inference on a batch of images. The images should already be batched and preprocessed, as prepared by the ImageBatcher class. Memory copying to and from the GPU device will be performed here. @@ -107,16 +132,8 @@ def infer(self, batch, scales=None, nms_threshold=None): :param scales: The image resize scales for each image in this batch. Default: No scale postprocessing applied. :return: A nested list for each image in the batch and each detection in the list. """ - # Prepare the output data - outputs = [] - for shape, dtype in self.output_spec(): - outputs.append(np.zeros(shape, dtype)) - - # Process I/O and execute the network - cuda.memcpy_htod(self.inputs[0]['allocation'], np.ascontiguousarray(batch)) - self.context.execute_v2(self.allocations) - for o in range(len(outputs)): - cuda.memcpy_dtoh(outputs[o], self.outputs[o]['allocation']) + # Run inference + outputs = self.infer(batch) # Process the results nums = outputs[0] @@ -145,8 +162,9 @@ def infer(self, batch, scales=None, nms_threshold=None): def main(args): - output_dir = os.path.realpath(args.output) - os.makedirs(output_dir, exist_ok=True) + if args.output: + output_dir = os.path.realpath(args.output) + os.makedirs(output_dir, exist_ok=True) labels = [] if args.labels: @@ -155,38 +173,59 @@ def main(args): labels.append(label.strip()) trt_infer = TensorRTInfer(args.engine) - batcher = ImageBatcher(args.input, *trt_infer.input_spec()) - for batch, images, scales in batcher.get_batch(): - print("Processing Image {} / {}".format(batcher.image_index, batcher.num_images), end="\r") - detections = trt_infer.infer(batch, scales, args.nms_threshold) - for i in range(len(images)): - basename = os.path.splitext(os.path.basename(images[i]))[0] - # Image Visualizations - output_path = os.path.join(output_dir, "{}.png".format(basename)) - visualize_detections(images[i], output_path, detections[i], labels) - # Text Results - output_results = "" - for d in detections[i]: - line = [d['xmin'], d['ymin'], d['xmax'], d['ymax'], d['score'], d['class']] - output_results += "\t".join([str(f) for f in line]) + "\n" - with open(os.path.join(args.output, "{}.txt".format(basename)), "w") as f: - f.write(output_results) + if args.input: + print("Inferring data in {}".format(args.input)) + batcher = ImageBatcher(args.input, *trt_infer.input_spec()) + for batch, images, scales in batcher.get_batch(): + print("Processing Image {} / {}".format(batcher.image_index, batcher.num_images), end="\r") + detections = trt_infer.process(batch, scales, args.nms_threshold) + if args.output: + for i in range(len(images)): + basename = os.path.splitext(os.path.basename(images[i]))[0] + # Image Visualizations + output_path = os.path.join(output_dir, "{}.png".format(basename)) + visualize_detections(images[i], output_path, detections[i], labels) + # Text Results + output_results = "" + for d in detections[i]: + line = [d['xmin'], d['ymin'], d['xmax'], d['ymax'], d['score'], d['class']] + output_results += "\t".join([str(f) for f in line]) + "\n" + with open(os.path.join(output_dir, "{}.txt".format(basename)), "w") as f: + f.write(output_results) + else: + print("No input provided, running in benchmark mode") + spec = trt_infer.input_spec() + batch = 255 * np.random.rand(*spec[0]).astype(spec[1]) + iterations = 200 + times = [] + for i in range(20): # GPU warmup iterations + trt_infer.infer(batch) + for i in range(iterations): + start = time.time() + trt_infer.infer(batch) + times.append(time.time() - start) + print("Iteration {} / {}".format(i + 1, iterations), end="\r") + print("Benchmark results include time for H2D and D2H memory copies") + print("Average Latency: {:.3f} ms".format( + 1000 * np.average(times))) + print("Average Throughput: {:.1f} ips".format( + trt_infer.batch_size / np.average(times))) + print() print("Finished Processing") if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-e", "--engine", default=None, help="The serialized TensorRT engine") - parser.add_argument("-i", "--input", default=None, help="Path to the image or directory to process") - parser.add_argument("-o", "--output", default=None, help="Directory where to save the visualization results") - parser.add_argument("-l", "--labels", default="./labels_coco.txt", help="File to use for reading the class labels " - "from, default: ./labels_coco.txt") - parser.add_argument("-t", "--nms_threshold", type=float, help="Override the score threshold for the NMS operation, " - "if higher than the threshold in the engine.") + parser.add_argument("-e", "--engine", default=None, required=True, + help="The serialized TensorRT engine") + parser.add_argument("-i", "--input", default=None, + help="Path to the image or directory to process") + parser.add_argument("-o", "--output", default=None, + help="Directory where to save the visualization results") + parser.add_argument("-l", "--labels", default="./labels_coco.txt", + help="File to use for reading the class labels from, default: ./labels_coco.txt") + parser.add_argument("-t", "--nms_threshold", type=float, + help="Override the score threshold for the NMS operation, if higher than the built-in threshold") args = parser.parse_args() - if not all([args.engine, args.input, args.output]): - parser.print_help() - print("\nThese arguments are required: --engine --input and --output") - sys.exit(1) main(args) diff --git a/samples/python/efficientdet/infer_tf.py b/samples/python/efficientdet/infer_tf.py new file mode 100644 index 00000000..45dfbdc2 --- /dev/null +++ b/samples/python/efficientdet/infer_tf.py @@ -0,0 +1,157 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import time +import argparse + +import numpy as np +import tensorflow as tf + + +class TensorFlowInfer: + """ + Implements TensorFlow inference of a saved model, following the same API as the TensorRTInfer class. + """ + + def __init__(self, saved_model_path): + gpus = tf.config.experimental.list_physical_devices('GPU') + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + + self.model = tf.saved_model.load(saved_model_path) + self.pred_fn = self.model.signatures['serving_default'] + + # Setup I/O bindings + self.batch_size = 1 + self.inputs = [] + fn_inputs = self.pred_fn.structured_input_signature[1] + for i, input in enumerate(list(fn_inputs.values())): + self.inputs.append({ + 'index': i, + 'name': input.name, + 'dtype': np.dtype(input.dtype.as_numpy_dtype()), + 'shape': [1, 512, 512, 3], # This can be overridden later + }) + self.outputs = [] + fn_outputs = self.pred_fn.structured_outputs + for i, output in enumerate(list(fn_outputs.values())): + self.outputs.append({ + 'index': i, + 'name': output.name, + 'dtype': np.dtype(output.dtype.as_numpy_dtype()), + 'shape': output.shape.as_list(), + }) + + def override_input_shape(self, input, shape): + self.inputs[input]['shape'] = shape + self.batch_size = shape[0] + + def input_spec(self): + return self.inputs[0]['shape'], self.inputs[0]['dtype'] + + def output_spec(self): + return self.outputs[0]['shape'], self.outputs[0]['dtype'] + + def infer(self, batch): + # Process I/O and execute the network + input = {self.inputs[0]['name']: tf.convert_to_tensor(batch)} + output = self.pred_fn(**input) + return output + + def process(self, batch, scales=None, nms_threshold=None): + # Infer network + output = self.infer(batch) + + # Extract the results depending on what kind of saved model this is + boxes = None + scores = None + classes = None + if len(self.outputs) == 1: + # Detected as AutoML Saved Model + assert len(self.outputs[0]['shape']) == 3 and self.outputs[0]['shape'][2] == 7 + results = output[self.outputs[0]['name']].numpy() + boxes = results[:, :, 1:5] + scores = results[:, :, 5] + classes = results[:, :, 6].astype(np.int32) + elif len(self.outputs) >= 4: + # Detected as TFOD Saved Model + assert output['num_detections'] + num = int(output['num_detections'].numpy().flatten()[0]) + boxes = output['detection_boxes'].numpy()[:, 0:num, :] + scores = output['detection_scores'].numpy()[:, 0:num] + classes = output['detection_classes'].numpy()[:, 0:num] + + # Process the results + detections = [[]] + normalized = (np.max(boxes) < 2.0) + for n in range(scores.shape[1]): + if scores[0][n] == 0.0: + break + scale = self.inputs[0]['shape'][2] if normalized else 1.0 + if scales: + scale /= scales[0] + if nms_threshold and scores[0][n] < nms_threshold: + continue + detections[0].append({ + 'ymin': boxes[0][n][0] * scale, + 'xmin': boxes[0][n][1] * scale, + 'ymax': boxes[0][n][2] * scale, + 'xmax': boxes[0][n][3] * scale, + 'score': scores[0][n], + 'class': int(classes[0][n]) - 1, + }) + return detections + + +def main(args): + print("Running in benchmark mode") + tf_infer = TensorFlowInfer(args.saved_model) + input_size = [int(v) for v in args.input_size.split(",")] + assert len(input_size) == 2 + tf_infer.override_input_shape(0, [args.batch_size, input_size[0], input_size[1], 3]) + spec = tf_infer.input_spec() + batch = 255 * np.random.rand(*spec[0]).astype(spec[1]) + iterations = 200 + times = [] + for i in range(20): # Warmup iterations + tf_infer.infer(batch) + for i in range(iterations): + start = time.time() + tf_infer.infer(batch) + times.append(time.time() - start) + print("Iteration {} / {}".format(i + 1, iterations), end="\r") + print("Benchmark results include TensorFlow host overhead") + print("Average Latency: {:.3f} ms".format( + 1000 * np.average(times))) + print("Average Throughput: {:.1f} ips".format( + tf_infer.batch_size / np.average(times))) + + print() + print("Finished Processing") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--saved_model", required=True, + help="The TensorFlow saved model path to validate against") + parser.add_argument("-i", "--input_size", default="512,512", + help="The input size to run the model with, in HEIGHT,WIDTH format") + parser.add_argument("-b", "--batch_size", default=1, type=int, + help="The batch size to run the model with") + args = parser.parse_args() + main(args)