-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlearn.py
34 lines (30 loc) · 845 Bytes
/
learn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import DQN
import torch
from torch.utils.tensorboard import SummaryWriter
vic=[]
scores=[]
render=False
save=False
agent=DQN.Pw_Agent(gamma=0.93,epsilon=1.0)
tb = SummaryWriter(comment=agent.get_hyperparams())
print(agent.get_hyperparams())
print()
for i in range(400):
if i%50==0 :
render=False
else:
render=False
if i%50==0:
torch.save(agent.policy_NN.state_dict(), 'episode {}'.format(i))
reason, score,loss = agent.DQNepisode(N_steps=10000,vid=render)
scores.append(score)
print("Episode ",i,"Score ",score,'Status: ',reason)
vic.append(reason)
if len(vic)>50:
a=vic[-50:]
my_dict = {i:a.count(i) for i in a}
suc=int(my_dict['GOAL'])/50 * 100
print('success rate = ', suc)
tb.add_scalar('Success Rate', suc, i)
tb.flush()
tb.close()