Skip to content

Commit 05c0e2e

Browse files
author
centos Cloud User
committed
first working version of pose training
1 parent d42fcca commit 05c0e2e

File tree

3 files changed

+21
-23
lines changed

3 files changed

+21
-23
lines changed

happypose/pose_estimators/cosypose/cosypose/scripts/run_cosypose_eval.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
RESULTS_DIR,
2424
)
2525
from happypose.pose_estimators.cosypose.cosypose.datasets.bop import remap_bop_targets
26-
from happypose.pose_estimators.cosypose.cosypose.datasets.datasets_cfg import (
26+
from happypose.toolbox.datasets.datasets_cfg import (
2727
make_object_dataset,
2828
make_scene_dataset,
2929
)
@@ -50,9 +50,8 @@
5050
from happypose.pose_estimators.cosypose.cosypose.integrated.pose_predictor import (
5151
CoarseRefinePosePredictor,
5252
)
53-
from happypose.pose_estimators.cosypose.cosypose.lib3d.rigid_mesh_database import (
54-
MeshDataBase,
55-
)
53+
from happypose.toolbox.lib3d.rigid_mesh_database import MeshDataBase
54+
5655
from happypose.pose_estimators.cosypose.cosypose.rendering.bullet_batch_renderer import ( # noqa: E501
5756
BulletBatchRenderer,
5857
)

happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def make_cfg(args):
7171
# Training
7272
cfg.batch_size = 16
7373
cfg.epoch_size = 115200
74-
cfg.n_epochs = 700
74+
cfg.n_epochs = 3
7575
cfg.n_dataloader_workers = N_WORKERS
7676

7777
# Method

happypose/pose_estimators/cosypose/cosypose/training/train_pose.py

+17-18
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
from happypose.pose_estimators.cosypose.cosypose.integrated.pose_estimator import (
2929
PoseEstimator,
3030
)
31+
from happypose.pose_estimators.megapose.evaluation.meters.modelnet_meters import (
32+
ModelNetErrorMeter,
33+
)
34+
3135
from happypose.pose_estimators.cosypose.cosypose.scripts.run_cosypose_eval import (
3236
get_pose_meters,
3337
load_pix2pose_results,
@@ -89,7 +93,7 @@ def save_checkpoint(model):
8993
logger.info(test_dict)
9094

9195

92-
def make_eval_bundle(args, model_training):
96+
def make_eval_bundle(args, model_training, mesh_db):
9397
eval_bundle = {}
9498
model_training.cfg = args
9599

@@ -217,10 +221,9 @@ def load_model(run_id):
217221
)
218222

219223
# Evaluation
220-
meters = get_pose_meters(scene_ds, ds_name)
221-
meters = {k.split("_")[0]: v for k, v in meters.items()}
222-
list(iter(pred_runner.sampler))
223-
print(scene_ds.frame_index)
224+
meters = {
225+
"modelnet": ModelNetErrorMeter(mesh_db, sample_n_points=None),
226+
}
224227
# scene_ds_ids = np.concatenate(
225228
# scene_ds.frame_index.loc[mv_group_ids, "scene_ds_ids"].values
226229
# )
@@ -335,16 +338,12 @@ def make_datasets(dataset_names):
335338
n_workers=args.n_rendering_workers,
336339
preload_cache=False,
337340
)
338-
mesh_db = (
339-
MeshDataBase.from_object_ds(object_ds)
340-
.batched(n_sym=args.n_symmetries_batch)
341-
.cuda()
342-
.float()
343-
)
341+
mesh_db = MeshDataBase.from_object_ds(object_ds)
342+
mesh_db_batched = mesh_db.batched(n_sym=args.n_symmetries_batch).cuda().float()
344343

345-
model = create_model_pose(cfg=args, renderer=renderer, mesh_db=mesh_db).cuda()
344+
model = create_model_pose(cfg=args, renderer=renderer, mesh_db=mesh_db_batched).cuda()
346345

347-
eval_bundle = make_eval_bundle(args, model)
346+
eval_bundle = make_eval_bundle(args, model, mesh_db)
348347

349348
if args.resume_run_id:
350349
resume_dir = EXP_DIR / args.resume_run_id
@@ -413,7 +412,7 @@ def lambd(batch):
413412
model=model,
414413
cfg=args,
415414
n_iterations=args.n_iterations,
416-
mesh_db=mesh_db,
415+
mesh_db=mesh_db_batched,
417416
input_generator=args.TCO_input_generator,
418417
)
419418

@@ -476,9 +475,9 @@ def test():
476475
if epoch % args.val_epoch_interval == 0:
477476
validation()
478477

479-
test_dict = None
480-
if epoch % args.test_epoch_interval == 0:
481-
test_dict = test()
478+
#test_dict = None
479+
#if epoch % args.test_epoch_interval == 0:
480+
# test_dict = test()
482481

483482
log_dict = {}
484483
log_dict.update(
@@ -507,6 +506,6 @@ def test():
507506
model=model,
508507
epoch=epoch,
509508
log_dict=log_dict,
510-
test_dict=test_dict,
509+
test_dict=None,
511510
)
512511
dist.barrier()

0 commit comments

Comments
 (0)