Skip to content

Commit 57f8f57

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 53d7212 commit 57f8f57

File tree

4 files changed

+14
-4
lines changed

4 files changed

+14
-4
lines changed

happypose/pose_estimators/cosypose/cosypose/evaluation/evaluation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def load_detector(run_id, ds_name):
6565
cfg = check_update_config_detector(cfg)
6666
label_to_category_id = cfg.label_to_category_id
6767
model = create_model_detector(cfg, len(label_to_category_id))
68-
ckpt = torch.load(run_dir / "checkpoint.pth.tar", map_location=device, weights_only=True)
68+
ckpt = torch.load(
69+
run_dir / "checkpoint.pth.tar", map_location=device, weights_only=True
70+
)
6971
ckpt = ckpt["state_dict"]
7072
model.load_state_dict(ckpt)
7173
model = model.to(device).eval()

happypose/pose_estimators/cosypose/cosypose/training/evaluation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def load_detector(run_id, ds_name):
6969
cfg = check_update_config_detector(cfg)
7070
label_to_category_id = cfg.label_to_category_id
7171
model = create_model_detector(cfg, len(label_to_category_id))
72-
ckpt = torch.load(run_dir / "checkpoint.pth.tar", map_location=device, weights_only=True)
72+
ckpt = torch.load(
73+
run_dir / "checkpoint.pth.tar", map_location=device, weights_only=True
74+
)
7375
ckpt = ckpt["state_dict"]
7476
model.load_state_dict(ckpt)
7577
model = model.to(device).eval()

happypose/pose_estimators/cosypose/cosypose/training/pose_models_cfg.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def load_model_cosypose(
6565
cfg = yaml.load((run_dir / "config.yaml").read_text(), Loader=yaml.UnsafeLoader)
6666
cfg = check_update_config(cfg)
6767
model = create_pose_model_cosypose(cfg, renderer=renderer, mesh_db=mesh_db_batched)
68-
ckpt = torch.load(run_dir / "checkpoint.pth.tar", map_location=device, weights_only=True)
68+
ckpt = torch.load(
69+
run_dir / "checkpoint.pth.tar", map_location=device, weights_only=True
70+
)
6971
ckpt = ckpt["state_dict"]
7072
model.load_state_dict(ckpt)
7173
model = model.to(device).eval()

happypose/toolbox/inference/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,11 @@ def load_detector(run_id: str, device="cpu") -> torch.nn.Module:
6464
cfg = check_update_config_detector(cfg)
6565
label_to_category_id = cfg.label_to_category_id
6666
model = create_model_detector(cfg, len(label_to_category_id))
67-
ckpt = torch.load(run_dir / "checkpoint.pth.tar", map_location=torch.device(device), weights_only=True)
67+
ckpt = torch.load(
68+
run_dir / "checkpoint.pth.tar",
69+
map_location=torch.device(device),
70+
weights_only=True,
71+
)
6872
ckpt = ckpt["state_dict"]
6973
model.load_state_dict(ckpt)
7074
model = model.to(device).eval()

0 commit comments

Comments
 (0)