|
28 | 28 | from happypose.pose_estimators.cosypose.cosypose.integrated.pose_estimator import (
|
29 | 29 | PoseEstimator,
|
30 | 30 | )
|
| 31 | +from happypose.pose_estimators.megapose.evaluation.meters.modelnet_meters import ( |
| 32 | + ModelNetErrorMeter, |
| 33 | +) |
| 34 | + |
31 | 35 | from happypose.pose_estimators.cosypose.cosypose.scripts.run_cosypose_eval import (
|
32 | 36 | get_pose_meters,
|
33 | 37 | load_pix2pose_results,
|
@@ -89,7 +93,7 @@ def save_checkpoint(model):
|
89 | 93 | logger.info(test_dict)
|
90 | 94 |
|
91 | 95 |
|
92 |
| -def make_eval_bundle(args, model_training): |
| 96 | +def make_eval_bundle(args, model_training, mesh_db): |
93 | 97 | eval_bundle = {}
|
94 | 98 | model_training.cfg = args
|
95 | 99 |
|
@@ -217,10 +221,9 @@ def load_model(run_id):
|
217 | 221 | )
|
218 | 222 |
|
219 | 223 | # Evaluation
|
220 |
| - meters = get_pose_meters(scene_ds, ds_name) |
221 |
| - meters = {k.split("_")[0]: v for k, v in meters.items()} |
222 |
| - list(iter(pred_runner.sampler)) |
223 |
| - print(scene_ds.frame_index) |
| 224 | + meters = { |
| 225 | + "modelnet": ModelNetErrorMeter(mesh_db, sample_n_points=None), |
| 226 | + } |
224 | 227 | # scene_ds_ids = np.concatenate(
|
225 | 228 | # scene_ds.frame_index.loc[mv_group_ids, "scene_ds_ids"].values
|
226 | 229 | # )
|
@@ -335,16 +338,12 @@ def make_datasets(dataset_names):
|
335 | 338 | n_workers=args.n_rendering_workers,
|
336 | 339 | preload_cache=False,
|
337 | 340 | )
|
338 |
| - mesh_db = ( |
339 |
| - MeshDataBase.from_object_ds(object_ds) |
340 |
| - .batched(n_sym=args.n_symmetries_batch) |
341 |
| - .cuda() |
342 |
| - .float() |
343 |
| - ) |
| 341 | + mesh_db = MeshDataBase.from_object_ds(object_ds) |
| 342 | + mesh_db_batched = mesh_db.batched(n_sym=args.n_symmetries_batch).cuda().float() |
344 | 343 |
|
345 |
| - model = create_model_pose(cfg=args, renderer=renderer, mesh_db=mesh_db).cuda() |
| 344 | + model = create_model_pose(cfg=args, renderer=renderer, mesh_db=mesh_db_batched).cuda() |
346 | 345 |
|
347 |
| - eval_bundle = make_eval_bundle(args, model) |
| 346 | + eval_bundle = make_eval_bundle(args, model, mesh_db) |
348 | 347 |
|
349 | 348 | if args.resume_run_id:
|
350 | 349 | resume_dir = EXP_DIR / args.resume_run_id
|
@@ -413,7 +412,7 @@ def lambd(batch):
|
413 | 412 | model=model,
|
414 | 413 | cfg=args,
|
415 | 414 | n_iterations=args.n_iterations,
|
416 |
| - mesh_db=mesh_db, |
| 415 | + mesh_db=mesh_db_batched, |
417 | 416 | input_generator=args.TCO_input_generator,
|
418 | 417 | )
|
419 | 418 |
|
@@ -476,9 +475,9 @@ def test():
|
476 | 475 | if epoch % args.val_epoch_interval == 0:
|
477 | 476 | validation()
|
478 | 477 |
|
479 |
| - test_dict = None |
480 |
| - if epoch % args.test_epoch_interval == 0: |
481 |
| - test_dict = test() |
| 478 | + #test_dict = None |
| 479 | + #if epoch % args.test_epoch_interval == 0: |
| 480 | + # test_dict = test() |
482 | 481 |
|
483 | 482 | log_dict = {}
|
484 | 483 | log_dict.update(
|
@@ -507,6 +506,6 @@ def test():
|
507 | 506 | model=model,
|
508 | 507 | epoch=epoch,
|
509 | 508 | log_dict=log_dict,
|
510 |
| - test_dict=test_dict, |
| 509 | + test_dict=None, |
511 | 510 | )
|
512 | 511 | dist.barrier()
|
0 commit comments