-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmontezuma_train.py
119 lines (96 loc) · 3.69 KB
/
montezuma_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import typing
import gym
import torch
from torch.multiprocessing import Process, Pipe, set_start_method, get_context
from multiprocessing import Lock
from tensorboardX import SummaryWriter
from util.config import default_config
from environments.atari_env import AtariEnvironmentWrapper
from training.trainer import run_rnd_trainer
from models.rnd_agent import RNDPPOAgent
from training.env_runner import ParallelEnvironmentRunner, get_default_stored_data
try:
set_start_method('spawn')
except RuntimeError:
pass
USETPU = default_config.get("UseTPU", False)
OPT_DEVICE = default_config["OptimDevice"]
RUN_DEVICE = default_config["RunDevice"]
NUM_WORKERS = default_config["NumWorkers"]
ENV_NAME = default_config["EnvName"]
EPOCHS = default_config["NumEpochs"]
ROLLOUT_STEPS = default_config["RolloutSteps"]
STATE_DICT = default_config.get("StateDict", None)
IMAGE_HEIGHT = default_config["ImageHeight"]
IMAGE_WIDTH = default_config["ImageWidth"]
Buffers = typing.Dict[str, typing.List[torch.Tensor]]
def create_buffer(W, T, action_dim) -> Buffers:
specs = dict(
states=dict(size=(W, T, 4, IMAGE_HEIGHT, IMAGE_WIDTH), dtype=torch.uint8),
next_states=dict(size=(W, T, 4, IMAGE_HEIGHT, IMAGE_WIDTH), dtype=torch.uint8),
actions=dict(size=(W, T), dtype=torch.long),
rewards=dict(size=(W, T), dtype=torch.float32),
dones=dict(size=(W, T), dtype=torch.bool),
real_dones=dict(size=(W, T), dtype=torch.bool),
ext_values=dict(size=(W, T), dtype=torch.float32),
int_values=dict(size=(W, T), dtype=torch.float32),
policies=dict(size=(W, T, action_dim), dtype=torch.float32),
log_prob_policies=dict(size=(W, T), dtype=torch.float32),
obs_stats=dict(size=(2, IMAGE_HEIGHT, IMAGE_WIDTH), dtype=torch.float32),
)
buffers: Buffers = {key: [] for key in specs}
for key in buffers:
buffers[key].append(torch.empty(**specs[key]).share_memory_())
return buffers
def train_montezuma():
env = AtariEnvironmentWrapper(ENV_NAME, False, 0, None)
action_dim = env.env.action_space.n
init_state = torch.from_numpy(env.reset())
env.env.close()
del env
if not USETPU:
opt_device = OPT_DEVICE
run_device = RUN_DEVICE
print_fn = print
else:
import torch_xla
import torch_xla.core.xla_model as xm
opt_device = xm.xla_device()
run_device = xm.xla_device()
print_fn = xm.master_print
if STATE_DICT is not None:
state_dict = torch.load(STATE_DICT, map_location='cpu')
else:
state_dict = None
print_fn("Initializing agent...")
agent = RNDPPOAgent(action_dim, device=opt_device)
writer = SummaryWriter()
print_fn("Initializing buffer and shared state...")
with torch.no_grad():
buffer = create_buffer(NUM_WORKERS, ROLLOUT_STEPS, action_dim)
shared_model = RNDPPOAgent(action_dim, device=run_device)
shared_model.share_memory()
parent_conn, child_conn = Pipe()
print_fn("Initializing Environment Runner...")
env_runner = ParallelEnvironmentRunner(
NUM_WORKERS, action_dim, ROLLOUT_STEPS, shared_model, init_state,
buffer, EPOCHS,
parent_conn, writer,
)
if state_dict and "N_Episodes" in state_dict.keys():
env_runner.log_episode = state_dict["N_Episodes"]
print_fn("Done, initializing RNDTrainer...")
learner = Process(
target=run_rnd_trainer,
args=(
NUM_WORKERS, 4, child_conn, agent,
buffer, shared_model, EPOCHS, state_dict
)
)
learner.start()
print_fn("Done, training")
env_runner.run_agent()
learner.join()
print_fn("Finished!")
if __name__ == '__main__':
train_montezuma()