1
1
# Standard Library
2
2
import argparse
3
3
import os
4
+ import time
4
5
from pathlib import Path
5
6
6
7
# Third Party
7
8
import torch
8
9
9
- from happypose .pose_estimators .cosypose .cosypose .integrated .pose_estimator import (
10
- PoseEstimator ,
11
- )
12
-
13
10
# CosyPose
14
11
from happypose .pose_estimators .cosypose .cosypose .utils .cosypose_wrapper import (
15
12
CosyPoseWrapper ,
16
13
)
17
14
18
15
# HappyPose
19
- from happypose .toolbox .datasets .object_dataset import RigidObjectDataset
20
16
from happypose .toolbox .inference .example_inference_utils import (
21
17
load_detections ,
22
18
load_object_data ,
26
22
make_poses_visualization ,
27
23
save_predictions ,
28
24
)
29
- from happypose .toolbox .inference .types import DetectionsType , ObservationTensor
30
- from happypose .toolbox .inference .utils import filter_detections , load_detector
25
+ from happypose .toolbox .inference .types import ObservationTensor
26
+ from happypose .toolbox .inference .utils import filter_detections
31
27
from happypose .toolbox .utils .logging import get_logger , set_logging_level
32
28
33
29
logger = get_logger (__name__ )
34
30
35
31
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
36
32
37
33
38
- def setup_pose_estimator (dataset_to_use : str , object_dataset : RigidObjectDataset ):
39
- # TODO: remove this wrapper from code base
40
- cosypose = CosyPoseWrapper (
41
- dataset_name = dataset_to_use , object_dataset = object_dataset , n_workers = 1
42
- )
43
-
44
- return cosypose .pose_predictor
45
-
46
-
47
- def run_inference (
48
- pose_estimator : PoseEstimator ,
49
- observation : ObservationTensor ,
50
- detections : DetectionsType ,
51
- ) -> None :
52
- observation .to (device )
53
-
54
- data_TCO , extra_data = pose_estimator .run_inference_pipeline (
55
- observation = observation , detections = detections , n_refiner_iterations = 3
56
- )
57
- print ("Timings:" )
58
- print (extra_data ["timing_str" ])
59
-
60
- return data_TCO .cpu ()
61
-
62
-
63
34
def main ():
64
35
set_logging_level ("info" )
65
36
parser = argparse .ArgumentParser ()
66
37
parser .add_argument ("example_name" )
67
38
parser .add_argument ("--dataset" , type = str , default = "hope" )
68
39
parser .add_argument ("--run-detections" , action = "store_true" )
69
40
parser .add_argument ("--run-inference" , action = "store_true" )
41
+ parser .add_argument ("--run-depth-refiner" , action = "store_true" )
42
+ parser .add_argument ("--depth-refiner-type" , type = str , default = "icp" )
70
43
parser .add_argument ("--vis-detections" , action = "store_true" )
71
44
parser .add_argument ("--vis-poses" , action = "store_true" )
72
45
args = parser .parse_args ()
@@ -82,29 +55,42 @@ def main():
82
55
# Load data
83
56
detections = load_detections (example_dir ).to (device )
84
57
object_dataset = make_example_object_dataset (example_dir )
85
- rgb , depth , camera_data = load_observation_example (example_dir , load_depth = False )
86
- # TODO: cosypose forward does not work if depth is loaded detection
87
- # contrary to megapose
88
- observation = ObservationTensor .from_numpy (rgb , depth = None , K = camera_data .K ).to (
89
- device
90
- )
58
+ rgb , depth , camera_data = load_observation_example (example_dir , load_depth = True )
59
+ observation = ObservationTensor .from_numpy (rgb , depth , camera_data .K ).to (device )
91
60
92
61
# Load models
93
- pose_estimator = setup_pose_estimator (args .dataset , object_dataset )
62
+ cosy = CosyPoseWrapper (
63
+ dataset_name = args .dataset ,
64
+ object_dataset = object_dataset ,
65
+ depth_refiner_type = args .depth_refiner_type ,
66
+ n_workers = 1 ,
67
+ )
94
68
95
69
if args .run_detections :
96
- # TODO: hardcoded detector
97
- detector = load_detector (run_id = "detector-bop-hope-pbr--15246" , device = device )
98
70
# Masks are not used for pose prediction, but are computed by Mask-RCNN anyway
99
- detections = detector .get_detections (observation , output_masks = True )
100
- available_labels = [obj .label for obj in object_dataset .list_objects ]
101
- detections = filter_detections (detections , available_labels )
71
+ detections = cosy .detector .get_detections (observation , output_masks = True )
102
72
else :
103
73
detections = load_detections (example_dir ).to (device )
74
+ available_labels = [obj .label for obj in object_dataset .list_objects ]
75
+ detections = filter_detections (detections , available_labels )
104
76
105
77
if args .run_inference :
106
- output = run_inference (pose_estimator , observation , detections )
107
- save_predictions (output , example_dir )
78
+ data_TCO , extra_data = cosy .pose_predictor .run_inference_pipeline (
79
+ observation = observation ,
80
+ detections = detections ,
81
+ run_detector = False ,
82
+ n_refiner_iterations = 3 ,
83
+ )
84
+ print ("run_inference_pipeline timings:" )
85
+ print (extra_data ["timing_str" ])
86
+ if args .run_depth_refiner :
87
+ t1 = time .perf_counter ()
88
+ data_TCO , _ = cosy .depth_refiner .refine_poses (
89
+ predictions = data_TCO , depth = observation .depth , K = observation .K
90
+ )
91
+ print (f"Depth refiner took: { time .perf_counter () - t1 } " )
92
+
93
+ save_predictions (data_TCO .cpu (), example_dir )
108
94
109
95
if args .vis_detections :
110
96
make_detections_visualization (rgb , detections , example_dir )
0 commit comments