34
34
35
35
# Pose estimator
36
36
from happypose .toolbox .datasets .datasets_cfg import make_object_dataset , make_scene_dataset , get_obj_ds_info
37
- from happypose .toolbox .inference .utils import load_detector
38
37
from happypose .toolbox .lib3d .rigid_mesh_database import MeshDataBase
39
38
from happypose .toolbox .utils .distributed import get_rank , get_tmp_dir
40
39
from happypose .toolbox .utils .logging import get_logger
@@ -69,7 +68,7 @@ def load_detector(run_id, ds_name):
69
68
return model
70
69
71
70
72
- def load_pose_models (object_dataset , coarse_run_id , refiner_run_id , n_workers , renderer_type = "panda3d" ):
71
+ def load_pose_models_cosypose (object_dataset , coarse_run_id , refiner_run_id , n_workers , renderer_type = "panda3d" ):
73
72
run_dir = EXP_DIR / coarse_run_id
74
73
cfg = yaml .load ((run_dir / "config.yaml" ).read_text (), Loader = yaml .UnsafeLoader )
75
74
cfg = check_update_config_pose (cfg )
@@ -161,6 +160,8 @@ def run_eval(
161
160
cfg .save_dir = str (save_dir )
162
161
163
162
logger .info (f"Running eval on ds_name={ cfg .ds_name } with setting={ save_key } " )
163
+ # e.g. "ycbv.bop19" -> "ycbv"
164
+ ds_name_short = cfg .ds_name .split ('.' )[0 ]
164
165
165
166
# Load the dataset
166
167
ds_kwargs = {"load_depth" : False }
@@ -181,8 +182,7 @@ def run_eval(
181
182
# Load detector model
182
183
if cfg .inference .detection_type == "detector" :
183
184
assert cfg .detector_run_id is not None
184
- detector_model = load_detector (cfg .detector_run_id , cfg .ds_name )
185
- # detector_model = load_detector(cfg.detector_run_id, device=device)
185
+ detector_model = load_detector (cfg .detector_run_id , ds_name_short )
186
186
elif cfg .inference .detection_type == "gt" :
187
187
detector_model = None
188
188
else :
@@ -195,14 +195,15 @@ def run_eval(
195
195
assert cfg .coarse_run_id is not None
196
196
assert cfg .refiner_run_id is not None
197
197
198
- object_ds = make_object_dataset ('ycbv' )
199
198
200
- coarse_model , refiner_model , mesh_db = load_pose_models (
199
+ object_ds = make_object_dataset (ds_name_short )
200
+
201
+ coarse_model , refiner_model , mesh_db = load_pose_models_cosypose (
201
202
object_ds ,
202
203
coarse_run_id = cfg .coarse_run_id ,
203
204
refiner_run_id = cfg .refiner_run_id ,
204
- n_workers = 0 ,
205
- renderer_type = "bullet"
205
+ n_workers = cfg . inference . n_workers ,
206
+ renderer_type = cfg . inference . renderer
206
207
)
207
208
208
209
renderer = refiner_model .renderer
0 commit comments