-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathddpg.py
418 lines (355 loc) · 18.5 KB
/
ddpg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
import sys
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
from models import model_factory
#from shuffle import shuffle_others_loc, sort_state_batch
def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)
def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
"""
From: https://github.com/pytorch/pytorch/issues/1959
There's an official LayerNorm implementation in pytorch now, but it hasn't been included in
pip version yet. This is a temporary version
This slows down training by a bit
"""
def adjust_lr(optimizer, init_lr, episode_i, num_episode, start_episode):
if episode_i < start_episode:
return init_lr
lr = init_lr * (1 - (episode_i - start_episode) / (num_episode - start_episode))
#lr = init_lr * (0.1 ** episode_i)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class LayerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, affine=True):
super(LayerNorm, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
shape = [-1] + [1] * (x.dim() - 1)
mean = x.view(x.size(0), -1).mean(1).view(*shape)
std = x.view(x.size(0), -1).std(1).view(*shape)
y = (x - mean) / (std + self.eps)
if self.affine:
shape = [1, -1] + [1] * (x.dim() - 2)
y = self.gamma.view(*shape) * y + self.beta.view(*shape)
return y
nn.LayerNorm = LayerNorm
class Actor(nn.Module):
def __init__(self, hidden_size, num_inputs, num_outputs):
super(Actor, self).__init__()
self.linear1 = nn.Linear(num_inputs, hidden_size)
# self.ln1 = nn.LayerNorm(hidden_size)
self.linear2 = nn.Linear(hidden_size, hidden_size)
# self.ln2 = nn.LayerNorm(hidden_size)
self.mu = nn.Linear(hidden_size, num_outputs)
self.mu.weight.data.mul_(0.1)
self.mu.bias.data.mul_(0.1)
def forward(self, inputs):
x = inputs
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
mu = self.mu(x)
return mu
class Critic(nn.Module):
def __init__(self, hidden_size, num_inputs, num_outputs, num_agents, critic_type='mlp', agent_id=0):
super(Critic, self).__init__()
self.num_agents = num_agents
self.critic_type = critic_type
sa_dim = int((num_inputs + num_outputs) / num_agents)
self.agent_id = agent_id
self.net_fn = model_factory.get_model_fn(critic_type)
self.net = self.net_fn(sa_dim, num_agents, hidden_size)
def forward(self, inputs, actions):
bz = inputs.size()[0]
#s_n = inputs.view(bz, -1, self.num_agents)
#a_n = actions.view(bz, -1, self.num_agents)
s_n = inputs.view(bz, self.num_agents, -1)
a_n = actions.view(bz, self.num_agents, -1)
x = torch.cat((s_n, a_n), dim=2)
#x = torch.transpose(x, 1, 2)
V = self.net(x)
return V
class DDPG(object):
def __init__(self, gamma, tau, hidden_size, obs_dim, n_action, n_agent, obs_dims, agent_id, actor_lr, critic_lr,
fixed_lr, critic_type, train_noise, num_episodes, num_steps,
critic_dec_cen, target_update_mode='soft', device='cpu'):
self.device = device
self.obs_dim = obs_dim
self.n_agent = n_agent
self.n_action = n_action
self.actor = Actor(hidden_size, obs_dim, n_action).to(self.device)
self.actor_target = Actor(hidden_size, obs_dim, n_action).to(self.device)
self.actor_perturbed = Actor(hidden_size, obs_dim, n_action)
self.actor_optim = Adam(self.actor.parameters(),
lr=actor_lr, weight_decay=0)
if critic_dec_cen == 'decen':
self.critic = Critic(hidden_size, obs_dims[agent_id + 1], n_action, 1, critic_type, agent_id).to(self.device)
self.critic_target = Critic(hidden_size, obs_dims[agent_id + 1], n_action, 1, critic_type, agent_id).to(self.device)
else:
self.critic = Critic(hidden_size, np.sum(obs_dims),
n_action * n_agent, n_agent, critic_type, agent_id).to(self.device)
self.critic_target = Critic(hidden_size, np.sum(
obs_dims), n_action * n_agent, n_agent, critic_type, agent_id).to(self.device)
critic_n_params = sum(p.numel() for p in self.critic.parameters())
print('# of critic params', critic_n_params)
self.critic_optim = Adam(self.critic.parameters(), lr=critic_lr)
self.fixed_lr = fixed_lr
self.init_act_lr = actor_lr
self.init_critic_lr = critic_lr
self.num_episodes = num_episodes
#self.start_episode = num_episodes / 2
self.start_episode = 0
self.num_steps = num_steps
self.actor_scheduler = LambdaLR(self.actor_optim, lr_lambda=self.lambda1)
self.critic_scheduler = LambdaLR(self.critic_optim, lr_lambda=self.lambda1)
self.gamma = gamma
self.tau = tau
self.train_noise = train_noise
self.obs_dims_cumsum = np.cumsum(obs_dims)
self.critic_dec_cen = critic_dec_cen
self.agent_id = agent_id
self.debug = False
self.target_update_mode = target_update_mode
# Make sure target is with the same weight
hard_update(self.actor_target, self.actor)
hard_update(self.critic_target, self.critic)
def adjust_lr(self, i_episode):
adjust_lr(self.actor_optim, self.init_act_lr, i_episode, self.num_episodes, self.start_episode)
adjust_lr(self.critic_optim, self.init_critic_lr, i_episode, self.num_episodes, self.start_episode)
def lambda1(self, step):
start_decrease_step = ((self.num_episodes / 2)
* self.num_steps) / 100
max_step = (self.num_episodes * self.num_steps) / 100
return 1 - ((step - start_decrease_step) / (
max_step - start_decrease_step)) if step > start_decrease_step else 1
def select_action(self, state, action_noise=None, param_noise=False, grad=False):
self.actor.eval()
if param_noise:
mu = self.actor_perturbed((Variable(state)))
else:
mu = self.actor((Variable(state)))
self.actor.train()
if not grad:
mu = mu.data
if action_noise:
noise = np.log(-np.log(np.random.uniform(0, 1, mu.size())))
try:
mu -= torch.Tensor(noise).to(self.device)
except (AttributeError, AssertionError):
mu -= torch.Tensor(noise)
action = F.softmax(mu, dim=1)
if not grad:
return action
else:
return action, mu
def update_critic_parameters(self, batch, agent_id, shuffle=None, eval=False):
state_batch = Variable(torch.cat(batch.state)).to(self.device)
action_batch = Variable(torch.cat(batch.action)).to(self.device)
reward_batch = Variable(torch.cat(batch.reward)).to(self.device)
mask_batch = Variable(torch.cat(batch.mask)).to(self.device)
next_state_batch = torch.cat(batch.next_state).to(self.device)
if shuffle == 'shuffle':
rand_idx = np.random.permutation(self.n_agent)
#shuffle_others_loc(list(rand_idx), state_batch)
#shuffle_others_loc(list(rand_idx), next_state_batch)
new_state_batch = state_batch.view(-1, self.n_agent, self.obs_dim)
state_batch = new_state_batch[:, rand_idx, :].view(-1, self.obs_dim * self.n_agent)
new_next_state_batch = next_state_batch.view(-1, self.n_agent, self.obs_dim)
next_state_batch = new_next_state_batch[:, rand_idx, :].view(-1, self.obs_dim * self.n_agent)
new_action_batch = action_batch.view(-1, self.n_agent, self.n_action)
action_batch = new_action_batch[:, rand_idx, :].view(-1, self.n_action * self.n_agent)
#elif shuffle == 'sort':
# idxs = sort_state_batch(state_batch)
# for i in range(state_batch.size()[0]):
# shuffle_others_loc(list(idxs[i]), next_state_batch[i].view(1, -1))
# next_state_batch[i] = next_state_batch[i].view(self.n_agent, self.obs_dim)[idxs[i]].view(-1)
# action_batch[i] = action_batch[i].view(self.n_agent, self.n_action)[idxs[i]].view(-1)
next_actions = []
for k in range(self.n_agent):
next_obs_batch = next_state_batch[:,
self.obs_dims_cumsum[k]: self.obs_dims_cumsum[k + 1]]
next_actions.append(self.select_action(
next_obs_batch, action_noise=self.train_noise))
next_action_batch = torch.cat(next_actions, dim=1)
if self.critic_dec_cen == 'decen':
next_state_action_values = self.critic_target(
next_state_batch[:, self.obs_dims_cumsum[self.agent_id]: self.obs_dims_cumsum[self.agent_id + 1]],
next_action_batch[:, self.agent_id * self.n_action: (self.agent_id + 1) * self.n_action])
else:
next_state_action_values = self.critic_target(
next_state_batch, next_action_batch)
reward_batch = reward_batch[:, agent_id].unsqueeze(1)
mask_batch = mask_batch[:, agent_id].unsqueeze(1)
expected_state_action_batch = reward_batch + (self.gamma * mask_batch * next_state_action_values)
self.critic_optim.zero_grad()
if self.critic_dec_cen == 'decen':
state_action_batch = self.critic(
(state_batch[:, self.obs_dims_cumsum[self.agent_id]: self.obs_dims_cumsum[self.agent_id + 1]]),
(action_batch[:, self.agent_id * self.n_action: (self.agent_id + 1) * self.n_action]))
else:
state_action_batch = self.critic(state_batch, action_batch)
perturb_out = 0
if self.debug:
critic_clone = self.critic
state_action_batch = critic_clone(state_batch, action_batch)
#perturbed_state_batch = torch.Tensor(np.random.normal(state_batch, 0.1))
#perturbed_action_batch = torch.Tensor(np.random.normal(action_batch, 0.1))
state_action_batch_perturb = critic_clone(1.01 * state_batch, action_batch)
perturb_out = F.mse_loss(state_action_batch, state_action_batch_perturb).item()
print('ori - perturb ', perturb_out)
value_loss = (
(state_action_batch - expected_state_action_batch) ** 2).mean()
if eval:
return value_loss.item(), perturb_out
value_loss.backward()
unclipped_norm = clip_grad_norm_(self.critic.parameters(), 0.5)
self.critic_optim.step()
if self.target_update_mode == 'soft':
soft_update(self.critic_target, self.critic, self.tau)
elif self.target_update_mode == 'hard':
hard_update(self.critic_target, self.critic)
return value_loss.item(), perturb_out, unclipped_norm
def update_actor_parameters(self, batch, agent_id, shuffle=None):
state_batch = Variable(torch.cat(batch.state)).to(self.device)
action_batch = Variable(torch.cat(batch.action)).to(self.device)
reward_batch = Variable(torch.cat(batch.reward))
mask_batch = Variable(torch.cat(batch.mask))
next_state_batch = torch.cat(batch.next_state).to(self.device)
if shuffle == 'shuffle':
rand_idx = np.random.permutation(self.n_agent)
#shuffle_others_loc(list(rand_idx), state_batch)
#shuffle_others_loc(list(rand_idx), next_state_batch)
new_state_batch = state_batch.view(-1, self.n_agent, self.obs_dim)
state_batch = new_state_batch[:, rand_idx, :].view(-1, self.obs_dim * self.n_agent)
#next_actions = []
#for k in range(self.n_agent):
# next_obs_batch = next_state_batch[:,
# self.obs_dims_cumsum[k]: self.obs_dims_cumsum[k + 1]]
# next_actions.append(self.select_action(
# next_obs_batch, action_noise=self.train_noise))
self.actor_optim.zero_grad()
actions = []
for k in range(self.n_agent):
obs_batch = state_batch[:, self.obs_dims_cumsum[k]: self.obs_dims_cumsum[k + 1]]
if k == agent_id:
action, logit = self.select_action(
obs_batch, action_noise=self.train_noise, grad=True)
else:
action = self.select_action(
obs_batch, action_noise=self.train_noise, grad=False)
actions.append(action)
action_batch_n = torch.cat(actions, dim=1)
if self.critic_dec_cen == 'decen':
policy_loss = -self.critic(
state_batch[:, self.obs_dims_cumsum[self.agent_id]: self.obs_dims_cumsum[self.agent_id + 1]],
action_batch_n[:, self.agent_id * self.n_action: (self.agent_id + 1) * self.n_action])
else:
policy_loss = -self.critic(state_batch, action_batch_n)
policy_loss = policy_loss.mean() + 1e-3 * (logit ** 2).mean()
policy_loss.backward()
clip_grad_norm_(self.actor.parameters(), 0.5)
self.actor_optim.step()
soft_update(self.actor_target, self.actor, self.tau)
soft_update(self.critic_target, self.critic, self.tau)
return policy_loss.item()
def update_parameters(self, batch, agent_id, shuffle=None):
state_batch = Variable(torch.cat(batch.state))
action_batch = Variable(torch.cat(batch.action))
reward_batch = Variable(torch.cat(batch.reward))
mask_batch = Variable(torch.cat(batch.mask))
next_state_batch = torch.cat(batch.next_state)
if shuffle == 'shuffle':
rand_idx = np.random.permutation(self.n_agent)
#shuffle_others_loc(list(rand_idx), state_batch)
#shuffle_others_loc(list(rand_idx), next_state_batch)
new_state_batch = state_batch.view(-1, self.n_agent, self.obs_dim)
state_batch = new_state_batch[:, rand_idx, :].view(-1, self.obs_dim * self.n_agent)
new_next_state_batch = next_state_batch.view(-1, self.n_agent, self.obs_dim)
next_state_batch = new_next_state_batch[:, rand_idx, :].view(-1, self.obs_dim * self.n_agent)
new_action_batch = action_batch.view(-1, self.n_agent, self.n_action)
action_batch = new_action_batch[:, rand_idx, :].view(-1, self.n_action * self.n_agent)
#elif shuffle == 'sort':
# idxs = sort_state_batch(state_batch)
# for i in range(state_batch.size()[0]):
# shuffle_others_loc(list(idxs[i]), next_state_batch[i].view(1, -1))
# next_state_batch[i] = next_state_batch[i].view(self.n_agent, self.obs_dim)[idxs[i]].view(-1)
# action_batch[i] = action_batch[i].view(self.n_agent, self.n_action)[idxs[i]].view(-1)
next_actions = []
for k in range(self.n_agent):
next_obs_batch = next_state_batch[:,
self.obs_dims_cumsum[k]: self.obs_dims_cumsum[k + 1]]
next_actions.append(self.select_action(
next_obs_batch, action_noise=self.train_noise))
next_action_batch = torch.cat(next_actions, dim=1)
next_state_action_values = self.critic_target(
next_state_batch, next_action_batch)
reward_batch = reward_batch[:, agent_id].unsqueeze(1)
mask_batch = mask_batch[:, agent_id].unsqueeze(1)
expected_state_action_batch = reward_batch + (self.gamma * mask_batch * next_state_action_values)
self.critic_optim.zero_grad()
state_action_batch = self.critic((state_batch), (action_batch))
value_loss = (
(state_action_batch - expected_state_action_batch) ** 2).mean()
value_loss.backward()
clip_grad_norm_(self.critic.parameters(), 0.5)
self.critic_optim.step()
self.actor_optim.zero_grad()
actions = []
for k in range(self.n_agent):
obs_batch = state_batch[:, self.obs_dims_cumsum[k]: self.obs_dims_cumsum[k + 1]]
if k == agent_id:
action, logit = self.select_action(
obs_batch, action_noise=self.train_noise, grad=True)
else:
action = self.select_action(
obs_batch, action_noise=self.train_noise, grad=False)
actions.append(action)
action_batch_n = torch.cat(actions, dim=1)
policy_loss = -self.critic(state_batch, action_batch_n)
policy_loss = policy_loss.mean() + 1e-3 * (logit ** 2).mean()
policy_loss.backward()
clip_grad_norm_(self.actor.parameters(), 0.5)
self.actor_optim.step()
soft_update(self.actor_target, self.actor, self.tau)
soft_update(self.critic_target, self.critic, self.tau)
return value_loss.item(), policy_loss.item()
def perturb_actor_parameters(self, param_noise):
"""Apply parameter noise to actor model, for exploration"""
hard_update(self.actor_perturbed, self.actor)
params = self.actor_perturbed.state_dict()
for name in params:
if 'ln' in name:
pass
param = params[name]
param += torch.randn(param.shape) * param_noise.current_stddev
def save_model(self, env_name, suffix="", actor_path=None, critic_path=None):
if not os.path.exists('models/'):
os.makedirs('models/')
if actor_path is None:
actor_path = "models/ddpg_actor_{}_{}".format(env_name, suffix)
if critic_path is None:
critic_path = "models/ddpg_critic_{}_{}".format(env_name, suffix)
print('Saving models to {} and {}'.format(actor_path, critic_path))
torch.save(self.actor.state_dict(), actor_path)
torch.save(self.critic.state_dict(), critic_path)
def load_model(self, actor_path, critic_path):
print('Loading models from {} and {}'.format(actor_path, critic_path))
if actor_path is not None:
self.actor.load_state_dict(torch.load(actor_path))
if critic_path is not None:
self.critic.load_state_dict(torch.load(critic_path))