From cad224881658267a619cd6da235917f4abf7a27d Mon Sep 17 00:00:00 2001 From: Mattia Almansi Date: Thu, 13 Jun 2024 11:55:31 +0200 Subject: [PATCH 1/3] implement nanny and worker plugins --- cads_broker/config.py | 2 ++ cads_broker/dispatcher.py | 48 +++++++++++++++++++++++++++++++++++-- tests/test_20_dispatcher.py | 33 +++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/cads_broker/config.py b/cads_broker/config.py index 37c6b2c7..ae548ba3 100644 --- a/cads_broker/config.py +++ b/cads_broker/config.py @@ -24,6 +24,8 @@ dbsettings = None +TASKS_SUBDIR = "tasks_working_dir" + class SqlalchemySettings(pydantic_settings.BaseSettings): """Postgres-specific API settings. diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index f0754c9a..bb4f14f4 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -2,7 +2,9 @@ import hashlib import io import os +import pathlib import pickle +import shutil import threading import time import traceback @@ -13,10 +15,11 @@ import distributed import sqlalchemy as sa import structlog +from dask.typing import Key from typing_extensions import Iterable try: - from cads_worker import worker + import cads_worker.worker except ModuleNotFoundError: pass @@ -197,6 +200,43 @@ def __init__(self, number_of_workers) -> None: parser.parse_rules(self.rules, self.environment) +def rmtree_if_exists(path: pathlib.Path, **kwargs: Any) -> None: + if path.exists(): + shutil.rmtree(path, **kwargs) + + +class TempDirNannyPlugin(distributed.NannyPlugin): + def setup(self, nanny: distributed.Nanny) -> None: + self.tasks_path = pathlib.Path(nanny.worker_dir) / config.TASKS_SUBDIR + rmtree_if_exists(self.tasks_path) + self.tasks_path.mkdir() + + def teardown(self, nanny: distributed.Nanny) -> None: + rmtree_if_exists(self.tasks_path) + + +class TempDirsWorkerPlugin(distributed.WorkerPlugin): + def setup(self, worker: distributed.Worker) -> None: + self.tasks_path = pathlib.Path(worker.local_directory) / config.TASKS_SUBDIR + + def delete_task_working_dir(self, key: Key) -> None: + rmtree_if_exists(self.tasks_path / str(key)) + + def teardown(self, worker: distributed.Worker) -> None: + for key in worker.state.tasks: + self.delete_task_working_dir(key) + + def transition( + self, + key: Key, + start: distributed.worker_state_machine.TaskStateState, + finish: distributed.worker_state_machine.TaskStateState, + **kwargs: Any, + ) -> None: + if finish in ("memory", "error"): + self.delete_task_working_dir(key) + + @attrs.define class Broker: client: distributed.Client @@ -218,6 +258,10 @@ class Broker: internal_scheduler: Scheduler = Scheduler() queue: Queue = Queue() + def __attrs_post_init__(self): + self.client.register_plugin(TempDirNannyPlugin()) + self.client.register_plugin(TempDirsWorkerPlugin()) + @classmethod def from_address( cls, @@ -563,7 +607,7 @@ def submit_request( ) self.queue.pop(request.request_uid) future = self.client.submit( - worker.submit_workflow, + cads_worker.worker.submit_workflow, key=request.request_uid, setup_code=request.request_body.get("setup_code", ""), entry_point=request.entry_point, diff --git a/tests/test_20_dispatcher.py b/tests/test_20_dispatcher.py index 9aeb38fa..624fc735 100644 --- a/tests/test_20_dispatcher.py +++ b/tests/test_20_dispatcher.py @@ -1,4 +1,5 @@ import datetime +import pathlib import uuid from typing import Any @@ -120,3 +121,35 @@ def mock_get_tasks() -> dict[str, str]: # with pytest.raises(db.NoResultFound): # with session_obj() as session: # db.get_request(dismissed_request_uid, session=session) + + +def test_plugins( + mocker: pytest_mock.plugin.MockerFixture, session_obj: sa.orm.sessionmaker +) -> None: + environment = Environment.Environment() + qos = QoS.QoS(rules=Rule.RuleSet(), environment=environment, rules_hash="") + broker = dispatcher.Broker( + client=CLIENT, + environment=environment, + qos=qos, + address="scheduler-address", + session_maker_read=session_obj, + session_maker_write=session_obj, + ) + + def func() -> pathlib.Path: + worker = distributed.get_worker() + key = worker.get_current_task() + task_path = ( + pathlib.Path(worker.local_directory) / "tasks_working_dir" / str(key) + ) + task_path.mkdir() + return task_path + + future = broker.client.submit(func) + task_path = future.result() + assert not task_path.exists() + + assert task_path.parent.exists() + broker.client.shutdown() + assert not task_path.parent.exists() From 116526a9df2d147cafb3e0e9ad9b945615d5ad4d Mon Sep 17 00:00:00 2001 From: Mattia Almansi Date: Thu, 13 Jun 2024 12:35:16 +0200 Subject: [PATCH 2/3] use utils --- cads_broker/config.py | 2 -- cads_broker/dispatcher.py | 27 ++++++++------------------- cads_broker/utils.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 21 deletions(-) create mode 100644 cads_broker/utils.py diff --git a/cads_broker/config.py b/cads_broker/config.py index ae548ba3..37c6b2c7 100644 --- a/cads_broker/config.py +++ b/cads_broker/config.py @@ -24,8 +24,6 @@ dbsettings = None -TASKS_SUBDIR = "tasks_working_dir" - class SqlalchemySettings(pydantic_settings.BaseSettings): """Postgres-specific API settings. diff --git a/cads_broker/dispatcher.py b/cads_broker/dispatcher.py index bb4f14f4..51235b60 100644 --- a/cads_broker/dispatcher.py +++ b/cads_broker/dispatcher.py @@ -2,9 +2,7 @@ import hashlib import io import os -import pathlib import pickle -import shutil import threading import time import traceback @@ -23,7 +21,7 @@ except ModuleNotFoundError: pass -from cads_broker import Environment, config, factory +from cads_broker import Environment, config, factory, utils from cads_broker import database as db from cads_broker.qos import QoS @@ -200,31 +198,22 @@ def __init__(self, number_of_workers) -> None: parser.parse_rules(self.rules, self.environment) -def rmtree_if_exists(path: pathlib.Path, **kwargs: Any) -> None: - if path.exists(): - shutil.rmtree(path, **kwargs) - - class TempDirNannyPlugin(distributed.NannyPlugin): def setup(self, nanny: distributed.Nanny) -> None: - self.tasks_path = pathlib.Path(nanny.worker_dir) / config.TASKS_SUBDIR - rmtree_if_exists(self.tasks_path) - self.tasks_path.mkdir() + path = utils.rm_task_path(nanny, None) + path.mkdir() def teardown(self, nanny: distributed.Nanny) -> None: - rmtree_if_exists(self.tasks_path) + utils.rm_task_path(nanny, None) class TempDirsWorkerPlugin(distributed.WorkerPlugin): - def setup(self, worker: distributed.Worker) -> None: - self.tasks_path = pathlib.Path(worker.local_directory) / config.TASKS_SUBDIR - - def delete_task_working_dir(self, key: Key) -> None: - rmtree_if_exists(self.tasks_path / str(key)) + def setup(self, worker) -> None: + self.worker = worker def teardown(self, worker: distributed.Worker) -> None: for key in worker.state.tasks: - self.delete_task_working_dir(key) + utils.rm_task_path(worker, key) def transition( self, @@ -234,7 +223,7 @@ def transition( **kwargs: Any, ) -> None: if finish in ("memory", "error"): - self.delete_task_working_dir(key) + utils.rm_task_path(self.worker, key) @attrs.define diff --git a/cads_broker/utils.py b/cads_broker/utils.py new file mode 100644 index 00000000..4a209510 --- /dev/null +++ b/cads_broker/utils.py @@ -0,0 +1,34 @@ +import pathlib +import shutil +from typing import Any + +import distributed +from dask.typing import Key + + +def get_task_path( + worker_or_nanny: distributed.Worker | distributed.Nanny, key: Key | None +) -> pathlib.Path: + if isinstance(worker_or_nanny, distributed.Worker): + root = worker_or_nanny.local_directory + elif isinstance(worker_or_nanny, distributed.Nanny): + root = worker_or_nanny.worker_dir + else: + raise TypeError( + f"`worker_or_nanny` is of the wrong type: {type(worker_or_nanny)}" + ) + path = pathlib.Path(root) / "tasks_working_dir" + if key is not None: + path /= str(key) + return path + + +def rm_task_path( + worker_or_nanny: distributed.Worker | distributed.Nanny, + key: Key | None, + **kwargs: Any, +) -> pathlib.Path: + path = get_task_path(worker_or_nanny, key) + if path.exists(): + shutil.rmtree(path, **kwargs) + return path From 910e1a03b7961d71edfcebdafe159f1be915d030 Mon Sep 17 00:00:00 2001 From: Mattia Almansi Date: Thu, 13 Jun 2024 13:57:54 +0200 Subject: [PATCH 3/3] add comment --- cads_broker/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cads_broker/utils.py b/cads_broker/utils.py index 4a209510..82c80dbb 100644 --- a/cads_broker/utils.py +++ b/cads_broker/utils.py @@ -28,6 +28,7 @@ def rm_task_path( key: Key | None, **kwargs: Any, ) -> pathlib.Path: + # This function is used by cads-worker as well. path = get_task_path(worker_or_nanny, key) if path.exists(): shutil.rmtree(path, **kwargs)