-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathNetwork.py
40 lines (35 loc) · 1.15 KB
/
Network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
Network for the A2C agent. An A2C agent
has a value network and a policy network.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial, reduce
import Config as C
class AcrobotNet (nn.Module) :
"""
Simple multi-layer feed forward networks
with ReLU non linearity.
"""
def __init__ (self, s=C.OBSERVATION_SPACE,
a=C.ACTION_SPACE, hdims=C.HDIMS) :
super(AcrobotNet, self).__init__()
self.pi = self.buildff([s, *hdims, a])
self.v = self.buildff([s, *hdims, 1])
self.dropout = nn.Dropout(p=0.5)
def buildff (self, lst) :
layerDims = zip(lst, lst[1:])
return nn.ModuleList([
nn.Linear(a, b) for a, b in layerDims
])
def withReluDropout (self, y, f) :
return F.relu(self.dropout(f(y)))
def forward (self, x, previous_action, previous_reward) :
pi = reduce(self.withReluDropout, self.pi[:-1], x)
pi = self.pi[-1](pi)
pi = F.softmax(pi, dim=-1)
v = reduce(self.withReluDropout, self.v[:-1], x)
v = self.v[-1](v)
v = v.squeeze(-1)
return pi, v