-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
93 lines (71 loc) · 2.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import random
import collections
import yaml
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
class ReplayMemory:
def __init__(self, capacity):
self.capacity = capacity
self.memory = []
self.position = 0
def push(self, e):
if len(self.memory) < self.capacity:
self.memory.append(e)
else:
self.memory[self.position] = e
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
def load_yaml_configs(filename, model_name):
with open(filename, 'r') as f:
model_configs = yaml.load(f)
model_configs = model_configs[model_name]
ModelConfigs = collections.namedtuple('ModelConfigs', model_configs.keys())
for key, value in model_configs.items():
try:
if 'e' in value:
model_configs[key] = float(value)
except Exception:
pass
return ModelConfigs(*model_configs.values())
def enumerate(num_samples_per_dim, low, high):
num_dim = len(low)
assert num_samples_per_dim >= 2, 'Enumeration error. Number of samples per dimension is {}.'.format(num_samples_per_dim)
assert num_dim > 0, 'Enumeration error. Number of dimension is {}.'.format(num_dim)
cur_enums = [[low[0] + (high[0] - low[0]) * i / (num_samples_per_dim - 1)] for i in range(num_samples_per_dim)]
for j in range(1, num_dim):
new_enums = []
enum_new_dim = [[low[j] + (high[j] - low[j]) * i / (num_samples_per_dim - 1)] for i in range(num_samples_per_dim)]
for cur_enum in cur_enums:
for new_element in enum_new_dim:
new_enums.append(cur_enum + new_element)
cur_enums = new_enums
return cur_enums
def get_state(observation):
state = []
for s in observation.values():
state += list(s)
return state
def plot_figure(y, x=None, title=None, xlabel=None, ylabel=None, figure_num=None, display=True, save=False, filename=None):
if figure_num is None:
plt.figure()
else:
plt.figure(figure_num)
plt.clf()
if title is not None:
plt.title(title)
if xlabel is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
if x is not None:
plt.plot(x, y)
else:
plt.plot(y)
if save:
plt.savefig(filename if filename is not None else 'figure.png')
if display:
plt.show()