-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
56 lines (46 loc) · 1.22 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import fire
import yaml
import gym
import torch
import numpy as np
from collections import deque
from ddrl.trainer import Trainer
import time
device = "cpu"
def main(cp=None, config=None):
print(cp, config)
config = yaml.load(open(config, "r"), Loader=yaml.Loader)
env = gym.make(config["env"]["env-name"])
env.seed(2021)
obs_dim = env.observation_space.shape[0]
try:
act_dim = env.action_space.shape[0]
except:
act_dim = env.action_space.n
agent = Trainer(
state_size=obs_dim,
action_size=act_dim,
config=config,
device=device,
neptune=None,
)
weights = {
"actor": torch.load(os.path.join(cp, "actor.pth")),
"critic": torch.load(os.path.join(cp, "critic.pth"))
}
agent.sync(weights)
score = 0
observation = env.reset()
while True:
env.render()
time.sleep(0.1)
prob = agent.actor(torch.Tensor(observation).unsqueeze(0))
action = agent.actor.get_best_action(prob).numpy()
observation, reward, done, _ = env.step(np.squeeze(action))
score += reward
if done:
break
print(score)
if __name__ == "__main__":
fire.Fire(main)