-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathretro-v2.py
123 lines (101 loc) · 3.51 KB
/
retro-v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from collections import defaultdict, namedtuple
from time import sleep
import numpy as np
import retro
import cv2
def cellfn(frame):
cell = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
cell = cv2.resize(cell, (11, 8), interpolation = cv2.INTER_AREA)
cell = cell // 32
return cell
def hashfn(cell):
return hash(cell.tobytes())
class Weights:
times_chosen = 0.1
times_chosen_since_new = 0
times_seen = 0.3
class Powers:
times_chosen = 0.5
times_chosen_since_new = 0.5
times_seen = 0.5
class Cell(object):
def __init__(self):
self.times_chosen = 0
self.times_chosen_since_new = 0
self.times_seen = 0
def cntscore(self, a):
w = getattr(Weights, a)
p = getattr(Powers, a)
v = getattr(self, a)
return w / (v + e1) ** p + e2
def cellscore(self):
return self.cntscore('times_chosen') +\
self.cntscore('times_chosen_since_new') +\
self.cntscore('times_seen') +\
1
def visit(self):
self.times_seen += 1
self.score = self.cellscore()
return self.times_seen == 1
def choose(self):
self.times_chosen += 1
self.times_chosen_since_new += 1
return self.ram, self.reward, self.trajectory
archive = defaultdict(lambda: Cell())
highscore = 0
frames = 0
e1 = 0.001
e2 = 0.00001
env = retro.make("SuperMarioWorld2-Snes")
frame = env.reset()
score = 0
action = np.zeros(env.action_space.shape)
trajectory = []
iterations = 0
while True:
found_new_cell = False
for i in range(100):
if np.random.random() > 0.95:
action = env.action_space.sample()
frame, reward, terminal, info = env.step(action)
if iterations % 100 == 0:
env.render()
score += reward
terminal |= info['lives'] < 3 or info['health'] < 109
trajectory.append(action)
frames += 4
if score > highscore:
highscore = score
cv2.imshow("Best Cell", cv2.cvtColor(np.copy(frame), cv2.COLOR_BGR2RGB))
cv2.waitKey(1)
if terminal:
break
else:
cell = cellfn(frame)
cv2.imshow("Cell", cv2.resize(cell * 32, (220, 160), interpolation = cv2.INTER_AREA))
cellhash = hashfn(cell)
cell = archive[cellhash]
first_visit = cell.visit()
if first_visit or score > cell.reward or score == cell.reward and len(trajectory) < len(cell.trajectory):
cell.ram = env.em.get_state()
cell.reward = score
cell.trajectory = trajectory.copy()
cell.times_chosen = 0
cell.times_chosen_since_new = 0
cell.score = cell.cellscore()
found_new_cell = True
cv2.imshow("Newest Cell", cv2.cvtColor(np.copy(frame), cv2.COLOR_BGR2RGB))
cv2.waitKey(1)
if found_new_cell and iterations > 0:
restore_cell.times_chosen_since_new = 0
restore_cell.score = restore_cell.cellscore()
iterations += 1
scores = np.array([cell.score for cell in archive.values()])
hashes = [cellhash for cellhash in archive.keys()]
probs = scores / scores.sum()
restore = np.random.choice(hashes, p = probs)
restore_cell = archive[restore]
ram, score, trajectory = restore_cell.choose()
env.reset()
env.em.set_state(ram)
print ("Iterations: %d, Cells: %d, Frames: %d, Max Reward: %d" % (iterations, len(archive), frames, highscore))