forked from bdaiinstitute/vlfm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_vlfm.py
83 lines (66 loc) · 2.83 KB
/
custom_vlfm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Copyright (c) 2023 Boston Dynamics AI Institute LLC. All rights reserved.
import os
import random
# The following imports require habitat to be installed, and despite not being used by
# this script itself, will register several classes and make them discoverable by Hydra.
# This run.py script is expected to only be used when habitat is installed, thus they
# are hidden here instead of in an __init__.py file. This avoids import errors when used
# in an environment without habitat, such as when doing real-world deployment. noqa is
# used to suppress the unused import and unsorted import warnings by ruff.
import frontier_exploration # noqa
import hydra # noqa
import numpy as np
import torch
from habitat import get_config # noqa
from habitat.config import read_write
from habitat.config.default import patch_config
from habitat.config.default_structured_configs import register_hydra_plugin
from habitat_baselines.run import execute_exp
from hydra.core.config_search_path import ConfigSearchPath
from hydra.plugins.search_path_plugin import SearchPathPlugin
from omegaconf import DictConfig
import vlfm.measurements.traveled_stairs # noqa: F401
import vlfm.obs_transformers.resize # noqa: F401
import vlfm.policy.action_replay_policy # noqa: F401
import vlfm.policy.habitat_policies # noqa: F401
import vlfm.utils.vlfm_trainer # noqa: F401
class HabitatConfigPlugin(SearchPathPlugin):
def manipulate_search_path(self, search_path: ConfigSearchPath) -> None:
search_path.append(provider="habitat", path="config/")
register_hydra_plugin(HabitatConfigPlugin)
@hydra.main(
version_base=None,
config_path="config",
config_name="experiments/vlfm_objectnav_hm3d",
)
def main(config: DictConfig) -> None:
assert os.path.isdir("data"), "Missing 'data/' directory!"
if not os.path.isfile("data/dummy_policy.pth"):
print("Dummy policy weights not found! Please run the following command first:")
print("python -m vlfm.utils.generate_dummy_policy")
exit(1)
config = patch_config(config)
with read_write(config):
try:
config.habitat.simulator.agents.main_agent.sim_sensors.pop("semantic_sensor")
except KeyError:
pass
random.seed(config.habitat.seed)
np.random.seed(config.habitat.seed)
torch.manual_seed(config.habitat.seed)
if (
config.habitat_baselines.force_torch_single_threaded
and torch.cuda.is_available()
):
torch.set_num_threads(1)
from habitat_baselines.common.baseline_registry import baseline_registry
trainer_init = baseline_registry.get_trainer(
config.habitat_baselines.trainer_name
)
assert (
trainer_init is not None
), f"{config.habitat_baselines.trainer_name} is not supported"
trainer = trainer_init(config)
trainer.eval()
if __name__ == "__main__":
main()