Skip to content

Commit

Permalink
Merge pull request #7 from mrc-ide/mrc-6198
Browse files Browse the repository at this point in the history
Add task info function
  • Loading branch information
weshinsley authored Jan 29, 2025
2 parents 760db73 + 83f534f commit de01484
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
38 changes: 36 additions & 2 deletions src/hipercow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,30 @@ def __str__(self) -> str:
@dataclass
class TaskTimes:
created: float
started: float
finished: float
started: float | None
finished: float | None

def write(self, root: Root, task_id: str):
with root.path_task_times(task_id).open("wb") as f:
pickle.dump(self, f)

@staticmethod
def read(root: Root, task_id: str):
path_times = root.path_task_times(task_id)
if path_times.exists():
with path_times.open("rb") as f:
return pickle.load(f)
created = root.path_task_data(task_id).stat().st_ctime
path_task_running = (
root.path_task(task_id) / STATUS_FILE_MAP[TaskStatus.RUNNING]
)
running = (
path_task_running.stat().st_ctime
if path_task_running.exists()
else None
)
return TaskTimes(created, running, None)


def task_status(root: Root, task_id: str) -> TaskStatus:
# check_task_id(task_id)
Expand Down Expand Up @@ -100,3 +117,20 @@ def task_data_write(root: Root, data: TaskData) -> None:
def task_data_read(root: Root, task_id: str) -> TaskData:
with root.path_task_data(task_id).open("rb") as f:
return pickle.load(f)


@dataclass
class TaskInfo:
status: TaskStatus
data: TaskData
times: TaskTimes


def task_info(root: Root, task_id: str) -> TaskInfo:
status = task_status(root, task_id)
if status == TaskStatus.MISSING:
msg = f"Task '{task_id}' does not exist"
raise Exception(msg)
data = TaskData.read(root, task_id)
times = TaskTimes.read(root, task_id)
return TaskInfo(status, data, times)
48 changes: 47 additions & 1 deletion tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@

from hipercow import root
from hipercow import task_create as tc
from hipercow.task import TaskStatus, set_task_status, task_log, task_status
from hipercow.task import (
TaskData,
TaskStatus,
TaskTimes,
set_task_status,
task_info,
task_log,
task_status,
)
from hipercow.task_eval import task_eval
from hipercow.util import transient_working_directory


Expand Down Expand Up @@ -45,3 +54,40 @@ def test_that_missing_tasks_error_on_log_read(tmp_path):

def test_can_convert_to_nice_string():
assert str(TaskStatus.CREATED) == "created"


def test_read_task_info(tmp_path):
root.init(tmp_path)
r = root.open_root(tmp_path)
with transient_working_directory(tmp_path):
tid = tc.task_create_shell(["echo", "hello world"])
info = task_info(r, tid)
assert info.status == TaskStatus.CREATED
assert info.data == TaskData.read(r, tid)
assert info.times == TaskTimes.read(r, tid)
assert isinstance(info.times.created, float)
assert info.times.started is None
assert info.times.finished is None


def test_that_missing_tasks_error_on_task_info(tmp_path):
root.init(tmp_path)
r = root.open_root(tmp_path)
task_id = "a" * 32
with pytest.raises(Exception, match="Task '.+' does not exist"):
task_info(r, task_id)


def test_that_can_read_info_for_completed_task(tmp_path):
root.init(tmp_path)
r = root.open_root(tmp_path)
with transient_working_directory(tmp_path):
tid = tc.task_create_shell(["echo", "hello world"])
task_eval(r, tid)
info = task_info(r, tid)
assert info.status == TaskStatus.SUCCESS
assert info.data == TaskData.read(r, tid)
assert info.times == TaskTimes.read(r, tid)
assert isinstance(info.times.created, float)
assert isinstance(info.times.started, float)
assert isinstance(info.times.finished, float)

0 comments on commit de01484

Please sign in to comment.