Skip to content

Commit

Permalink
Merge pull request #2 from mrn-mln/main
Browse files Browse the repository at this point in the history
merge commits together
  • Loading branch information
undefined-references authored Aug 9, 2021
2 parents f52d563 + 5add5ab commit e3e83ca
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 27 deletions.
28 changes: 15 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,21 @@ docker run --runtime nvidia --entrypoint bash --privileged -it -v $PWD/:/repo ne
4. Start Inference:

```
python3 inference/inference.py [--device DEVICE] --input_video INPUT_VIDEO --out_dir
OUT_DIR [--detector_model_path DETECTOR_MODEL_PATH]
[--label_map LABEL_MAP]
[--detector_threshold DETECTOR_THRESHOLD]
[--detector_input_width DETECTOR_INPUT_WIDTH]
[--detector_input_height DETECTOR_INPUT_HEIGHT]
[--pose_input_width POSE_INPUT_WIDTH]
[--pose_input_height POSE_INPUT_HEIGHT]
[--heatmap_width HEATMAP_WIDTH]
[--heatmap_height HEATMAP_HEIGHT] [--out_width OUT_WIDTH]
[--out_height OUT_HEIGHT]
[--batch_size BATCH_SIZE]
[--trt_model_path TRT_MODEL_PATH]
python3 inference/inference.py --device DEVICE --input_video INPUT_VIDEO
--out_dir OUT_DIR
--detector_model_path DETECTOR_MODEL_PATH
--label_map LABEL_MAP
--detector_threshold DETECTOR_THRESHOLD
--detector_input_width DETECTOR_INPUT_WIDTH
--detector_input_height DETECTOR_INPUT_HEIGHT
--pose_input_width POSE_INPUT_WIDTH
--pose_input_height POSE_INPUT_HEIGHT
--heatmap_width HEATMAP_WIDTH
--heatmap_height HEATMAP_HEIGHT
--out_width OUT_WIDTH
--out_height OUT_HEIGHT
--batch_size BATCH_SIZE
--trt_model_path TRT_MODEL_PATH
Expand Down
15 changes: 10 additions & 5 deletions inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def inference(args):
heatmap_height = args.heatmap_height
label_map_file = args.label_map
batch_size = args.batch_size
pose_model_path = args.pose_model_path
if not label_map_file:
label_map_file = "adaptive_object_detection/utils/mscoco_label_map.pbtxt"
label_map = create_category_index_dict(label_map_file)
Expand All @@ -30,15 +31,16 @@ def inference(args):
detector_input_size=(detector_input_height, detector_input_width),
pose_input_size=(pose_input_height, pose_input_width),
heatmap_size=(heatmap_height, heatmap_width))
elif device == "jetson":
elif device == "jetson-tx2":
from models.jetson_pose_estimator import TRTPoseEstimator
pose_estimator = TRTPoseEstimator(detector_thresh=detector_thresh,
detector_input_size=(detector_input_height, detector_input_width),
pose_input_size=(pose_input_height, pose_input_width),
heatmap_size=(heatmap_height, heatmap_width),
batch_size=batch_size)
batch_size=batch_size,
pose_model_path=pose_model_path)
else:
raise ValueError("device should be 'x86' or 'jetson' but you provided {0}".format(device))
raise ValueError("device should be 'x86' or 'jetson-tx2' but you provided {0}".format(device))
video_uri = args.input_video
if not os.path.isfile(video_uri):
raise FileNotFoundError('video file does not exist under: {}'.format(video_uri))
Expand Down Expand Up @@ -87,7 +89,7 @@ def inference(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="This script runs the inference of pose estimation models")
parser.add_argument("--device", type=str, default="x86",
help="supports x86 and trt which is running on x86 device")
help="supports x86 and jetson-tx2 device")
parser.add_argument("--input_video", type=str, required=True, help="input video path")
parser.add_argument("--out_dir", type=str, required=True, help="directory to store output video")
parser.add_argument("--detector_model_path", type=str,
Expand All @@ -102,7 +104,10 @@ def inference(args):
parser.add_argument("--heatmap_height", type=int, default=64, help="height of the pose heatmap")
parser.add_argument("--out_width", type=int, default=960, help="width of the output video")
parser.add_argument("--out_height", type=int, default=540, help="height of the output video")
parser.add_argument("--batch_size", type=int, default=8, help="batch size of pose estimator (it only works for jetson)")
parser.add_argument("--batch_size", type=int, default=8,
help="batch size of pose estimator (it only works for jetson)")
parser.add_argument("--pose_model_path", type=str, default="models/data/fast_pose_fp16_b8.trt",
help="using for jetson, path to the pose estimator model file, if not provided the default the model is loaded by default path")

args = parser.parse_args()

Expand Down
16 changes: 7 additions & 9 deletions models/jetson_pose_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def __init__(self,
detector_input_size=(300, 300),
pose_input_size=(256, 192),
heatmap_size=(64, 48),
batch_size=8
batch_size=8,
pose_model_path=None
):
super().__init__(detector_thresh)
self.detector_height, self.detector_width = detector_input_size
self.pose_input_size = pose_input_size
self.heatmap_size = heatmap_size
self.batch_size = batch_size
self.pose_model_path = pose_model_path
self.h_input = None
self.d_input = None
self.h_ouput = None
Expand Down Expand Up @@ -74,14 +76,12 @@ def inference(self, preprocessed_image):
# Transfer input data to the GPU.
result_raw = self._batch_execute(context, num_detected_objects, batch_inps)
result = result_raw[0:num_detected_objects, :]
# print("DONE!", result)

else:
remainder = num_detected_objects
start_idx = 0
while remainder > 0:
endidx = min(self.batch_size, remainder)
#print('remainder', remainder, 'start_idx', start_idx, 'endidx', endidx)
batch_inps[0:endidx, :] = inps[start_idx: start_idx + endidx, :]
self._load_images_to_buffer(batch_inps)
with self.model.create_execution_context() as context:
Expand Down Expand Up @@ -114,9 +114,9 @@ def post_process(self, hm, cropped_boxes, boxes, scores, ids):

preds_img = np.array(pose_coords)
preds_scores = np.array(pose_scores)

boxes, scores, ids, preds_img, preds_scores, pick_ids = \
pose_nms(boxes, scores, ids, preds_img, preds_scores, 0)
# TODO
#boxes, scores, ids, preds_img, preds_scores, pick_ids = \
# pose_nms(boxes, scores, ids, preds_img, preds_scores, 0)
_result = []
for k in range(len(scores)):
if np.ndim(preds_scores[k] == 2):
Expand Down Expand Up @@ -224,9 +224,7 @@ def _load_images_to_buffer(self, img):
np.copyto(self.h_input, preprocessed)

def _load_engine(self):
base_dir = "models/data/"
model_file = "fast_pose_fp16_b8.trt"
model_path = os.path.join(base_dir, model_file)
model_path =self.pose_model_path
if not os.path.isfile(model_path):
logging.info(
'model does not exist under: {}'.format(str(model_path)))
Expand Down

0 comments on commit e3e83ca

Please sign in to comment.