diff --git a/src/fastcs/transport/graphQL/graphQL.py b/src/fastcs/transport/graphQL/graphQL.py index f1a6777d..85bde07b 100644 --- a/src/fastcs/transport/graphQL/graphQL.py +++ b/src/fastcs/transport/graphQL/graphQL.py @@ -8,7 +8,13 @@ from strawberry.types.field import StrawberryField from fastcs.attributes import AttrR, AttrRW, AttrW, T -from fastcs.controller import BaseController, Controller +from fastcs.controller import ( + BaseController, + Controller, + SingleMapping, + _get_single_mapping, +) +from fastcs.exceptions import FastCSException from .options import GraphQLServerOptions @@ -16,18 +22,11 @@ class GraphQLServer: def __init__(self, controller: Controller): self._controller = controller - self._fields_tree: FieldTree = FieldTree("") self._app = self._create_app() def _create_app(self) -> GraphQL: - _add_attribute_operations(self._fields_tree, self._controller) - _add_command_mutations(self._fields_tree, self._controller) - - schema_kwargs = {} - for key in ["query", "mutation"]: - if s_type := self._fields_tree.create_strawberry_type(key): - schema_kwargs[key] = s_type - schema = strawberry.Schema(**schema_kwargs) # type: ignore + api = GraphQLAPI(self._controller) + schema = api.create_schema() app = GraphQL(schema) return app @@ -44,10 +43,83 @@ def run(self, options: GraphQLServerOptions | None = None) -> None: ) +class GraphQLAPI: + """A Strawberry API built dynamically from a Controller""" + + def __init__(self, controller: BaseController): + self.queries: list[StrawberryField] = [] + self.mutations: list[StrawberryField] = [] + + api = _get_single_mapping(controller) + + self._process_attributes(api) + self._process_commands(api) + self._process_sub_controllers(api) + + def _process_attributes(self, api: SingleMapping): + """Create queries and mutations from api attributes.""" + for attr_name, attribute in api.attributes.items(): + match attribute: + # mutation for server changes https://graphql.org/learn/queries/ + case AttrRW(): + self.queries.append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + self.mutations.append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + case AttrR(): + self.queries.append( + strawberry.field(_wrap_attr_get(attr_name, attribute)) + ) + case AttrW(): + self.mutations.append( + strawberry.mutation(_wrap_attr_set(attr_name, attribute)) + ) + + def _process_commands(self, api: SingleMapping): + """Create mutations from api commands""" + for cmd_name, method in api.command_methods.items(): + self.mutations.append( + strawberry.mutation(_wrap_command(cmd_name, method.fn, api.controller)) + ) + + def _process_sub_controllers(self, api: SingleMapping): + """Recursively add fields from the queries and mutations of sub controllers""" + for sub_controller in api.controller.get_sub_controllers().values(): + name = "".join(sub_controller.path) + child_tree = GraphQLAPI(sub_controller) + if child_tree.queries: + self.queries.append( + _wrap_as_field( + name, create_type(f"{name}Query", child_tree.queries) + ) + ) + if child_tree.mutations: + self.mutations.append( + _wrap_as_field( + name, create_type(f"{name}Mutation", child_tree.mutations) + ) + ) + + def create_schema(self) -> strawberry.Schema: + """Create a Strawberry Schema to load into a GraphQL application.""" + if not self.queries: + raise FastCSException( + "Can't create GraphQL transport from Controller with no read attributes" + ) + + query = create_type("Query", self.queries) + mutation = create_type("Mutation", self.mutations) if self.mutations else None + + return strawberry.Schema(query=query, mutation=mutation) + + def _wrap_attr_set( - attr_name: str, - attribute: AttrW[T], + attr_name: str, attribute: AttrW[T] ) -> Callable[[T], Coroutine[Any, Any, None]]: + """Wrap an attribute in a function with annotations for strawberry""" + async def _dynamic_f(value): await attribute.process(value) return value @@ -61,9 +133,10 @@ async def _dynamic_f(value): def _wrap_attr_get( - attr_name: str, - attribute: AttrR[T], + attr_name: str, attribute: AttrR[T] ) -> Callable[[], Coroutine[Any, Any, Any]]: + """Wrap an attribute in a function with annotations for strawberry""" + async def _dynamic_f() -> Any: return attribute.get() @@ -73,101 +146,23 @@ async def _dynamic_f() -> Any: return _dynamic_f -def _wrap_as_field( - field_name: str, - strawberry_type: type, -) -> StrawberryField: +def _wrap_as_field(field_name: str, operation: type) -> StrawberryField: + """Wrap a strawberry type as a field of a parent type""" + def _dynamic_field(): - return strawberry_type() + return operation() _dynamic_field.__name__ = field_name - _dynamic_field.__annotations__["return"] = strawberry_type + _dynamic_field.__annotations__["return"] = operation return strawberry.field(_dynamic_field) -class FieldTree: - def __init__(self, name: str): - self.name = name - self.children: dict[str, FieldTree] = {} - self.fields: dict[str, list[StrawberryField]] = { - "query": [], - "mutation": [], - } - - def insert(self, path: list[str]) -> "FieldTree": - # Create child if not exist - name = path.pop(0) - if child := self.get_child(name): - pass - else: - child = FieldTree(name) - self.children[name] = child - - # Recurse if needed - if path: - return child.insert(path) - else: - return child - - def get_child(self, name: str) -> "FieldTree | None": - if name in self.children: - return self.children[name] - else: - return None - - def create_strawberry_type(self, strawberry_type: str) -> type | None: - for child in self.children.values(): - if new_type := child.create_strawberry_type(strawberry_type): - child_field = _wrap_as_field( - child.name, - new_type, - ) - self.fields[strawberry_type].append(child_field) - - if self.fields[strawberry_type]: - return create_type( - f"{self.name}{strawberry_type}", self.fields[strawberry_type] - ) - else: - return None - - -def _add_attribute_operations( - fields_tree: FieldTree, - controller: Controller, -) -> None: - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path - if path: - node = fields_tree.insert(path) - else: - node = fields_tree - - if node is not None: - for attr_name, attribute in single_mapping.attributes.items(): - match attribute: - # mutation for server changes https://graphql.org/learn/queries/ - case AttrRW(): - node.fields["query"].append( - strawberry.field(_wrap_attr_get(attr_name, attribute)) - ) - node.fields["mutation"].append( - strawberry.mutation(_wrap_attr_set(attr_name, attribute)) - ) - case AttrR(): - node.fields["query"].append( - strawberry.field(_wrap_attr_get(attr_name, attribute)) - ) - case AttrW(): - node.fields["mutation"].append( - strawberry.mutation(_wrap_attr_set(attr_name, attribute)) - ) - - def _wrap_command( method_name: str, method: Callable, controller: BaseController ) -> Callable[..., Awaitable[bool]]: + """Wrap a command in a function with annotations for strawberry""" + async def _dynamic_f() -> bool: await getattr(controller, method.__name__)() return True @@ -175,24 +170,3 @@ async def _dynamic_f() -> bool: _dynamic_f.__name__ = method_name return _dynamic_f - - -def _add_command_mutations(fields_tree: FieldTree, controller: Controller) -> None: - for single_mapping in controller.get_controller_mappings(): - path = single_mapping.controller.path - if path: - node = fields_tree.insert(path) - else: - node = fields_tree - - if node is not None: - for cmd_name, method in single_mapping.command_methods.items(): - node.fields["mutation"].append( - strawberry.mutation( - _wrap_command( - cmd_name, - method.fn, - single_mapping.controller, - ) - ) - )