From ef2c5c46c40f03db4a5348af22f6a938d516d22e Mon Sep 17 00:00:00 2001 From: Marcell Nagy Date: Mon, 4 Nov 2024 13:34:18 +0000 Subject: [PATCH] Attempt graphql backend Apply upstream feedback Remove fastapi layer from gql Process review comments Rebase create_type -> create_strawberry_type --- pyproject.toml | 3 +- src/fastcs/launch.py | 10 +- src/fastcs/transport/__init__.py | 2 + src/fastcs/transport/graphQL/__init__.py | 0 src/fastcs/transport/graphQL/adapter.py | 24 +++ src/fastcs/transport/graphQL/graphQL.py | 198 +++++++++++++++++++++++ src/fastcs/transport/graphQL/options.py | 13 ++ tests/transport/graphQL/test_graphQL.py | 158 ++++++++++++++++++ 8 files changed, 406 insertions(+), 2 deletions(-) create mode 100644 src/fastcs/transport/graphQL/__init__.py create mode 100644 src/fastcs/transport/graphQL/adapter.py create mode 100644 src/fastcs/transport/graphQL/graphQL.py create mode 100644 src/fastcs/transport/graphQL/options.py create mode 100644 tests/transport/graphQL/test_graphQL.py diff --git a/pyproject.toml b/pyproject.toml index 2b7722ca..c1663382 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "pvi~=0.10.0", "pytango", "softioc>=4.5.0", + "strawberry-graphql", ] dynamic = ["version"] license.file = "LICENSE" @@ -63,7 +64,7 @@ version_file = "src/fastcs/_version.py" [tool.pyright] typeCheckingMode = "standard" -reportMissingImports = false # Ignore missing stubs in imported modules +reportMissingImports = false # Ignore missing stubs in imported modules [tool.pytest.ini_options] # Run pytest with all our checkers, and don't spam us with massive tracebacks on error diff --git a/src/fastcs/launch.py b/src/fastcs/launch.py index d3d3fcf7..62759423 100644 --- a/src/fastcs/launch.py +++ b/src/fastcs/launch.py @@ -12,11 +12,12 @@ from .exceptions import LaunchError from .transport.adapter import TransportAdapter from .transport.epics.options import EpicsOptions +from .transport.graphQL.options import GraphQLOptions from .transport.rest.options import RestOptions from .transport.tango.options import TangoOptions # Define a type alias for transport options -TransportOptions: TypeAlias = EpicsOptions | TangoOptions | RestOptions +TransportOptions: TypeAlias = EpicsOptions | TangoOptions | RestOptions | GraphQLOptions class FastCS: @@ -36,6 +37,13 @@ def __init__( self._backend.dispatcher, transport_options, ) + case GraphQLOptions(): + from .transport.graphQL.adapter import GraphQLTransport + + self._transport = GraphQLTransport( + controller, + transport_options, + ) case TangoOptions(): from .transport.tango.adapter import TangoTransport diff --git a/src/fastcs/transport/__init__.py b/src/fastcs/transport/__init__.py index 0ca90d43..36f3470e 100644 --- a/src/fastcs/transport/__init__.py +++ b/src/fastcs/transport/__init__.py @@ -2,6 +2,8 @@ from .epics.options import EpicsGUIOptions as EpicsGUIOptions from .epics.options import EpicsIOCOptions as EpicsIOCOptions from .epics.options import EpicsOptions as EpicsOptions +from .graphQL.options import GraphQLOptions as GraphQLOptions +from .graphQL.options import GraphQLServerOptions as GraphQLServerOptions from .rest.options import RestOptions as RestOptions from .rest.options import RestServerOptions as RestServerOptions from .tango.options import TangoDSROptions as TangoDSROptions diff --git a/src/fastcs/transport/graphQL/__init__.py b/src/fastcs/transport/graphQL/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/fastcs/transport/graphQL/adapter.py b/src/fastcs/transport/graphQL/adapter.py new file mode 100644 index 00000000..5b573c02 --- /dev/null +++ b/src/fastcs/transport/graphQL/adapter.py @@ -0,0 +1,24 @@ +from fastcs.controller import Controller +from fastcs.transport.adapter import TransportAdapter + +from .graphQL import GraphQLServer +from .options import GraphQLOptions + + +class GraphQLTransport(TransportAdapter): + def __init__( + self, + controller: Controller, + options: GraphQLOptions | None = None, + ): + self.options = options or GraphQLOptions() + self._server = GraphQLServer(controller) + + def create_docs(self) -> None: + raise NotImplementedError + + def create_gui(self) -> None: + raise NotImplementedError + + def run(self) -> None: + self._server.run(self.options.gql) diff --git a/src/fastcs/transport/graphQL/graphQL.py b/src/fastcs/transport/graphQL/graphQL.py new file mode 100644 index 00000000..f1a6777d --- /dev/null +++ b/src/fastcs/transport/graphQL/graphQL.py @@ -0,0 +1,198 @@ +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any + +import strawberry +import uvicorn +from strawberry.asgi import GraphQL +from strawberry.tools import create_type +from strawberry.types.field import StrawberryField + +from fastcs.attributes import AttrR, AttrRW, AttrW, T +from fastcs.controller import BaseController, Controller + +from .options import GraphQLServerOptions + + +class GraphQLServer: + def __init__(self, controller: Controller): + self._controller = controller + self._fields_tree: FieldTree = FieldTree("") + self._app = self._create_app() + + def _create_app(self) -> GraphQL: + _add_attribute_operations(self._fields_tree, self._controller) + _add_command_mutations(self._fields_tree, self._controller) + + schema_kwargs = {} + for key in ["query", "mutation"]: + if s_type := self._fields_tree.create_strawberry_type(key): + schema_kwargs[key] = s_type + schema = strawberry.Schema(**schema_kwargs) # type: ignore + app = GraphQL(schema) + + return app + + def run(self, options: GraphQLServerOptions | None = None) -> None: + if options is None: + options = GraphQLServerOptions() + + uvicorn.run( + self._app, + host=options.host, + port=options.port, + log_level=options.log_level, + ) + + +def _wrap_attr_set( + attr_name: str, + attribute: AttrW[T], +) -> Callable[[T], Coroutine[Any, Any, None]]: + async def _dynamic_f(value): + await attribute.process(value) + return value + + # Add type annotations for validation, schema, conversions + _dynamic_f.__name__ = attr_name + _dynamic_f.__annotations__["value"] = attribute.datatype.dtype + _dynamic_f.__annotations__["return"] = attribute.datatype.dtype + + return _dynamic_f + + +def _wrap_attr_get( + attr_name: str, + attribute: AttrR[T], +) -> Callable[[], Coroutine[Any, Any, Any]]: + async def _dynamic_f() -> Any: + return attribute.get() + + _dynamic_f.__name__ = attr_name + _dynamic_f.__annotations__["return"] = attribute.datatype.dtype + + return _dynamic_f + + +def _wrap_as_field( + field_name: str, + strawberry_type: type, +) -> StrawberryField: + def _dynamic_field(): + return strawberry_type() + + _dynamic_field.__name__ = field_name + _dynamic_field.__annotations__["return"] = strawberry_type + + return strawberry.field(_dynamic_field) + + +class FieldTree: + def __init__(self, name: str): + self.name = name + self.children: dict[str, FieldTree] = {} + self.fields: dict[str, list[StrawberryField]] = { + "query": [], + "mutation": [], + } + + def insert(self, path: list[str]) -> "FieldTree": + # Create child if not exist + name = path.pop(0) + if child := self.get_child(name): + pass + else: + child = FieldTree(name) + self.children[name] = child + + # Recurse if needed + if path: + return child.insert(path) + else: + return child + + def get_child(self, name: str) -> "FieldTree | None": + if name in self.children: + return self.children[name] + else: + return None + + def create_strawberry_type(self, strawberry_type: str) -> type | None: + for child in self.children.values(): + if new_type := child.create_strawberry_type(strawberry_type): + child_field = _wrap_as_field( + child.name, + new_type, + ) + self.fields[strawberry_type].append(child_field) + + if self.fields[strawberry_type]: + return create_type( + f"{self.name}{strawberry_type}", self.fields[strawberry_type] + ) + else: + return None + + +def _add_attribute_operations( + fields_tree: FieldTree, + controller: Controller, +) -> None: + for single_mapping in controller.get_controller_mappings(): + path = single_mapping.controller.path + if path: + node = fields_tree.insert(path) + else: + node = fields_tree + + if node is not None: + for attr_name, attribute in single_mapping.attributes.items(): + match attribute: + # mutation for server changes https://graphql.org/learn/queries/ + case AttrRW(): + node.fields["query"].append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + node.fields["mutation"].append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + case AttrR(): + node.fields["query"].append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + case AttrW(): + node.fields["mutation"].append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + + +def _wrap_command( + method_name: str, method: Callable, controller: BaseController +) -> Callable[..., Awaitable[bool]]: + async def _dynamic_f() -> bool: + await getattr(controller, method.__name__)() + return True + + _dynamic_f.__name__ = method_name + + return _dynamic_f + + +def _add_command_mutations(fields_tree: FieldTree, controller: Controller) -> None: + for single_mapping in controller.get_controller_mappings(): + path = single_mapping.controller.path + if path: + node = fields_tree.insert(path) + else: + node = fields_tree + + if node is not None: + for cmd_name, method in single_mapping.command_methods.items(): + node.fields["mutation"].append( + strawberry.mutation( + _wrap_command( + cmd_name, + method.fn, + single_mapping.controller, + ) + ) + ) diff --git a/src/fastcs/transport/graphQL/options.py b/src/fastcs/transport/graphQL/options.py new file mode 100644 index 00000000..b1ce2e83 --- /dev/null +++ b/src/fastcs/transport/graphQL/options.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass, field + + +@dataclass +class GraphQLServerOptions: + host: str = "localhost" + port: int = 8080 + log_level: str = "info" + + +@dataclass +class GraphQLOptions: + gql: GraphQLServerOptions = field(default_factory=GraphQLServerOptions) diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py new file mode 100644 index 00000000..05ba00e0 --- /dev/null +++ b/tests/transport/graphQL/test_graphQL.py @@ -0,0 +1,158 @@ +import copy +import json +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +from fastcs.transport.graphQL.adapter import GraphQLTransport + + +def nest_query(path: list[str]) -> str: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_query(queue) + return f"{field} {{ {nesting} }} " + else: + return field + + +def nest_mutation(path: list[str], value: Any) -> str: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_query(queue) + return f"{field} {{ {nesting} }} " + else: + return f"{field}(value: {json.dumps(value)})" + + +def nest_responce(path: list[str], value: Any) -> dict: + queue = copy.deepcopy(path) + field = queue.pop(0) + + if queue: + nesting = nest_responce(queue, value) + return {field: nesting} + else: + return {field: value} + + +class TestGraphQLServer: + @pytest.fixture(scope="class") + def client(self, assertable_controller): + app = GraphQLTransport(assertable_controller)._server._app + return TestClient(app) + + def test_read_int(self, assertable_controller, client): + expect = 0 + path = ["readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_read_write_int(self, assertable_controller, client): + expect = 0 + path = ["readWriteInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_write_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = 9 + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["read_write_int"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_read_write_float(self, assertable_controller, client): + expect = 0 + path = ["readWriteFloat"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_write_float"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = 0.5 + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["read_write_float"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_read_bool(self, assertable_controller, client): + expect = False + path = ["readBool"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["read_bool"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_write_bool(self, assertable_controller, client): + value = True + path = ["writeBool"] + mutation = f"mutation {{ {nest_mutation(path, value)} }}" + with assertable_controller.assert_write_here(["write_bool"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, value) + + def test_string_enum(self, assertable_controller, client): + expect = "" + path = ["stringEnum"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["string_enum"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + new = "new" + mutation = f"mutation {{ {nest_mutation(path, new)} }}" + with assertable_controller.assert_write_here(["string_enum"]): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, new) + + def test_big_enum(self, assertable_controller, client): + expect = 0 + path = ["bigEnum"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["big_enum"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_go(self, assertable_controller, client): + path = ["go"] + mutation = f"mutation {{ {nest_query(path)} }}" + with assertable_controller.assert_execute_here(path): + response = client.post("/graphql", json={"query": mutation}) + assert response.status_code == 200 + assert response.json()["data"] == {path[-1]: True} + + def test_read_child1(self, assertable_controller, client): + expect = 0 + path = ["SubController01", "readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["SubController01", "read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect) + + def test_read_child2(self, assertable_controller, client): + expect = 0 + path = ["SubController02", "readInt"] + query = f"query {{ {nest_query(path)} }}" + with assertable_controller.assert_read_here(["SubController02", "read_int"]): + response = client.post("/graphql", json={"query": query}) + assert response.status_code == 200 + assert response.json()["data"] == nest_responce(path, expect)