Skip to content

Commit

Permalink
Add launcher testing, tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
marcelldls committed Nov 22, 2024
1 parent 6426fea commit 604d9c3
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 24 deletions.
1 change: 0 additions & 1 deletion src/fastcs/controller.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from copy import copy
from typing import Any

from .attributes import Attribute

Expand Down
4 changes: 4 additions & 0 deletions src/fastcs/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
class FastCSException(Exception):
pass


class LaunchError(FastCSException):
pass
46 changes: 23 additions & 23 deletions src/fastcs/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import inspect
import json
from inspect import signature
from pathlib import Path
from typing import Annotated, Any, get_type_hints
from typing import Annotated, get_type_hints

import typer
from pydantic import create_model
from ruamel.yaml import YAML

from .backend import Backend
from .controller import Controller
from .exceptions import LaunchError
from .transport.adapter import TransportAdapter
from .transport.epics.options import EpicsOptions
from .transport.tango.options import TangoOptions
Expand Down Expand Up @@ -51,10 +52,6 @@ def run(self) -> None:
self._transport.run()


class LaunchError(Exception):
pass


def launch(controller_class: type[Controller]) -> None:
"""
Serves as an entry point for starting FastCS applications.
Expand All @@ -79,24 +76,29 @@ def __init__(self, my_arg: MyControllerOptions) -> None:
if __name__ == "__main__":
launch(MyController)
"""
args_len = controller_class.__init__.__code__.co_argcount
sig = signature(controller_class.__init__)
if args_len == 1:
_launch(controller_class)()


def _launch(controller_class: type[Controller]) -> typer.Typer:
# args_len = controller_class.__init__.__code__.co_argcount
sig = inspect.signature(controller_class.__init__)
args = inspect.getfullargspec(controller_class.__init__)[0]
if len(args) == 1:
fastcs_options = create_model(
f"{controller_class.__name__}",
transport=(EpicsOptions | TangoOptions, ...),
__config__={"extra": "forbid"},
)
elif args_len == 2:
elif len(args) == 2:
hints = get_type_hints(controller_class.__init__)
if "self" in hints:
del hints["self"]
# if "self" in hints:
# del hints["self"]
if hints:
options_type = list(hints.values())[-1]
else:
raise LaunchError(
f"Expected typehinting in {controller_class.__name__}"
f".__init__ but received {sig}"
f"Expected typehinting in '{controller_class.__name__}"
f".__init__' but received {sig}. Add a typehint for `{args[-1]}`."
)
fastcs_options = create_model(
f"{controller_class.__name__}",
Expand All @@ -106,32 +108,30 @@ def __init__(self, my_arg: MyControllerOptions) -> None:
)
else:
raise LaunchError(
f"Expected up to 2 arguments for {controller_class.__name__}.__init__ "
f"but received {args_len} {sig}"
f"Expected no more than 2 arguments for '{controller_class.__name__}"
f".__init__' but received {len(args)} as `{sig}`"
)

_launch_typer = typer.Typer()
launch_typer = typer.Typer()

class _LaunchContext:
def __init__(self, controller_class, fastcs_options):
self.controller_class = controller_class
self.fastcs_options = fastcs_options

@_launch_typer.callback()
@launch_typer.callback()
def create_context(ctx: typer.Context):
ctx.obj = _LaunchContext(
controller_class,
fastcs_options,
)

@_launch_typer.command(
help=f"Produce json schema for a {controller_class.__name__}"
)
@launch_typer.command(help=f"Produce json schema for a {controller_class.__name__}")
def schema(ctx: typer.Context):
system_schema = ctx.obj.fastcs_options.model_json_schema()
print(json.dumps(system_schema, indent=2))

@_launch_typer.command(help=f"Start up a {controller_class.__name__}")
@launch_typer.command(help=f"Start up a {controller_class.__name__}")
def run(
ctx: typer.Context,
config: Annotated[
Expand Down Expand Up @@ -167,4 +167,4 @@ def run(
instance.create_docs()
instance.run()

_launch_typer()
return launch_typer
80 changes: 80 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import json
from dataclasses import dataclass

import pytest
from pydantic import create_model
from typer.testing import CliRunner

from fastcs.controller import Controller
from fastcs.exceptions import LaunchError
from fastcs.main import _launch, launch
from fastcs.transport.epics.options import EpicsOptions
from fastcs.transport.tango.options import TangoOptions


@dataclass
class SomeConfig:
name: str


class SingleArg(Controller):
def __init__(self):
super().__init__()


class NotHinted(Controller):
def __init__(self, arg):
super().__init__()


class IsHinted(Controller):
def __init__(self, arg: SomeConfig):
super().__init__()


class ManyArgs(Controller):
def __init__(self, arg: SomeConfig, too_many):
super().__init__()


runner = CliRunner()


def test_is_hinted_schema():
target_model = create_model(
"IsHinted",
controller=(SomeConfig, ...),
transport=(EpicsOptions | TangoOptions, ...),
__config__={"extra": "forbid"},
)
target_dict = target_model.model_json_schema()

app = _launch(IsHinted)
result = runner.invoke(app, ["schema"])
assert result.exit_code == 0
result_dict = json.loads(result.stdout)

assert result_dict == target_dict


def test_not_hinted_schema():
error = (
"Expected typehinting in 'NotHinted.__init__' but received "
"(self, arg). Add a typehint for `arg`."
)

with pytest.raises(LaunchError) as exc_info:
launch(NotHinted)
assert str(exc_info.value) == error


def test_over_defined_schema():
error = (
""
"Expected no more than 2 arguments for 'ManyArgs.__init__' "
"but received 3 as `(self, arg: test_main.SomeConfig, too_many)`"
)

with pytest.raises(LaunchError) as exc_info:
launch(ManyArgs)
assert str(exc_info.value) == error

0 comments on commit 604d9c3

Please sign in to comment.