diff --git a/pyproject.toml b/pyproject.toml index 3503e3a..2c6f253 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,9 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "click" + "click", + "defusedxml", + "requests" ] dynamic = ["version"] @@ -42,6 +44,7 @@ path = "src/hipercow/__about__.py" dependencies = [ "pytest", "pytest-cov", + "responses", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -65,6 +68,7 @@ extra-dependencies = [ "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", + "types-defusedxml", ] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src tests}" diff --git a/src/hipercow/dide/web.py b/src/hipercow/dide/web.py new file mode 100644 index 0000000..b7b5cd3 --- /dev/null +++ b/src/hipercow/dide/web.py @@ -0,0 +1,286 @@ +import base64 +import datetime +import re +from dataclasses import dataclass +from subprocess import list2cmdline +from urllib.parse import urljoin + +import requests +from defusedxml import ElementTree + +from hipercow.task import TaskStatus + + +def encode64(x: str) -> str: + return base64.b64encode(x.encode("utf-8")).decode("utf-8") + + +def decode64(x: str) -> str: + return base64.b64decode(x).decode("utf-8") + + +@dataclass +class Credentials: + username: str + password: str + + +@dataclass +class DideTaskStatus: + dide_id: str + name: str + status: TaskStatus + resources: str + user: str + time_start: float + time_end: float + time_submit: float + template: str + + @staticmethod + def from_string(entry): + els = entry.strip().split("\t") + els[2] = _parse_dide_status(els[2]) + els[4] = els[4].replace("DIDE\\", "") + for i in range(5, 8): + els[i] = _parse_dide_timestamp(els[i]) + return DideTaskStatus(*els) + + +class DideHTTPClient(requests.Session): + _has_logged_in = False + _credentials: Credentials + + def __init__(self, credentials: Credentials): + super().__init__() + self._credentials = credentials + + def request(self, method, path, *args, public=False, **kwargs): + if not public and not self._has_logged_in: + self.login() + base_url = "https://mrcdata.dide.ic.ac.uk/hpc/" + url = urljoin(base_url, path) + headers = {"Accept": "text/plain"} if method == "POST" else {} + response = super().request( + method, url, *args, headers=headers, **kwargs + ) + # To debug requests, you can do: + # from requests_toolbelt.utils import dump + # print(dump.dump_all(response).decode("utf-8")) + response.raise_for_status() + return response + + def login(self) -> None: + data = { + "us": encode64(self._credentials.username), + "pw": encode64(self._credentials.password), + "hpcfunc": encode64("login"), + } + res = self.request("POST", "index.php", data=data, public=True) + no_access = "You don't seem to have any HPC access" + if no_access in res.text: + msg = "You do not have HPC access - please contact Wes" + raise Exception(msg) + self._has_logged_in = True + + def logout(self) -> None: + self.request("GET", "logout.php", public=True) + self._has_logged_in = False + + def username(self) -> str: + return self._credentials.username + + def logged_in(self) -> bool: + return self._has_logged_in + + +class DideWebClient: + def __init__(self, credentials): + self._client = DideHTTPClient(credentials) + self._cluster = "wpia-hn" + + def login(self): + self._client.login() + + def logout(self): + self._client.logout() + + def headnodes(self) -> list[str]: + data = {"user": encode64("")} + response = self._client.request("POST", "_listheadnodes.php", data=data) + return _client_parse_headnodes(response.text) + + def check_access(self) -> None: + _client_check_access(self._cluster, self.headnodes()) + + def logged_in(self) -> bool: + return self._client.logged_in() + + def submit(self, path: str, name: str) -> str: + data = _client_body_submit(path, name, self._cluster) + response = self._client.request("POST", "submit_1.php", data=data) + return _client_parse_submit(response.text) + + def cancel(self, dide_id: str) -> bool: + data = _client_body_cancel(dide_id, self._cluster) + response = self._client.request("POST", "cancel.php", data=data) + return _client_parse_cancel(response.text) + + def log(self, dide_id: str) -> str: + data = _client_body_log(dide_id, self._cluster) + response = self._client.request("POST", "showjobfail.php", data=data) + return _client_parse_log(response.text) + + def status_user(self, state="*") -> list[DideTaskStatus]: + data = _client_body_status_user( + state, self._client.username(), self._cluster + ) + response = self._client.request("POST", "_listalljobs.php", data=data) + return _client_parse_status_user(response.text) + + def status_job(self, dide_id: str) -> TaskStatus: + query = _client_query_status_job(dide_id, self._cluster) + response = self._client.request("GET", "api/v1/get_job_status/", query) + return _client_parse_status_job(response.text) + + def software(self): + response = self._client.request( + "GET", "api/v1/cluster_software", public=True + ) + return _client_parse_software(response.json()) + + +def _client_check_access(cluster: str, valid: list[str]) -> None: + if cluster in valid: + return + if len(valid) == 0: + msg = "You do not have access to any cluster" + elif len(valid) == 1: + msg = f"You do not have access to '{cluster}'; try '{valid[0]}'" + else: + valid_str = ", ".join(valid) + msg = f"You do not have access to '{cluster}'; try one of {valid_str}" + raise Exception(msg) + + +def _client_body_submit(path: str, name: str, cluster: str) -> dict: + # NOTE: list2cmdline is undocumented but needed. + # not documented https://github.com/conan-io/conan/pull/11553/ + path_call = f"call {list2cmdline([path])}" + data = { + "cluster": encode64(cluster), + "template": encode64("AllNodes"), + "jn": encode64(name or ""), # job name + "wd": encode64(""), # work dir - unset as we do this in the .bat + "se": encode64(""), # stderr + "so": encode64(""), # stdout + "jobs": encode64(path_call), + "dep": encode64(""), # dependencies, eventually + "hpcfunc": "submit", + "ver": encode64("hipercow-py"), + } + + # There's quite a bit more here to do with processing resource + # options (but these need to exist!). For now, just set 1 core + # unconditionally and keep going: + data["rc"] = encode64("1") + data["rt"] = encode64("Cores") + + return data + + +def _client_body_cancel(dide_id: str | list[str], cluster: str) -> dict: + if isinstance(dide_id, str): + dide_id = [dide_id] + return { + "cluster": encode64(cluster), + "hpcfunc": encode64("cancel"), + **{"c" + i: i for i in dide_id}, + } + + +def _client_body_log(dide_id: str, cluster: str) -> dict: + return {"cluster": encode64(cluster), "hpcfunc": "showfail", "id": dide_id} + + +def _client_body_status_user(state: str, username: str, cluster: str) -> dict: + return { + "user": encode64(username), + "scheduler": encode64(cluster), + "state": encode64(state), + "jobs": encode64("-1"), + } + + +def _client_query_status_job(dide_id: str, cluster: str) -> dict: + return {"scheduler": cluster, "jobid": dide_id} + + +def _client_parse_headnodes(txt: str) -> list[str]: + txt = txt.strip() + return txt.split("\n") if txt else [] + + +def _client_parse_submit(txt: str) -> str: + m = re.match("^Job has been submitted. ID: +([0-9]+)\\.$", txt.strip()) + if not m: + msg = "Job submission has failed; could be a login error" + raise Exception(msg) + return m.group(1) + + +def _client_parse_cancel(txt: str): + return dict([x.split("\t") for x in txt.strip().split("\n")]) + + +def _client_parse_log(txt: str) -> str: + res = ElementTree.fromstring(txt).find('.//input[@id="res"]') + assert res is not None # noqa: S101 + output = decode64(res.attrib["value"]) + return re.sub("^Output\\s*:\\s*?\n+", "", output) + + +def _client_parse_status_user(txt: str) -> list[DideTaskStatus]: + return [DideTaskStatus.from_string(x) for x in txt.strip().split("\n")] + + +def _client_parse_status_job(txt: str) -> TaskStatus: + return _parse_dide_status(txt.strip()) + + +def _client_parse_software(json: dict) -> dict: + # Likely to change soon when the portal updates + linux = json["linuxsoftware"] + windows = json["software"] + + def process(x): + ret = {} + for el in x: + name = el["name"].lower() + version = el["version"] + omit = {"name", "version"} + value = {k: v for k, v in el.items() if k not in omit} + if name not in ret: + ret[name] = {} + ret[name][version] = value + return ret + + return {"linux": process(linux), "windows": process(windows)} + + +def _parse_dide_status(status: str) -> TaskStatus: + remap = { + "Running": TaskStatus.RUNNING, + "Finished": TaskStatus.SUCCESS, + "Queued": TaskStatus.SUBMITTED, + "Failed": TaskStatus.FAILURE, + "Canceled": TaskStatus.CANCELLED, + "Cancelled": TaskStatus.CANCELLED, + } + return remap[status] + + +def _parse_dide_timestamp(time: str) -> datetime.datetime: + return datetime.datetime.strptime(time, "%Y%m%d%H%M%S").astimezone( + datetime.timezone.utc + ) diff --git a/tests/dide/test_web.py b/tests/dide/test_web.py new file mode 100644 index 0000000..6ff9627 --- /dev/null +++ b/tests/dide/test_web.py @@ -0,0 +1,275 @@ +import datetime +import json + +import pytest +import responses + +from hipercow.dide import web +from hipercow.task import TaskStatus + + +def create_client(*, logged_in=True): + cl = web.DideWebClient(web.Credentials("", "")) + if logged_in: + cl._client._has_logged_in = True + return cl + + +@responses.activate +def test_list_headnodes(): + ## https://stackoverflow.com/questions/40361308/create-a-functioning-response-object + listheadnodes = responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/_listheadnodes.php", + body="foo\nbar\n", + status=200, + ) + cl = create_client() + res = cl.headnodes() + assert res == ["foo", "bar"] + + assert listheadnodes.call_count == 1 + req = listheadnodes.calls[0].request + assert req.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert req.body == "user=" + + +def test_can_parse_headnodes_responses(): + assert web._client_parse_headnodes("") == [] + assert web._client_parse_headnodes("foo\n") == ["foo"] + assert web._client_parse_headnodes("foo\nbar\n") == ["foo", "bar"] + + +@responses.activate +def test_can_get_task_log(): + payload = '
\n' # noqa: E501 + getlog = responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/showjobfail.php", + body=payload, + status=200, + ) + cl = create_client() + res = cl.log("1234") + assert res == "log contents!" + + assert getlog.call_count == 1 + req = getlog.calls[0].request + assert req.body == "cluster=d3BpYS1obg%3D%3D&hpcfunc=showfail&id=1234" + + +@responses.activate +def test_can_get_status_for_user(): + payload = """493420 hipercow-py-test Finished 1 core DIDE\\rfitzjoh 20250129120445 20250129120445 20250129120446 AllNodes +489851 Failed 1 core DIDE\\rfitzjoh 20250127160545 20250127160545 20250127160545 LinuxNodes +489823 Finished 1 core DIDE\\rfitzjoh 20250127160453 20250127160453 20250127160454 LinuxNodes +""" # noqa: E501 + status = responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/_listalljobs.php", + body=payload, + status=200, + ) + cl = create_client() + res = cl.status_user() + assert len(res) == 3 + assert res[0] == web.DideTaskStatus( + "493420", + "hipercow-py-test", + TaskStatus.SUCCESS, + "1 core", + "rfitzjoh", + datetime.datetime(2025, 1, 29, 12, 4, 45, tzinfo=datetime.timezone.utc), + datetime.datetime(2025, 1, 29, 12, 4, 45, tzinfo=datetime.timezone.utc), + datetime.datetime(2025, 1, 29, 12, 4, 46, tzinfo=datetime.timezone.utc), + "AllNodes", + ) + + assert status.call_count == 1 + req = status.calls[0].request + assert ( + req.body + == "user=&scheduler=d3BpYS1obg%3D%3D&state=Kg%3D%3D&jobs=LTE%3D" + ) + + +@responses.activate +def test_can_get_job_status(): + status = responses.add( + responses.GET, + "https://mrcdata.dide.ic.ac.uk/hpc/api/v1/get_job_status/", + body="Failed", + status=200, + ) + cl = create_client() + res = cl.status_job("1234") + assert res == TaskStatus.FAILURE + assert status.call_count == 1 + + +@responses.activate +def test_can_get_software_list(): + payload = { + "software": [ + {"name": "R", "version": "4.2.3", "call": "setr64_4_2_3.bat"}, + {"name": "R", "version": "4.3.1", "call": "setr64_4_3_1.bat"}, + {"name": "python", "version": "3.11", "call": "python311.bat"}, + ], + "linuxsoftware": [ + {"name": "R", "version": "4.2.1", "module": "r/4.2.1"}, + {"name": "python", "version": "3.12", "module": "python/3.12"}, + ], + } + software = responses.add( + responses.GET, + "https://mrcdata.dide.ic.ac.uk/hpc/api/v1/cluster_software", + body=json.dumps(payload), + status=200, + ) + cl = create_client(logged_in=False) + res = cl.software() + assert res == { + "linux": { + "r": {"4.2.1": {"module": "r/4.2.1"}}, + "python": {"3.12": {"module": "python/3.12"}}, + }, + "windows": { + "r": { + "4.2.3": {"call": "setr64_4_2_3.bat"}, + "4.3.1": {"call": "setr64_4_3_1.bat"}, + }, + "python": {"3.11": {"call": "python311.bat"}}, + }, + } + assert software.call_count == 1 + + +@responses.activate +def test_can_submit_task(): + submit = responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/submit_1.php", + body="Job has been submitted. ID: 497979.\n", + status=200, + ) + cl = create_client() + res = cl.submit("1234", "myname") + assert res == "497979" + assert submit.call_count == 1 + + +@responses.activate +def test_can_cancel_task(): + cancel = responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/cancel.php", + body="497979\tWRONG_USER.\n", + status=200, + ) + cl = create_client() + res = cl.cancel("497979") + assert res == {"497979": "WRONG_USER."} + assert cancel.call_count == 1 + + +def test_can_check_access(): + with pytest.raises(Exception, match="You do not have access to any"): + web._client_check_access("wpia-hn", []) + with pytest.raises(Exception, match="You do not have access to 'wpia-hn'"): + web._client_check_access("wpia-hn", ["other"]) + with pytest.raises(Exception, match="try one of a, b"): + web._client_check_access("wpia-hn", ["a", "b"]) + assert web._client_check_access("a", ["a"]) is None + assert web._client_check_access("a", ["a", "b"]) is None + + +def test_throw_if_parse_on_submit_fails(): + with pytest.raises(Exception, match="Job submission has failed"): + web._client_parse_submit("") + + +def test_wrap_ids_as_list_for_cancel(): + base = {"cluster": web.encode64("cl"), "hpcfunc": web.encode64("cancel")} + assert web._client_body_cancel("1", "cl") == {"c1": "1", **base} + assert web._client_body_cancel(["1"], "cl") == {"c1": "1", **base} + assert web._client_body_cancel(["1", "2"], "cl") == { + "c1": "1", + "c2": "2", + **base, + } + + +def test_can_check_if_we_are_logged_in_from_web_client(): + cl = create_client() + assert cl.logged_in() + cl = create_client(logged_in=False) + assert not cl.logged_in() + + +@responses.activate +def test_can_check_if_we_have_access(): + responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/_listheadnodes.php", + body="wpia-hn\n", + status=200, + ) + cl = create_client() + assert cl.check_access() is None + + +@responses.activate +def test_can_log_in_and_out(): + login = responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/index.php", + body="", + status=200, + ) + logout = responses.add( + responses.GET, + "https://mrcdata.dide.ic.ac.uk/hpc/logout.php", + status=200, + ) + cl = create_client(logged_in=False) + cl.login() + assert login.call_count == 1 + assert cl.logged_in() + cl.logout() + assert login.call_count == 1 + assert logout.call_count == 1 + assert not cl.logged_in() + + +@responses.activate +def test_throw_if_user_has_no_access(): + body = "You don't seem to have any HPC access" + responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/index.php", + body=body, + status=200, + ) + cl = create_client(logged_in=False) + with pytest.raises(Exception, match="You do not have HPC access"): + cl.login() + + +@responses.activate +def test_login_if_using_authenticated_endpoints(): + login = responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/index.php", + body="", + status=200, + ) + responses.add( + responses.POST, + "https://mrcdata.dide.ic.ac.uk/hpc/_listheadnodes.php", + body="foo\nbar\n", + status=200, + ) + cl = create_client(logged_in=False) + assert cl.headnodes() == ["foo", "bar"] + assert login.call_count == 1 + assert cl.logged_in()