Skip to content

Commit

Permalink
feature and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-dot committed Feb 20, 2025
1 parent bec48e6 commit a950589
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,20 @@ def get_device(name: str) -> DeviceModel:
return DeviceModel.from_device(context().devices[name])


def get_all_devices_using_interface(interface_name: str) -> list[DeviceModel]:
"""Retrieve device by protocol from the BlueskyContext"""
interface_class = globals().get(interface_name)
if interface_class is None:
return []
devices = context().devices
results: list[DeviceModel] = []
for device in devices.values():
if isinstance(device, interface_class):
results.append(DeviceModel.from_device(device))

return results


def submit_task(task: Task) -> str:
"""Submit a task to be run on begin_task"""
return worker().submit_task(task)
Expand Down
12 changes: 12 additions & 0 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,18 @@ def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
return runner.run(interface.get_device, name)


@router.get(
"/devices/", # Endpoint to filter devices by protocol
response_model=list[str],
)
@start_as_current_span(TRACER, "protocol_name")
def get_devices_by_protocol(
protocol_name: str, runner: WorkerDispatcher = Depends(_runner)
):
"""Retrieve all devices that implement the given protocol."""
return runner.run(get_devices_by_protocol(protocol_name))


example_task = Task(name="count", params={"detectors": ["x"]})


Expand Down
29 changes: 29 additions & 0 deletions tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,35 @@ def test_get_device(context_mock: MagicMock):
assert interface.get_device("non_existing_device")


@patch("blueapi.service.interface.context")
def test_get_devices_by_protocol(context_mock: MagicMock):
context = BlueskyContext()
context.register_device(SynAxis(name="my_axis"))
context_mock.return_value = context

assert interface.get_all_devices_using_interface("Pausable") == [
DeviceModel(
name="my_axis",
protocols=[
"Checkable",
"HasHints",
"HasName",
"HasParent",
"Movable",
"Pausable",
"Readable",
"Stageable",
"Stoppable",
"Subscribable",
"Configurable",
"Triggerable",
],
),
]

assert interface.get_all_devices_using_interface("non_existing_interface") == []


@patch("blueapi.service.interface.context")
def test_submit_task(context_mock: MagicMock):
context = BlueskyContext()
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from ophyd.sim import SynAxis
from pydantic import BaseModel, ValidationError
from pydantic_core import InitErrorDetails
from super_state_machine.errors import TransitionError
Expand Down Expand Up @@ -160,6 +161,36 @@ class MyDevice:
}


def test_get_device_by_protocol(mock_runner: Mock, client: TestClient) -> None:
sya = SynAxis(name="my_axis")
mock_runner.run.return_value = DeviceModel.from_device(sya)
response = client.get("/devices?protocol_name=Pausable")

mock_runner.run.assert_called_once_with(test_get_device_by_protocol, "Pausable")
assert response.status_code == status.HTTP_200_OK
assert response.json() == {
"name": "my-device",
"protocols": ["HasName"],
}
assert response.json() == {
"name": "my_axis",
"protocols": [
"Checkable",
"HasHints",
"HasName",
"HasParent",
"Movable",
"Pausable",
"Readable",
"Stageable",
"Stoppable",
"Subscribable",
"Configurable",
"Triggerable",
],
}


def test_get_non_existent_device_by_name(mock_runner: Mock, client: TestClient) -> None:
mock_runner.run.side_effect = KeyError("my-device")
response = client.get("/devices/my-device")
Expand Down

0 comments on commit a950589

Please sign in to comment.