Skip to content

Commit

Permalink
795 Care about where a plan is sourced (#807)
Browse files Browse the repository at this point in the history
Fixes #795
  • Loading branch information
callumforrester authored Jan 31, 2025
1 parent e18b9be commit b605942
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 6 deletions.
4 changes: 3 additions & 1 deletion docs/how-to/write-plans.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def my_plan(
...
```

The type annotations (e.g. `: str`, `: int`, `-> MsgGenerator`) are required as blueapi uses them to detect that this function is intended to be a plan and generate its runtime API.
## Detection

The type annotations in the example above (e.g. `: str`, `: int`, `-> MsgGenerator`) are required as blueapi uses them to detect that this function is intended to be a plan and generate its runtime API. If there is an [`__all__` dunder](https://docs.python.org/3/tutorial/modules.html#importing-from-a-package) present in the module, blueapi will read that and import anything within that qualifies as a plan, per its type annotations. If not it will read everything in the module that hasn't been imported, for example it will ignore a plan imported from another module.

**Input annotations should be as broad as possible**, the least specific implementation that is sufficient to accomplish the requirements of the plan. For example, if a plan is written to drive a specific motor (`MyMotor`), but only uses the general methods on the [`Movable` protocol](https://blueskyproject.io/bluesky/main/hardware.html#bluesky.protocols.Movable), it should take `Movable` as a parameter annotation rather than `MyMotor`.

Expand Down
16 changes: 14 additions & 2 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

from blueapi import utils
from blueapi.config import EnvironmentConfig, SourceKind
from blueapi.utils import BlueapiPlanModelConfig, load_module_all
from blueapi.utils import (
BlueapiPlanModelConfig,
is_function_sourced_from_module,
load_module_all,
)

from .bluesky_types import (
BLUESKY_PROTOCOLS,
Expand Down Expand Up @@ -99,7 +103,15 @@ def plan_2(...) -> MsgGenerator:
"""

for obj in load_module_all(module):
if is_bluesky_plan_generator(obj):
# The rule here is that we only inspect objects defined in the module
# (as opposed to objects imported from other modules) to determine if
# they are valid plans, unless there is an __all__ defined in the module,
# in which case we only inspect objects listed there, regardless of their
# original source module.
if is_bluesky_plan_generator(obj) and (
hasattr(module, "__all__")
or is_function_sourced_from_module(obj, module)
):
self.register_plan(obj)

def with_device_module(self, module: ModuleType) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/blueapi/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .connect_devices import connect_devices
from .file_permissions import get_owner_gid, is_sgid_set
from .invalid_config_error import InvalidConfigError
from .modules import load_module_all
from .modules import is_function_sourced_from_module, load_module_all
from .serialization import serialize
from .thread_exception import handle_all_exceptions

Expand All @@ -17,4 +17,5 @@
"connect_devices",
"is_sgid_set",
"get_owner_gid",
"is_function_sourced_from_module",
]
17 changes: 16 additions & 1 deletion src/blueapi/utils/modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
import importlib
from collections.abc import Callable, Iterable
from types import ModuleType
from typing import Any

Expand Down Expand Up @@ -34,3 +35,17 @@ def get_named_subset(names: list[str]):
for name, value in mod.__dict__.items():
if not name.startswith("_"):
yield value


def is_function_sourced_from_module(
func: Callable[..., Any], module: ModuleType
) -> bool:
"""
Check if a function is originally from a particular module, useful to detect
whether it actually comes from a nested import.
Args:
func: Object to check
module: Module to check against object
"""
return importlib.import_module(func.__module__) is module
9 changes: 9 additions & 0 deletions tests/unit_tests/core/fake_plan_module_with_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from bluesky.utils import MsgGenerator
from tests.unit_tests.core.fake_plan_module import plan_a, plan_b # noqa: F401


def plan_c(c: bool) -> MsgGenerator[None]: ...
def plan_d(d: int) -> MsgGenerator[int]: ...


__all__ = ["plan_a", "plan_d"]
6 changes: 6 additions & 0 deletions tests/unit_tests/core/fake_plan_module_with_imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from bluesky.utils import MsgGenerator
from tests.unit_tests.core.fake_plan_module import plan_a, plan_b # noqa: F401


def plan_c(c: bool) -> MsgGenerator[None]: ...
def plan_d(d: int) -> MsgGenerator[int]: ...
14 changes: 14 additions & 0 deletions tests/unit_tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,20 @@ def test_add_plan_from_module(empty_context: BlueskyContext) -> None:
assert EXPECTED_PLANS == empty_context.plans.keys()


def test_only_plans_from_source_module_detected(empty_context: BlueskyContext) -> None:
import tests.unit_tests.core.fake_plan_module_with_imports as plan_module

empty_context.with_plan_module(plan_module)
assert {"plan_c", "plan_d"} == empty_context.plans.keys()


def test_only_plans_from_all_in_module_detected(empty_context: BlueskyContext) -> None:
import tests.unit_tests.core.fake_plan_module_with_all as plan_module

empty_context.with_plan_module(plan_module)
assert {"plan_a", "plan_d"} == empty_context.plans.keys()


def test_add_named_device(empty_context: BlueskyContext, sim_motor: SynAxis) -> None:
empty_context.register_device(sim_motor)
assert empty_context.devices[SIM_MOTOR_NAME] is sim_motor
Expand Down
4 changes: 4 additions & 0 deletions tests/unit_tests/utils/functions_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def a(): ...


def b(): ...
7 changes: 7 additions & 0 deletions tests/unit_tests/utils/functions_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .functions_a import a, b # noqa: F401


def c(): ...


def d(): ...
16 changes: 15 additions & 1 deletion tests/unit_tests/utils/test_modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from importlib import import_module

from blueapi.utils import load_module_all
from blueapi.utils import is_function_sourced_from_module, load_module_all


def test_imports_all():
Expand All @@ -11,3 +11,17 @@ def test_imports_all():
def test_imports_everything_without_all():
module = import_module(".lacksall", package="tests.unit_tests.utils")
assert list(load_module_all(module)) == [3, "hello", 9]


def test_source_is_in_module():
module = import_module(".functions_b", package="tests.unit_tests.utils")
c = module.__dict__["c"]
assert callable(c)
assert is_function_sourced_from_module(c, module)


def test_source_is_not_in_module():
module = import_module(".functions_b", package="tests.unit_tests.utils")
a = module.__dict__["a"]
assert callable(a)
assert not is_function_sourced_from_module(a, module)

0 comments on commit b605942

Please sign in to comment.