Skip to content

Commit

Permalink
Merge pull request #4 from mrc-ide/mrc-6186
Browse files Browse the repository at this point in the history
Shard task ids, as with the R version
  • Loading branch information
weshinsley authored Jan 24, 2025
2 parents 3c75277 + 48a2abb commit 9d93ea6
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
12 changes: 12 additions & 0 deletions src/hipercow/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def __init__(self, path: str | Path) -> None:
raise Exception(msg)
self.path = path

def path_task(self, task_id: str) -> Path:
return self.path / "tasks" / task_id[:2] / task_id[2:]

def path_task_times(self, task_id: str) -> Path:
return self.path_task(task_id) / "times"

def path_task_data(self, task_id: str) -> Path:
return self.path_task(task_id) / "data"

def path_task_result(self, task_id: str) -> Path:
return self.path_task(task_id) / "result"


def open_root(path: None | str | Path = None) -> Root:
root = find_file_descend("hipercow", path or Path.cwd())
Expand Down
13 changes: 7 additions & 6 deletions src/hipercow/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ class TaskTimes:
finished: float

def write(self, root: Root, task_id: str):
with open(root.path / "tasks" / task_id / "times", "wb") as f:
with root.path_task_times(task_id).open("wb") as f:
pickle.dump(self, f)


def task_status(root: Root, task_id: str) -> TaskStatus:
# check_task_id(task_id)
path = root.path / "tasks" / task_id
path = root.path_task(task_id)
if not path.exists():
return TaskStatus.MISSING
for v, p in STATUS_FILE_MAP.items():
Expand All @@ -60,7 +60,7 @@ def task_status(root: Root, task_id: str) -> TaskStatus:


def set_task_status(root: Root, task_id: str, status: TaskStatus):
file_create(root.path / "tasks" / task_id / STATUS_FILE_MAP[status])
file_create(root.path_task(task_id) / STATUS_FILE_MAP[status])


@dataclass
Expand All @@ -80,12 +80,13 @@ def read(root: Root, task_id: str):


def task_data_write(root: Root, data: TaskData) -> None:
path_task_dir = root.path / "tasks" / data.task_id
task_id = data.task_id
path_task_dir = root.path_task(task_id)
path_task_dir.mkdir(parents=True, exist_ok=True)
with open(path_task_dir / "data", "wb") as f:
with root.path_task_data(task_id).open("wb") as f:
pickle.dump(data, f)


def task_data_read(root: Root, task_id: str) -> TaskData:
with open(root.path / "tasks" / task_id / "data", "rb") as f:
with root.path_task_data(task_id).open("rb") as f:
return pickle.load(f)
4 changes: 2 additions & 2 deletions src/hipercow/task_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def task_eval_data(root: Root, data: TaskData) -> None:
msg = f"Can't run '{task_id}', which has status '{status}'"
raise Exception(msg)

t_created = (root.path / "tasks" / task_id / "data").stat().st_ctime
t_created = root.path_task_data(task_id).stat().st_ctime
t_start = time.time()

set_task_status(root, task_id, TaskStatus.RUNNING)
Expand All @@ -43,7 +43,7 @@ def task_eval_data(root: Root, data: TaskData) -> None:
t_end = time.time()

status = TaskStatus.SUCCESS if res.success else TaskStatus.FAILURE
with open(root.path / "tasks" / task_id / "result", "wb") as f:
with root.path_task_result(task_id).open("wb") as f:
pickle.dump(res.data, f)

times = TaskTimes(t_created, t_start, t_end)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_create_simple_task(tmp_path):
with transient_working_directory(tmp_path):
tid = tc.task_create_shell(["echo", "hello world"])
assert re.match("^[0-9a-f]{32}$", tid)
assert (tmp_path / "tasks" / tid / "data").exists()
assert (tmp_path / "tasks" / tid[:2] / tid[2:] / "data").exists()
d = TaskData.read(root.open_root(tmp_path), tid)
assert isinstance(d, TaskData)
assert d.task_id == tid
Expand Down

0 comments on commit 9d93ea6

Please sign in to comment.