diff --git a/src/fastcs/controller.py b/src/fastcs/controller.py index fe6863e4..95c64d31 100755 --- a/src/fastcs/controller.py +++ b/src/fastcs/controller.py @@ -1,7 +1,6 @@ from __future__ import annotations from copy import copy -from typing import Any from .attributes import Attribute diff --git a/src/fastcs/exceptions.py b/src/fastcs/exceptions.py index 5e1f375f..64964cbf 100644 --- a/src/fastcs/exceptions.py +++ b/src/fastcs/exceptions.py @@ -1,2 +1,6 @@ class FastCSException(Exception): pass + + +class LaunchError(FastCSException): + pass diff --git a/src/fastcs/main.py b/src/fastcs/main.py index 42efff05..70589a07 100644 --- a/src/fastcs/main.py +++ b/src/fastcs/main.py @@ -1,7 +1,7 @@ +import inspect import json -from inspect import signature from pathlib import Path -from typing import Annotated, Any, get_type_hints +from typing import Annotated, get_type_hints import typer from pydantic import create_model @@ -9,6 +9,7 @@ from .backend import Backend from .controller import Controller +from .exceptions import LaunchError from .transport.adapter import TransportAdapter from .transport.epics.options import EpicsOptions from .transport.tango.options import TangoOptions @@ -51,10 +52,6 @@ def run(self) -> None: self._transport.run() -class LaunchError(Exception): - pass - - def launch(controller_class: type[Controller]) -> None: """ Serves as an entry point for starting FastCS applications. @@ -79,24 +76,29 @@ def __init__(self, my_arg: MyControllerOptions) -> None: if __name__ == "__main__": launch(MyController) """ - args_len = controller_class.__init__.__code__.co_argcount - sig = signature(controller_class.__init__) - if args_len == 1: + _launch(controller_class)() + + +def _launch(controller_class: type[Controller]) -> typer.Typer: + # args_len = controller_class.__init__.__code__.co_argcount + sig = inspect.signature(controller_class.__init__) + args = inspect.getfullargspec(controller_class.__init__)[0] + if len(args) == 1: fastcs_options = create_model( f"{controller_class.__name__}", transport=(EpicsOptions | TangoOptions, ...), __config__={"extra": "forbid"}, ) - elif args_len == 2: + elif len(args) == 2: hints = get_type_hints(controller_class.__init__) - if "self" in hints: - del hints["self"] + # if "self" in hints: + # del hints["self"] if hints: options_type = list(hints.values())[-1] else: raise LaunchError( - f"Expected typehinting in {controller_class.__name__}" - f".__init__ but received {sig}" + f"Expected typehinting in '{controller_class.__name__}" + f".__init__' but received {sig}. Add a typehint for `{args[-1]}`." ) fastcs_options = create_model( f"{controller_class.__name__}", @@ -106,32 +108,30 @@ def __init__(self, my_arg: MyControllerOptions) -> None: ) else: raise LaunchError( - f"Expected up to 2 arguments for {controller_class.__name__}.__init__ " - f"but received {args_len} {sig}" + f"Expected no more than 2 arguments for '{controller_class.__name__}" + f".__init__' but received {len(args)} as `{sig}`" ) - _launch_typer = typer.Typer() + launch_typer = typer.Typer() class _LaunchContext: def __init__(self, controller_class, fastcs_options): self.controller_class = controller_class self.fastcs_options = fastcs_options - @_launch_typer.callback() + @launch_typer.callback() def create_context(ctx: typer.Context): ctx.obj = _LaunchContext( controller_class, fastcs_options, ) - @_launch_typer.command( - help=f"Produce json schema for a {controller_class.__name__}" - ) + @launch_typer.command(help=f"Produce json schema for a {controller_class.__name__}") def schema(ctx: typer.Context): system_schema = ctx.obj.fastcs_options.model_json_schema() print(json.dumps(system_schema, indent=2)) - @_launch_typer.command(help=f"Start up a {controller_class.__name__}") + @launch_typer.command(help=f"Start up a {controller_class.__name__}") def run( ctx: typer.Context, config: Annotated[ @@ -167,4 +167,4 @@ def run( instance.create_docs() instance.run() - _launch_typer() + return launch_typer diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 00000000..7c3387c3 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,80 @@ +import json +from dataclasses import dataclass + +import pytest +from pydantic import create_model +from typer.testing import CliRunner + +from fastcs.controller import Controller +from fastcs.exceptions import LaunchError +from fastcs.main import _launch, launch +from fastcs.transport.epics.options import EpicsOptions +from fastcs.transport.tango.options import TangoOptions + + +@dataclass +class SomeConfig: + name: str + + +class SingleArg(Controller): + def __init__(self): + super().__init__() + + +class NotHinted(Controller): + def __init__(self, arg): + super().__init__() + + +class IsHinted(Controller): + def __init__(self, arg: SomeConfig): + super().__init__() + + +class ManyArgs(Controller): + def __init__(self, arg: SomeConfig, too_many): + super().__init__() + + +runner = CliRunner() + + +def test_is_hinted_schema(): + target_model = create_model( + "IsHinted", + controller=(SomeConfig, ...), + transport=(EpicsOptions | TangoOptions, ...), + __config__={"extra": "forbid"}, + ) + target_dict = target_model.model_json_schema() + + app = _launch(IsHinted) + result = runner.invoke(app, ["schema"]) + assert result.exit_code == 0 + result_dict = json.loads(result.stdout) + + assert result_dict == target_dict + + +def test_not_hinted_schema(): + error = ( + "Expected typehinting in 'NotHinted.__init__' but received " + "(self, arg). Add a typehint for `arg`." + ) + + with pytest.raises(LaunchError) as exc_info: + launch(NotHinted) + assert str(exc_info.value) == error + + +def test_over_defined_schema(): + error = ( + "" + "Expected no more than 2 arguments for 'ManyArgs.__init__' " + "but received 3 as `(self, arg: test_main.SomeConfig, too_many)`" + ) + + with pytest.raises(LaunchError) as exc_info: + launch(ManyArgs) + assert str(exc_info.value) == error