Skip to content

Commit d42fcca

Browse files
committed
1st step of fixing training + debug
1 parent 31be268 commit d42fcca

File tree

11 files changed

+61
-52
lines changed

11 files changed

+61
-52
lines changed

happypose/pose_estimators/cosypose/cosypose/datasets/augmentations.py

-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ def __call__(self, im, mask, obs):
143143

144144
class VOCBackgroundAugmentation(BackgroundAugmentation):
145145
def __init__(self, voc_root, p=0.3):
146-
print("voc_root =", voc_root)
147146
image_dataset = VOCSegmentation(
148147
root=voc_root,
149148
year="2012",

happypose/pose_estimators/cosypose/cosypose/evaluation/meters/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def add_inst_num(
2525
group_keys=["scene_id", "view_id", "label"],
2626
key="pred_inst_num",
2727
):
28-
inst_num = np.empty(len(infos), dtype=np.int)
28+
inst_num = np.empty(len(infos), dtype=int)
2929
for _group_name, group_ids in infos.groupby(group_keys).groups.items():
3030
inst_num[group_ids.values] = np.arange(len(group_ids))
3131
infos[key] = inst_num

happypose/pose_estimators/cosypose/cosypose/evaluation/prediction_runner.py

+39-36
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,18 @@ def run_inference_pipeline(
9595
9696
9797
"""
98-
if self.inference_cfg.detection_type == "gt":
98+
if self.inference_cfg['detection_type'] == "gt":
9999
detections = gt_detections
100100
run_detector = False
101-
elif self.inference_cfg.detection_type == "detector":
101+
elif self.inference_cfg['detection_type'] == "detector":
102102
detections = None
103103
run_detector = True
104104
else:
105-
msg = f"Unknown detection type {self.inference_cfg.detection_type}"
105+
msg = f"Unknown detection type {self.inference_cfg['detection_type']}"
106106
raise ValueError(msg)
107107

108108
coarse_estimates = None
109-
if self.inference_cfg.coarse_estimation_type == "external":
109+
if self.inference_cfg['coarse_estimation_type'] == "external":
110110
# TODO (ylabbe): This is hacky, clean this for modelnet eval.
111111
coarse_estimates = initial_estimates
112112
coarse_estimates = happypose.toolbox.inference.utils.add_instance_id(
@@ -137,15 +137,15 @@ def run_inference_pipeline(
137137
all_preds = {}
138138
data_TCO_refiner = extra_data["refiner"]["preds"]
139139

140-
k_0 = f"refiner/iteration={self.inference_cfg.n_refiner_iterations}"
140+
k_0 = f"refiner/iteration={self.inference_cfg['n_refiner_iterations']}"
141141
all_preds = {
142142
"final": preds,
143143
k_0: data_TCO_refiner,
144144
"refiner/final": data_TCO_refiner,
145145
"coarse": extra_data["coarse"]["preds"],
146146
}
147147

148-
if self.inference_cfg.run_depth_refiner:
148+
if self.inference_cfg['run_depth_refiner']:
149149
all_preds["depth_refiner"] = extra_data["depth_refiner"]["preds"]
150150

151151
# Remove any mask tensors
@@ -174,43 +174,46 @@ def get_predictions(
174174
"""
175175
predictions_list = defaultdict(list)
176176
for n, data in enumerate(tqdm(self.dataloader)):
177-
# data is a dict
178-
rgb = data["rgb"]
179-
depth = None
180-
K = data["cameras"].K
181-
gt_detections = data["gt_detections"].cuda()
182-
183-
initial_data = None
184-
if data["initial_data"]:
185-
initial_data = data["initial_data"].cuda()
186-
187-
obs_tensor = ObservationTensor.from_torch_batched(rgb, depth, K)
188-
obs_tensor = obs_tensor.cuda()
189-
190-
# GPU warmup for timing
191-
if n == 0:
177+
if n < 3:
178+
# data is a dict
179+
rgb = data["rgb"]
180+
depth = None
181+
K = data["cameras"].K
182+
gt_detections = data["gt_detections"].cuda()
183+
184+
initial_data = None
185+
if data["initial_data"]:
186+
initial_data = data["initial_data"].cuda()
187+
188+
obs_tensor = ObservationTensor.from_torch_batched(rgb, depth, K)
189+
obs_tensor = obs_tensor.cuda()
190+
191+
# GPU warmup for timing
192+
if n == 0:
193+
with torch.no_grad():
194+
self.run_inference_pipeline(
195+
pose_estimator,
196+
obs_tensor,
197+
gt_detections,
198+
initial_estimates=initial_data,
199+
)
200+
201+
cuda_timer = CudaTimer()
202+
cuda_timer.start()
192203
with torch.no_grad():
193-
self.run_inference_pipeline(
204+
all_preds = self.run_inference_pipeline(
194205
pose_estimator,
195206
obs_tensor,
196207
gt_detections,
197208
initial_estimates=initial_data,
198209
)
210+
cuda_timer.end()
211+
cuda_timer.elapsed()
199212

200-
cuda_timer = CudaTimer()
201-
cuda_timer.start()
202-
with torch.no_grad():
203-
all_preds = self.run_inference_pipeline(
204-
pose_estimator,
205-
obs_tensor,
206-
gt_detections,
207-
initial_estimates=initial_data,
208-
)
209-
cuda_timer.end()
210-
cuda_timer.elapsed()
211-
212-
for k, v in all_preds.items():
213-
predictions_list[k].append(v)
213+
for k, v in all_preds.items():
214+
predictions_list[k].append(v)
215+
else:
216+
break
214217

215218
# Concatenate the lists of PandasTensorCollections
216219
predictions = {}

happypose/pose_estimators/cosypose/cosypose/evaluation/runner_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
def run_pred_eval(pred_runner, pred_kwargs, eval_runner, eval_preds=None):
1515
all_predictions = {}
1616
for pred_prefix, pred_kwargs_n in pred_kwargs.items():
17-
print("Prediction :", pred_prefix)
18-
preds = pred_runner.get_predictions(**pred_kwargs_n)
17+
pose_predictor = pred_kwargs_n['pose_predictor']
18+
preds = pred_runner.get_predictions(pose_predictor)
1919
for preds_name, preds_n in preds.items():
2020
all_predictions[f"{pred_prefix}/{preds_name}"] = preds_n
2121

happypose/pose_estimators/cosypose/cosypose/models/pose.py

+1
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def forward(self, images, K, labels, TCO, n_iterations=1):
167167
K_crop=K_crop,
168168
boxes_rend=boxes_rend,
169169
boxes_crop=boxes_crop,
170+
model_outputs=model_outputs
170171
)
171172

172173
TCO_input = TCO_output

happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
logger = get_logger(__name__)
1111

1212

13+
import warnings
14+
warnings.filterwarnings("ignore")
15+
1316
def make_cfg(args):
1417
cfg = argparse.ArgumentParser("").parse_args([])
1518
if args.config:
@@ -56,7 +59,7 @@ def make_cfg(args):
5659
cfg.n_pose_dims = 9
5760
cfg.n_rendering_workers = N_WORKERS
5861
cfg.refiner_run_id_for_test = None
59-
cfg.coarse_run_id_for_test = None
62+
cfg.coarse_run_id_for_test = "coarse-bop-ycbv-pbr--724183"
6063

6164
# Optimizer
6265
cfg.lr = 3e-4
@@ -66,7 +69,7 @@ def make_cfg(args):
6669
cfg.clip_grad_norm = 0.5
6770

6871
# Training
69-
cfg.batch_size = 32
72+
cfg.batch_size = 16
7073
cfg.epoch_size = 115200
7174
cfg.n_epochs = 700
7275
cfg.n_dataloader_workers = N_WORKERS

happypose/pose_estimators/cosypose/cosypose/training/pose_forward_loss.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ def h_pose(model, mesh_db, data, meters, cfg, n_iterations=1, input_generator="f
6363
losses_TCO_iter = []
6464
for n in range(n_iterations):
6565
iter_outputs = outputs[f"iteration={n+1}"]
66-
K_crop = iter_outputs["K_crop"]
67-
TCO_input = iter_outputs["TCO_input"]
68-
TCO_pred = iter_outputs["TCO_output"]
69-
model_outputs = iter_outputs["model_outputs"]
66+
K_crop = iter_outputs.K_crop
67+
TCO_input = iter_outputs.TCO_input
68+
TCO_pred = iter_outputs.TCO_output
69+
model_outputs = iter_outputs.model_outputs
70+
7071

7172
if cfg.loss_disentangled:
7273
if cfg.n_pose_dims == 9:

happypose/pose_estimators/cosypose/cosypose/training/train_pose.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def load_model(run_id):
9797
if run_id is None:
9898
return None
9999
run_dir = EXP_DIR / run_id
100-
cfg = yaml.load((run_dir / "config.yaml").read_text(), Loader=yaml.FullLoader)
100+
cfg = yaml.load((run_dir / "config.yaml").read_text(), Loader=yaml.Loader)
101101
cfg = check_update_config(cfg)
102102
model = (
103103
create_model_pose(
@@ -422,7 +422,7 @@ def train_epoch():
422422
iterator = tqdm(ds_iter_train, ncols=80)
423423
t = time.time()
424424
for n, sample in enumerate(iterator):
425-
if n < 5:
425+
if n < 3:
426426
if n > 0:
427427
meters_time["data"].add(time.time() - t)
428428

@@ -453,19 +453,19 @@ def train_epoch():
453453
lr_scheduler_warmup.step()
454454
t = time.time()
455455
else:
456-
continue
456+
break
457457
if epoch >= args.n_epochs_warmup:
458458
lr_scheduler.step()
459459

460460
@torch.no_grad()
461461
def validation():
462462
model.eval()
463463
for n, sample in enumerate(tqdm(ds_iter_val, ncols=80)):
464-
if n < 5:
464+
if n < 3:
465465
loss = h(data=sample, meters=meters_val)
466466
meters_val["loss_total"].add(loss.item())
467467
else:
468-
continue
468+
break
469469

470470
@torch.no_grad()
471471
def test():

happypose/pose_estimators/megapose/evaluation/runner_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
def run_pred_eval(pred_runner, pred_kwargs, eval_runner, eval_preds=None):
3131
all_predictions = {}
3232
for pred_prefix, pred_kwargs_n in pred_kwargs.items():
33-
print("Prediction :", pred_prefix)
3433
preds = pred_runner.get_predictions(**pred_kwargs_n)
3534
for preds_name, preds_n in preds.items():
3635
all_predictions[f"{pred_prefix}/{preds_name}"] = preds_n

happypose/pose_estimators/megapose/models/pose_rigid.py

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class PosePredictorOutputCosypose:
5959
K_crop: torch.Tensor
6060
boxes_rend: torch.Tensor
6161
boxes_crop: torch.Tensor
62+
model_outputs: torch.Tensor
6263

6364

6465
@dataclass

happypose/toolbox/lib3d/rigid_mesh_database.py

+2
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def n_sym_mapping(self):
146146
return {label: obj["n_sym"] for label, obj in self.infos.items()}
147147

148148
def select(self, labels):
149+
print("label to id =", self.label_to_id)
150+
print("labels = ", labels)
149151
ids = [self.label_to_id[label] for label in labels]
150152
return Meshes(
151153
infos=[self.infos[label] for label in labels],

0 commit comments

Comments
 (0)