-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsolve.py
242 lines (191 loc) · 8.42 KB
/
solve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# standard packages
import warnings
import os
import time
import pickle
import json
from argparse import ArgumentParser, Namespace
# custom modules
from asp import params
from modules.api import FlatlandPlan, FlatlandReplan
from modules.convert import convert_malfunctions_to_clingo, convert_formers_to_clingo, convert_futures_to_clingo
# clingo
import clingo
from clingo.application import Application, clingo_main
# rendering visualizations
from flatland.utils.rendertools import RenderTool
import imageio.v2 as imageio
from PIL import Image, ImageDraw, ImageFont
class MalfunctionManager():
def __init__(self, num_agents):
self.num_agents = num_agents
self.malfunctions = []
def get(self) -> list:
""" get the list of malfunctions """
return(self.malfunctions)
def deduct(self) -> None:
""" decrease the duration of each malfunction by one and delete expired malfunctions """
malfunctions_to_remove = []
for i, malf in enumerate(self.malfunctions):
self.malfunctions[i] = (self.malfunctions[i][0], self.malfunctions[i][1] - 1)
if self.malfunctions[i][1] == 0:
malfunctions_to_remove.append(i)
# delete expired malfunctions
for i in sorted(malfunctions_to_remove, reverse=True):
del self.malfunctions[i]
def check(self, info) -> set:
""" check current state of the env for new malfunctions """
malfunctioning_info = info['malfunction']
malfunctioning_trains = {train for train, duration in malfunctioning_info.items() if duration > 0}
existing = {malf[0] for malf in self.malfunctions}
new = malfunctioning_trains.difference(existing)
# add new ones to malfunctions
for train in new:
self.malfunctions.append((train, malfunctioning_info[train]))
return(new)
class SimulationManager():
def __init__(self,env,primary,secondary=None):
self.env = env
self.primary = primary
if secondary is None:
self.secondary = primary
else:
self.secondary = secondary
def build_actions(self) -> list:
""" create initial list of actions """
# pass env, primary
app = FlatlandPlan(self.env, None)
clingo_main(app, self.primary)
return(app.action_list)
def provide_context(self, actions, timestep, malfunctions) -> str:
""" provide additional facts when updating list """
# actions that have already been executed
# wait actions that are enforced because of malfunctions
# future actions that were previously planned
past = convert_formers_to_clingo(actions[:timestep])
present = convert_malfunctions_to_clingo(malfunctions, timestep)
future = convert_futures_to_clingo(actions[timestep:])
return(past + present + future)
def update_actions(self, context) -> list:
""" update list of actions following malfunction """
# pass env, secondary, context
app = FlatlandPlan(self.env, context)
clingo_main(app, self.primary)
return(app.action_list)
class OutputLogManager():
def __init__(self) -> None:
self.logs = []
def add(self,info) -> None:
""" add info from a timestep to the log """
self.logs.append(info)
def save(self,filename) -> None:
""" save output log to local drive """
#with open(f"output/{filename}/paths.json", "w") as f:
# f.write(json.dumps(self.logs))
with open(f"output/{filename}/paths.csv", "w") as f:
f.write("agent;timestep;position;direction;status;given_command\n")
for log in self.logs:
f.write(log)
def check_params(par):
"""
verify that all parameters exist before proceedingd
"""
required_params = {
"primary": list
#"secondary": list
}
# check that all required parameters exist and have the correct type
for param, expected_type in required_params.items():
if not hasattr(par, param):
raise ValueError(f"Required parameter '{param}' is missing from the params module")
else:
# check for correct types
value = getattr(par, param)
if not isinstance(value, expected_type):
raise TypeError(f"Parameter '{param}' should be of type {expected_type.__name__}, but got {type(value).__name__}")
return True
def get_args():
""" capture command line inputs """
parser = ArgumentParser()
parser.add_argument('env', type=str, default='', nargs=1, help='the Flatland environment as a .pkl file')
parser.add_argument('--no-render', action='store_true', help='if included, run the Flatland simulation but do not render a GIF')
return(parser.parse_args())
def main():
# dev test main
if check_params(params):
args: Namespace = get_args()
env = pickle.load(open(args.env[0], "rb"))
no_render = args.no_render
# create manager objects
mal = MalfunctionManager(env.get_num_agents())
sim = SimulationManager(env, params.primary, params.secondary)
log = OutputLogManager()
# envrionment rendering
env_renderer = None
if not no_render:
env_renderer = RenderTool(env, gl="PILSVG")
env_renderer.reset()
images = []
# create directory
os.makedirs("tmp/frames", exist_ok=True)
action_map = {1:'move_left',2:'move_forward',3:'move_right',4:'wait'}
state_map = {0:'waiting', 1:'ready to depart', 2:'malfunction (off map)', 3:'moving', 4:'stopped', 5:'malfunction (on map)', 6:'done'}
dir_map = {0:'n', 1:'e', 2:'s', 3:'w'}
actions = sim.build_actions()
timestep = 0
while len(actions) > timestep:
# add to the log
for a in actions[timestep]:
log.add(f'{a};{timestep};{env.agents[a].position};{dir_map[env.agents[a].direction]};{state_map[env.agents[a].state]};{action_map[actions[timestep][a]]}\n')
_, _, done, info = env.step(actions[timestep])
# end if simulation is finished
if done['__all__'] and timestep < len(actions)-1:
warnings.warn('Simulation has reached its end before actions list has been exhausted.')
break
# check for new malfunctions
new_malfs = mal.check(info)
if len(new_malfs) > 0:
context = sim.provide_context(actions, timestep, mal.get())
actions = sim.update_actions(context)
mal.deduct() #??? where in the loop should this go - before context?
# render an image
filename = 'tmp/frames/flatland_frame_{:04d}.png'.format(timestep)
if env_renderer is not None:
env_renderer.render_env(show=True, show_observations=False, show_predictions=False)
env_renderer.gl.save_image(filename)
env_renderer.reset()
# add red numbers in the corner
with Image.open(filename) as img:
draw = ImageDraw.Draw(img)
padding = 10
font_size = int(min(img.width, img.height) * 0.10)
try:
font = ImageFont.truetype("modules/LiberationMono-Regular.ttf", font_size)
except IOError:
font = ImageFont.load_default()
# prepare text
text = f"{timestep}"
size = font.getbbox(text)
text_width = size[2]-size[0]
text_position = (img.width - text_width - padding, padding)
# draw text borders
x, y = text_position
border_color = "black"
for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
draw.text((x + dx, y + dy), text, fill=border_color, font=font)
# draw text
draw.text(text_position, text, fill="red", font=font)
img.save(filename)
images.append(imageio.imread(filename))
# images.append(imageio.imread(filename))
timestep = timestep + 1
# get time stamp for gif and output log
stamp = time.time()
os.makedirs(f"output/{stamp}", exist_ok=True)
# combine images into gif
if not no_render:
imageio.mimsave(f"output/{stamp}/animation.gif", images, format='GIF', loop=0, duration=240)
# save output log
log.save(stamp)
if __name__ == "__main__":
main()