-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathAgent.py
79 lines (62 loc) · 2.62 KB
/
Agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from dotmap import DotMap
from gym.monitoring import VideoRecorder
class Agent:
"""An general class for RL agents.
"""
def __init__(self, params):
"""Initializes an agent.
Arguments:
params: (DotMap) A DotMap of agent parameters.
.env: (OpenAI gym environment) The environment for this agent.
.noisy_actions: (bool) Indicates whether random Gaussian noise will
be added to the actions of this agent.
.noise_stddev: (float) The standard deviation to be used for the
action noise if params.noisy_actions is True.
"""
assert params.get("noisy_actions", False) is False
self.env = params.env
if isinstance(self.env, DotMap):
raise ValueError("Environment must be provided to the agent at initialization.")
def sample(self, horizon, policy, record_fname=None):
"""Samples a rollout from the agent.
Arguments:
horizon: (int) The length of the rollout to generate from the agent.
policy: (policy) The policy that the agent will use for actions.
record_fname: (str/None) The name of the file to which a recording of the rollout
will be saved. If None, the rollout will not be recorded.
Returns: (dict) A dictionary containing data from the rollout.
The keys of the dictionary are 'obs', 'ac', and 'reward_sum'.
"""
video_record = record_fname is not None
recorder = None if not video_record else VideoRecorder(self.env, record_fname)
times, rewards = [], []
O, A, reward_sum, done = [self.env.reset()], [], 0, False
policy.reset()
for t in range(horizon):
if video_record:
recorder.capture_frame()
start = time.time()
A.append(policy.act(O[t], t))
times.append(time.time() - start)
obs, reward, done, info = self.env.step(A[t])
O.append(obs)
reward_sum += reward
rewards.append(reward)
if done:
break
if video_record:
recorder.capture_frame()
recorder.close()
print("Average action selection time: ", np.mean(times))
print("Rollout length: ", len(A))
return {
"obs": np.array(O),
"ac": np.array(A),
"reward_sum": reward_sum,
"rewards": np.array(rewards),
}