Skip to content

Commit 97f968e

Browse files
committed
some fixes and unit tests
1 parent 73df5d7 commit 97f968e

File tree

3 files changed

+54
-9
lines changed

3 files changed

+54
-9
lines changed

fireworks/core/rocket.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def run(self):
201201

202202
if my_spec.get("_add_launchpad_and_fw_id"):
203203
t.launchpad = self.launchpad
204-
t.fw_id = self.fw_id
204+
t.fw_id = m_fw.fw_id
205205

206206
try:
207207
m_action = t.run_task(my_spec)
@@ -337,7 +337,9 @@ def run(self):
337337
def decorate_fwaction(self, fwaction, my_spec, m_fw, launch_dir):
338338

339339
if my_spec.get("_pass_job_info"):
340-
fwaction.mod_spec.append({"_push": {"_job_info": {"fw_id": m_fw.fw_id, "name": m_fw.name, "launch_dir": launch_dir}}})
340+
job_info = my_spec.get("_job_info", [])
341+
job_info.append({"fw_id": m_fw.fw_id, "name": m_fw.name, "launch_dir": launch_dir})
342+
fwaction.mod_spec.append({"_push_all": {"_job_info": job_info}})
341343

342344
if my_spec.get("_preserve_fworker"):
343345
fwaction.update_spec['_fworker'] = self.fworker.name

fireworks/tests/mongo_tests.py

+36-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from fireworks.user_objects.firetasks.templatewriter_task import TemplateWriterTask
2525
from fw_tutorials.dynamic_wf.fibadd_task import FibonacciAdderTask
2626
from fw_tutorials.firetask.addition_task import AdditionTask
27-
from fireworks.tests.tasks import DummyTask
27+
from fireworks.tests.tasks import DummyFWEnvTask, DummyJobPassTask, DummyLPTask
2828
from fireworks.features.stats import FWStats
2929
import six
3030

@@ -266,20 +266,50 @@ def test_multi_detour(self):
266266
self.assertEqual(set(links[4]), set([2]))
267267
self.assertEqual(set(links[5]), set([2]))
268268

269-
def test_fworkerenv(self):
270-
t = DummyTask()
269+
def test_fw_env(self):
270+
t = DummyFWEnvTask()
271271
fw = Firework(t)
272272
self.lp.add_wf(fw)
273273
launch_rocket(self.lp, self.fworker)
274-
self.assertEqual(self.lp.get_launch_by_id(1).action.stored_data[
275-
'data'],
276-
"hello")
274+
self.assertEqual(self.lp.get_launch_by_id(1).action.stored_data['data'], "hello")
277275
self.lp.add_wf(fw)
278276
launch_rocket(self.lp, FWorker(env={"hello": "world"}))
279277
self.assertEqual(self.lp.get_launch_by_id(2).action.stored_data[
280278
'data'],
281279
"world")
282280

281+
def test_job_info(self):
282+
fw1 = Firework([ScriptTask.from_str('echo "Testing job info"')], spec={"_pass_job_info": True}, fw_id=1)
283+
fw2 = Firework([DummyJobPassTask()], parents=[fw1], spec={"_pass_job_info": True}, fw_id=2)
284+
fw3 = Firework([DummyJobPassTask()], parents=[fw2], fw_id=3)
285+
self.lp.add_wf(Workflow([fw1, fw2, fw3]))
286+
launch_rocket(self.lp, self.fworker)
287+
modified_spec = self.lp.get_fw_by_id(2).spec
288+
self.assertIsNotNone(modified_spec['_job_info'])
289+
self.assertTrue(modified_spec['_job_info'][0].has_key("launch_dir"))
290+
self.assertEqual(modified_spec['_job_info'][0]['name'], 'Unnamed FW')
291+
self.assertEqual(modified_spec['_job_info'][0]['fw_id'], 1)
292+
293+
launch_rocket(self.lp, self.fworker)
294+
modified_spec = self.lp.get_fw_by_id(3).spec
295+
print modified_spec
296+
self.assertEqual(len(modified_spec['_job_info']), 2)
297+
298+
def test_preserve_fworker(self):
299+
fw1 = Firework([ScriptTask.from_str('echo "Testing preserve FWorker"')], spec={"_preserve_fworker": True}, fw_id=1)
300+
fw2 = Firework([DummyJobPassTask()], parents=[fw1], fw_id=2)
301+
self.lp.add_wf(Workflow([fw1, fw2]))
302+
launch_rocket(self.lp, self.fworker)
303+
modified_spec = self.lp.get_fw_by_id(2).spec
304+
self.assertIsNotNone(modified_spec['_fworker'])
305+
306+
def test_add_lp_and_fw_id(self):
307+
fw1 = Firework([DummyLPTask()], spec={"_add_launchpad_and_fw_id": True})
308+
self.lp.add_wf(fw1)
309+
launch_rocket(self.lp, self.fworker)
310+
self.assertEqual(self.lp.get_launch_by_id(1).action.stored_data['fw_id'], 1)
311+
self.assertIsNotNone(self.lp.get_launch_by_id(1).action.stored_data['host'])
312+
283313
def test_spec_copy(self):
284314
task1 = ScriptTask.from_str('echo "Task 1"')
285315
task2 = ScriptTask.from_str('echo "Task 2"')

fireworks/tests/tasks.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,21 @@
2020

2121

2222
@explicit_serialize
23-
class DummyTask(FireTaskBase):
23+
class DummyFWEnvTask(FireTaskBase):
2424

2525
def run_task(self, fw_spec):
2626
data = fw_spec["_fw_env"].get("hello", "hello")
2727
return FWAction(stored_data={"data": data})
28+
29+
30+
@explicit_serialize
31+
class DummyJobPassTask(FireTaskBase):
32+
33+
def run_task(self, fw_spec):
34+
return FWAction(stored_data={"data": fw_spec['_job_info']})
35+
36+
@explicit_serialize
37+
class DummyLPTask(FireTaskBase):
38+
39+
def run_task(self, fw_spec):
40+
return FWAction(stored_data={"fw_id": self.fw_id, "host": self.launchpad.host})

0 commit comments

Comments
 (0)