Skip to content

Commit f396b05

Browse files
Merge pull request #213 from agimus-project/feat/cosypose_icp
Feat/cosypose icp
2 parents 7957c27 + 3e3b9ee commit f396b05

File tree

4 files changed

+160
-104
lines changed

4 files changed

+160
-104
lines changed

happypose/pose_estimators/cosypose/cosypose/integrated/pose_estimator.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727

2828
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2929

30+
RGB_DIMS = [0, 1, 2]
31+
3032

3133
class PoseEstimator(PoseEstimationModule):
3234
"""Performs inference for pose estimation."""
@@ -290,6 +292,9 @@ def forward_coarse_model(
290292

291293
model_time = 0.0
292294

295+
# [B,3,H,W]
296+
images = observation.images[:, RGB_DIMS]
297+
293298
for batch_idx, (batch_ids,) in enumerate(dl):
294299
data_TCO_input_ = data_TCO_input[batch_ids]
295300
df_ = data_TCO_input_.infos
@@ -302,7 +307,7 @@ def forward_coarse_model(
302307
labels_ = df_["label"].tolist()
303308
batch_im_ids_ = torch.as_tensor(df_["batch_im_id"].values, device=device)
304309

305-
images_ = observation.images[batch_im_ids_]
310+
images_ = images[batch_im_ids_]
306311
K_ = observation.K[batch_im_ids_]
307312
if torch.cuda.is_available():
308313
timer_ = CudaTimer(enabled=cuda_timer)
@@ -399,6 +404,9 @@ def forward_refiner(
399404

400405
model_time = 0.0
401406

407+
# [B,3,H,W]
408+
images = observation.images[:, RGB_DIMS]
409+
402410
for batch_idx, (batch_ids,) in enumerate(dl):
403411
data_TCO_input_ = data_TCO_input[batch_ids]
404412
df_ = data_TCO_input_.infos
@@ -411,7 +419,7 @@ def forward_refiner(
411419
labels_ = df_["label"].tolist()
412420
batch_im_ids_ = torch.as_tensor(df_["batch_im_id"].values, device=device)
413421

414-
images_ = observation.images[batch_im_ids_]
422+
images_ = images[batch_im_ids_]
415423
K_ = observation.K[batch_im_ids_]
416424
if torch.cuda.is_available():
417425
timer_ = CudaTimer(enabled=cuda_timer)

happypose/pose_estimators/cosypose/cosypose/scripts/run_inference_on_example.py

+32-46
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
# Standard Library
22
import argparse
33
import os
4+
import time
45
from pathlib import Path
56

67
# Third Party
78
import torch
89

9-
from happypose.pose_estimators.cosypose.cosypose.integrated.pose_estimator import (
10-
PoseEstimator,
11-
)
12-
1310
# CosyPose
1411
from happypose.pose_estimators.cosypose.cosypose.utils.cosypose_wrapper import (
1512
CosyPoseWrapper,
1613
)
1714

1815
# HappyPose
19-
from happypose.toolbox.datasets.object_dataset import RigidObjectDataset
2016
from happypose.toolbox.inference.example_inference_utils import (
2117
load_detections,
2218
load_object_data,
@@ -26,47 +22,24 @@
2622
make_poses_visualization,
2723
save_predictions,
2824
)
29-
from happypose.toolbox.inference.types import DetectionsType, ObservationTensor
30-
from happypose.toolbox.inference.utils import filter_detections, load_detector
25+
from happypose.toolbox.inference.types import ObservationTensor
26+
from happypose.toolbox.inference.utils import filter_detections
3127
from happypose.toolbox.utils.logging import get_logger, set_logging_level
3228

3329
logger = get_logger(__name__)
3430

3531
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3632

3733

38-
def setup_pose_estimator(dataset_to_use: str, object_dataset: RigidObjectDataset):
39-
# TODO: remove this wrapper from code base
40-
cosypose = CosyPoseWrapper(
41-
dataset_name=dataset_to_use, object_dataset=object_dataset, n_workers=1
42-
)
43-
44-
return cosypose.pose_predictor
45-
46-
47-
def run_inference(
48-
pose_estimator: PoseEstimator,
49-
observation: ObservationTensor,
50-
detections: DetectionsType,
51-
) -> None:
52-
observation.to(device)
53-
54-
data_TCO, extra_data = pose_estimator.run_inference_pipeline(
55-
observation=observation, detections=detections, n_refiner_iterations=3
56-
)
57-
print("Timings:")
58-
print(extra_data["timing_str"])
59-
60-
return data_TCO.cpu()
61-
62-
6334
def main():
6435
set_logging_level("info")
6536
parser = argparse.ArgumentParser()
6637
parser.add_argument("example_name")
6738
parser.add_argument("--dataset", type=str, default="hope")
6839
parser.add_argument("--run-detections", action="store_true")
6940
parser.add_argument("--run-inference", action="store_true")
41+
parser.add_argument("--run-depth-refiner", action="store_true")
42+
parser.add_argument("--depth-refiner-type", type=str, default="icp")
7043
parser.add_argument("--vis-detections", action="store_true")
7144
parser.add_argument("--vis-poses", action="store_true")
7245
args = parser.parse_args()
@@ -82,29 +55,42 @@ def main():
8255
# Load data
8356
detections = load_detections(example_dir).to(device)
8457
object_dataset = make_example_object_dataset(example_dir)
85-
rgb, depth, camera_data = load_observation_example(example_dir, load_depth=False)
86-
# TODO: cosypose forward does not work if depth is loaded detection
87-
# contrary to megapose
88-
observation = ObservationTensor.from_numpy(rgb, depth=None, K=camera_data.K).to(
89-
device
90-
)
58+
rgb, depth, camera_data = load_observation_example(example_dir, load_depth=True)
59+
observation = ObservationTensor.from_numpy(rgb, depth, camera_data.K).to(device)
9160

9261
# Load models
93-
pose_estimator = setup_pose_estimator(args.dataset, object_dataset)
62+
cosy = CosyPoseWrapper(
63+
dataset_name=args.dataset,
64+
object_dataset=object_dataset,
65+
depth_refiner_type=args.depth_refiner_type,
66+
n_workers=1,
67+
)
9468

9569
if args.run_detections:
96-
# TODO: hardcoded detector
97-
detector = load_detector(run_id="detector-bop-hope-pbr--15246", device=device)
9870
# Masks are not used for pose prediction, but are computed by Mask-RCNN anyway
99-
detections = detector.get_detections(observation, output_masks=True)
100-
available_labels = [obj.label for obj in object_dataset.list_objects]
101-
detections = filter_detections(detections, available_labels)
71+
detections = cosy.detector.get_detections(observation, output_masks=True)
10272
else:
10373
detections = load_detections(example_dir).to(device)
74+
available_labels = [obj.label for obj in object_dataset.list_objects]
75+
detections = filter_detections(detections, available_labels)
10476

10577
if args.run_inference:
106-
output = run_inference(pose_estimator, observation, detections)
107-
save_predictions(output, example_dir)
78+
data_TCO, extra_data = cosy.pose_predictor.run_inference_pipeline(
79+
observation=observation,
80+
detections=detections,
81+
run_detector=False,
82+
n_refiner_iterations=3,
83+
)
84+
print("run_inference_pipeline timings:")
85+
print(extra_data["timing_str"])
86+
if args.run_depth_refiner:
87+
t1 = time.perf_counter()
88+
data_TCO, _ = cosy.depth_refiner.refine_poses(
89+
predictions=data_TCO, depth=observation.depth, K=observation.K
90+
)
91+
print(f"Depth refiner took: {time.perf_counter() - t1}")
92+
93+
save_predictions(data_TCO.cpu(), example_dir)
10894

10995
if args.vis_detections:
11096
make_detections_visualization(rgb, detections, example_dir)

0 commit comments

Comments
 (0)