Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Properly jsanitize fireworks Task #544

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/jobflow/managers/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import typing

from fireworks import FiretaskBase, Firework, FWAction, Workflow, explicit_serialize
from fireworks.utilities.fw_serializers import recursive_serialize, serialize_fw
from monty.json import jsanitize

if typing.TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -197,3 +199,16 @@ def run_task(self, fw_spec):
defuse_workflow=response.stop_jobflow,
defuse_children=response.stop_children,
)

@serialize_fw
@recursive_serialize
def to_dict(self) -> dict:
"""
Serialize version of the FireTask.

Overrides the original method to explicitly jsanitize the Job
to handle cases not properly handled by fireworks, like a Callable.
"""
d = dict(self)
d["job"] = jsanitize(d["job"].as_dict())
return d
22 changes: 22 additions & 0 deletions tests/managers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,3 +399,25 @@ def _gen():
return Flow([replace, simple], simple.output, order=JobOrder.LINEAR)

return _gen


@pytest.fixture(scope="session")
def maker_with_callable():
from dataclasses import dataclass
from typing import Callable

from jobflow.core.job import job
from jobflow.core.maker import Maker

global TestCallableMaker

@dataclass
class TestCallableMaker(Maker):
f: Callable
name: str = "TestCallableMaker"

@job
def make(self, a, b):
return self.f([a, b])

return TestCallableMaker
33 changes: 33 additions & 0 deletions tests/managers/test_fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,36 @@ def test_external_reference(lpad, mongo_jobstore, fw_dir, simple_job, capsys):
# check response
result2 = mongo_jobstore.query_one({"uuid": uuid2})
assert result2["output"] == "12345_end_end"


def test_maker_flow(lpad, mongo_jobstore, fw_dir, maker_with_callable, capsys):
from fireworks.core.rocket_launcher import rapidfire

from jobflow.core.flow import Flow
from jobflow.managers.fireworks import flow_to_workflow

j = maker_with_callable(f=sum).make(a=1, b=2)

flow = Flow([j])
uuid = flow[0].uuid

wf = flow_to_workflow(flow, mongo_jobstore)
fw_ids = lpad.add_wf(wf)

# run the workflow
rapidfire(lpad)

# check workflow completed
fw_id = next(iter(fw_ids.values()))
wf = lpad.get_wf_by_fw_id(fw_id)

assert all(s == "COMPLETED" for s in wf.fw_states.values())

# check store has the activity output
result = mongo_jobstore.query_one({"uuid": uuid})
assert result["output"] == 3

# check logs printed
captured = capsys.readouterr()
assert "INFO Starting job - TestCallableMaker" in captured.out
assert "INFO Finished job - TestCallableMaker" in captured.out
Loading