diff --git a/.github/workflows/_container.yml b/.github/workflows/_container.yml index 8688011c3..cdfc9a7d8 100644 --- a/.github/workflows/_container.yml +++ b/.github/workflows/_container.yml @@ -106,7 +106,6 @@ jobs: - name: package chart and push it run: | - sed -i "$ a appVersion: ${GITHUB_REF##*/}" helm/blueapi/Chart.yaml helm dependencies update helm/blueapi - helm package helm/blueapi --version ${GITHUB_REF##*/} -d /tmp/ + helm package helm/blueapi --version ${GITHUB_REF##*/} --app-version ${GITHUB_REF##*/} -d /tmp/ helm push /tmp/blueapi-${GITHUB_REF##*/}.tgz oci://ghcr.io/diamondlightsource/charts diff --git a/.github/workflows/_system_test.yml b/.github/workflows/_system_test.yml new file mode 100644 index 000000000..71a14ae84 --- /dev/null +++ b/.github/workflows/_system_test.yml @@ -0,0 +1,32 @@ +on: + workflow_call: + +env: + # https://github.com/pytest-dev/pytest/issues/2042 + PY_IGNORE_IMPORTMISMATCH: "1" + +jobs: + run: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + # Need this to get version number from last tag + fetch-depth: 0 + + - name: Install python packages + uses: ./.github/actions/install_requirements + + - name: Start RabbitMQ + uses: namoshek/rabbitmq-github-action@v1 + with: + ports: "61613:61613" + plugins: rabbitmq_stomp + + - name: Start Blueapi Server + run: blueapi -c ${{ github.workspace }}/tests/unit_tests/example_yaml/valid_stomp_config.yaml serve & + + - name: Run tests + run: tox -e system-test diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml index 8fc0c8ee3..f652d4145 100644 --- a/.github/workflows/_test.yml +++ b/.github/workflows/_test.yml @@ -16,9 +16,6 @@ on: env: # https://github.com/pytest-dev/pytest/issues/2042 PY_IGNORE_IMPORTMISMATCH: "1" - BLUEAPI_TEST_STOMP_PORTS: "[61613,61614]" - - jobs: run: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 485b82b66..53f66b6d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,6 +34,11 @@ jobs: secrets: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + system-test: + needs: check + if: needs.check.outputs.branch-pr == '' + uses: ./.github/workflows/_system_test.yml + container: needs: check if: needs.check.outputs.branch-pr == '' diff --git a/docs/how-to/edit-live.md b/docs/how-to/edit-live.md index 8f1eb8c7c..af8695585 100644 --- a/docs/how-to/edit-live.md +++ b/docs/how-to/edit-live.md @@ -11,6 +11,8 @@ Blueapi can be configured to install editable Python packages from a chosen dire scratch: root: /path/to/my/scratch/directory + # Required GID for the scratch area + required_gid: 12345 repositories: # Repository for DLS devices - name: dodal @@ -21,6 +23,9 @@ scratch: remote_url: https://github.com/DiamondLightSource/mx-bluesky.git ``` +Note the `required_gid` field, which is useful for stopping blueapi from locking the files it clones +to a particular owner. + ## Synchronization Blueapi will synchronize reality with the configuration if you run diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 6a0ab67a4..2697ec63d 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -1,5 +1,7 @@ import json import logging +import os +import stat import sys from functools import wraps from pathlib import Path @@ -39,6 +41,9 @@ def main(ctx: click.Context, config: Path | None | tuple[Path, ...]) -> None: # if no command is supplied, run with the options passed + # Set umask to DLS standard + os.umask(stat.S_IWOTH) + config_loader = ConfigLoader(ApplicationConfig) if config is not None: configs = (config,) if isinstance(config, Path) else config diff --git a/src/blueapi/cli/scratch.py b/src/blueapi/cli/scratch.py index d731eeeb5..e2161799b 100644 --- a/src/blueapi/cli/scratch.py +++ b/src/blueapi/cli/scratch.py @@ -1,12 +1,14 @@ import logging import os import stat +import textwrap from pathlib import Path from subprocess import Popen from git import Repo from blueapi.config import ScratchConfig +from blueapi.utils import get_owner_gid, is_sgid_set _DEFAULT_INSTALL_TIMEOUT: float = 300.0 @@ -23,7 +25,7 @@ def setup_scratch( install_timeout: Timeout for installing packages """ - _validate_directory(config.root) + _validate_root_directory(config.root, config.required_gid) logging.info(f"Setting up scratch area: {config.root}") @@ -74,9 +76,6 @@ def scratch_install(path: Path, timeout: float = _DEFAULT_INSTALL_TIMEOUT) -> No _validate_directory(path) - # Set umask to DLS standard - os.umask(stat.S_IWOTH) - logging.info(f"Installing {path}") process = Popen( [ @@ -94,6 +93,37 @@ def scratch_install(path: Path, timeout: float = _DEFAULT_INSTALL_TIMEOUT) -> No raise RuntimeError(f"Failed to install {path}: Exit Code: {process.returncode}") +def _validate_root_directory(root_path: Path, required_gid: int | None) -> None: + _validate_directory(root_path) + + if not is_sgid_set(root_path): + raise PermissionError( + textwrap.dedent(f""" + The scratch area root directory ({root_path}) needs to have the + SGID permission bit set. This allows blueapi to clone + repositories into it while retaining the ability for + other users in an approved group to edit/delete them. + + See https://www.redhat.com/en/blog/suid-sgid-sticky-bit for how to + set the SGID bit. + """) + ) + elif required_gid is not None and get_owner_gid(root_path) != required_gid: + raise PermissionError( + textwrap.dedent(f""" + The configuration requires that {root_path} be owned by the group with + ID {required_gid}. + You may be able to find this group's name by running the following + in the terminal. + + getent group 1000 | cut -d: -f1 + + You can transfer ownership, if you have sufficient permissions, with the chgrp + command. + """) + ) + + def _validate_directory(path: Path) -> None: if not path.exists(): raise KeyError(f"{path}: No such file or directory") diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 016925216..e99d02b15 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -1,3 +1,4 @@ +import textwrap from collections.abc import Mapping from enum import Enum from functools import cached_property @@ -84,13 +85,34 @@ class RestConfig(BlueapiBaseModel): class ScratchRepository(BlueapiBaseModel): - name: str = "example" - remote_url: str = "https://github.com/example/example.git" + name: str = Field( + description="Unique name for this repository in the scratch directory", + default="example", + ) + remote_url: str = Field( + description="URL to clone from", + default="https://github.com/example/example.git", + ) class ScratchConfig(BlueapiBaseModel): - root: Path = Path("/tmp/scratch/blueapi") - repositories: list[ScratchRepository] = Field(default_factory=list) + root: Path = Field( + description="The root directory of the scratch area, all repositories will " + "be cloned under this directory.", + default=Path("/tmp/scratch/blueapi"), + ) + required_gid: int | None = Field( + description=textwrap.dedent(""" + Required owner GID for the scratch directory. If supplied the setup-scratch + command will check the scratch area ownership and raise an error if it is + not owned by . + """), + default=None, + ) + repositories: list[ScratchRepository] = Field( + description="Details of repositories to be cloned and imported into blueapi", + default_factory=list, + ) class OIDCConfig(BlueapiBaseModel): diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 3bc586034..37a41f30d 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -59,9 +59,12 @@ def _runner() -> WorkerDispatcher: return RUNNER -def setup_runner(config: ApplicationConfig | None = None, use_subprocess: bool = True): +def setup_runner( + config: ApplicationConfig | None = None, + runner: WorkerDispatcher | None = None, +): global RUNNER - runner = WorkerDispatcher(config, use_subprocess) + runner = runner or WorkerDispatcher(config) runner.start() RUNNER = runner diff --git a/src/blueapi/service/runner.py b/src/blueapi/service/runner.py index be4c49ba6..d8d957055 100644 --- a/src/blueapi/service/runner.py +++ b/src/blueapi/service/runner.py @@ -8,7 +8,6 @@ from typing import Any, ParamSpec, TypeVar from observability_utils.tracing import ( - add_span_attributes, get_context_propagator, get_tracer, start_as_current_span, @@ -44,17 +43,19 @@ class WorkerDispatcher: _config: ApplicationConfig _subprocess: PoolClass | None - _use_subprocess: bool _state: EnvironmentResponse def __init__( self, config: ApplicationConfig | None = None, - use_subprocess: bool = True, + subprocess_factory: Callable[[], PoolClass] | None = None, ) -> None: + def default_subprocess_factory(): + return Pool(initializer=_init_worker, processes=1) + self._config = config or ApplicationConfig() self._subprocess = None - self._use_subprocess = use_subprocess + self._subprocess_factory = subprocess_factory or default_subprocess_factory self._state = EnvironmentResponse( initialized=False, ) @@ -68,12 +69,8 @@ def reload(self): @start_as_current_span(TRACER) def start(self): - add_span_attributes( - {"_use_subprocess": self._use_subprocess, "_config": str(self._config)} - ) try: - if self._use_subprocess: - self._subprocess = Pool(initializer=_init_worker, processes=1) + self._subprocess = self._subprocess_factory() self.run(setup, self._config) self._state = EnvironmentResponse(initialized=True) except Exception as e: @@ -107,40 +104,25 @@ def run( function: Callable[P, T], *args: P.args, **kwargs: P.kwargs, - ) -> T: - """Calls the supplied function, which is modified to accept a dict as it's new - first param, before being passed to the subprocess runner, or just run in place. - """ - add_span_attributes({"use_subprocess": self._use_subprocess}) - if self._use_subprocess: - return self._run_in_subprocess(function, *args, **kwargs) - else: - return function(*args, **kwargs) - - @start_as_current_span(TRACER, "function", "args", "kwargs") - def _run_in_subprocess( - self, - function: Callable[P, T], - *args: P.args, - **kwargs: P.kwargs, ) -> T: """Call the supplied function, passing the current Span ID, if one - exists,from the observability context inro the _rpc caller function. + exists,from the observability context into the import_and_run_function + caller function. + When this is deserialized in and run by the subprocess, this will allow its functions to use the corresponding span as their parent span.""" + if self._subprocess is None: raise InvalidRunnerStateError("Subprocess runner has not been started") if not (hasattr(function, "__name__") and hasattr(function, "__module__")): raise RpcError(f"{function} is anonymous, cannot be run in subprocess") - if not callable(function): - raise RpcError(f"{function} is not Callable, cannot be run in subprocess") try: return_type = inspect.signature(function).return_annotation except TypeError: return_type = None return self._subprocess.apply( - _rpc, + import_and_run_function, ( function.__module__, function.__name__, @@ -164,7 +146,7 @@ def __init__(self, message): class RpcError(Exception): ... -def _rpc( +def import_and_run_function( module_name: str, function_name: str, expected_type: type[T] | None, diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index cf45f6936..602cea9bb 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,5 +1,6 @@ from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig from .connect_devices import connect_devices +from .file_permissions import get_owner_gid, is_sgid_set from .invalid_config_error import InvalidConfigError from .modules import load_module_all from .serialization import serialize @@ -14,4 +15,6 @@ "BlueapiPlanModelConfig", "InvalidConfigError", "connect_devices", + "is_sgid_set", + "get_owner_gid", ] diff --git a/src/blueapi/utils/file_permissions.py b/src/blueapi/utils/file_permissions.py new file mode 100644 index 000000000..0e1921f6d --- /dev/null +++ b/src/blueapi/utils/file_permissions.py @@ -0,0 +1,32 @@ +import stat +from pathlib import Path + + +def is_sgid_set(path: Path) -> bool: + """Check if the SGID bit is set so that new files created + under a directory owned by a group are owned by that same group. + + See https://www.redhat.com/en/blog/suid-sgid-sticky-bit + + Args: + path: Path to the file to check + + Returns: + bool: True if the SGID bit is set + """ + + mask = path.stat().st_mode + return bool(mask & stat.S_ISGID) + + +def get_owner_gid(path: Path) -> int: + """Get the GID of the owner of a file + + Args: + path: Path to the file to check + + Returns: + bool: The GID of the file owner + """ + + return path.stat().st_gid diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 477a637c3..61656d253 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -5,6 +5,7 @@ import pytest from bluesky_stomp.models import BasicAuthentication from pydantic import TypeAdapter +from requests.exceptions import ConnectionError from blueapi.client.client import ( BlueapiClient, @@ -32,6 +33,13 @@ _DATA_PATH = Path(__file__).parent +_REQUIRES_AUTH_MESSAGE = """ +Authentication credentials are required to run this test. +The test has been skipped because authentication is currently disabled. +For more details, see: https://github.com/DiamondLightSource/blueapi/issues/676. +To enable and execute these tests, set `REQUIRES_AUTH=1` and provide valid credentials. +""" + # Step 1: Ensure a message bus that supports stomp is running and available: # src/script/start_rabbitmq.sh # @@ -43,7 +51,7 @@ @pytest.fixture -def client_without_auth(tmp_path) -> BlueapiClient: +def client_without_auth(tmp_path: Path) -> BlueapiClient: return BlueapiClient.from_config(config=ApplicationConfig(auth_token_path=tmp_path)) @@ -58,6 +66,20 @@ def client_with_stomp() -> BlueapiClient: ) +@pytest.fixture(scope="module", autouse=True) +def wait_for_server(): + client = BlueapiClient.from_config(config=ApplicationConfig()) + + for _ in range(20): + try: + client.get_environment() + return + except ConnectionError: + ... + time.sleep(0.5) + raise TimeoutError("No connection to the blueapi server") + + # This client will have auth enabled if it finds cached valid token @pytest.fixture def client() -> BlueapiClient: @@ -101,6 +123,7 @@ def clean_existing_tasks(client: BlueapiClient): yield +@pytest.mark.xfail(reason=_REQUIRES_AUTH_MESSAGE) def test_cannot_access_endpoints( client_without_auth: BlueapiClient, blueapi_client_get_methods: list[str] ): @@ -112,6 +135,7 @@ def test_cannot_access_endpoints( getattr(client_without_auth, get_method)() +@pytest.mark.xfail(reason=_REQUIRES_AUTH_MESSAGE) def test_can_get_oidc_config_without_auth(client_without_auth: BlueapiClient): assert client_without_auth.get_oidc_config() == OIDCConfig( well_known_url="https://example.com/realms/master/.well-known/openid-configuration", diff --git a/tests/unit_tests/cli/test_scratch.py b/tests/unit_tests/cli/test_scratch.py index 29c5f282d..2af956c88 100644 --- a/tests/unit_tests/cli/test_scratch.py +++ b/tests/unit_tests/cli/test_scratch.py @@ -10,6 +10,7 @@ from blueapi.cli.scratch import ensure_repo, scratch_install, setup_scratch from blueapi.config import ScratchConfig, ScratchRepository +from blueapi.utils import get_owner_gid @pytest.fixture @@ -20,8 +21,17 @@ def directory_path() -> Generator[Path]: @pytest.fixture -def file_path(directory_path: Path) -> Generator[Path]: - file_path = directory_path / str(uuid.uuid4()) +def directory_path_with_sgid(directory_path: Path) -> Path: + os.chmod( + directory_path, + os.stat(directory_path).st_mode + stat.S_ISGID, + ) + return directory_path + + +@pytest.fixture +def file_path(directory_path_with_sgid: Path) -> Generator[Path]: + file_path = directory_path_with_sgid / str(uuid.uuid4()) with file_path.open("w") as stream: stream.write("foo") yield file_path @@ -29,8 +39,8 @@ def file_path(directory_path: Path) -> Generator[Path]: @pytest.fixture -def nonexistant_path(directory_path: Path) -> Path: - file_path = directory_path / str(uuid.uuid4()) +def nonexistant_path(directory_path_with_sgid: Path) -> Path: + file_path = directory_path_with_sgid / str(uuid.uuid4()) assert not file_path.exists() return file_path @@ -38,13 +48,13 @@ def nonexistant_path(directory_path: Path) -> Path: @patch("blueapi.cli.scratch.Popen") def test_scratch_install_installs_path( mock_popen: Mock, - directory_path: Path, + directory_path_with_sgid: Path, ): mock_process = Mock() mock_process.returncode = 0 mock_popen.return_value = mock_process - scratch_install(directory_path, timeout=1.0) + scratch_install(directory_path_with_sgid, timeout=1.0) mock_popen.assert_called_once_with( [ @@ -54,7 +64,7 @@ def test_scratch_install_installs_path( "install", "--no-deps", "-e", - str(directory_path), + str(directory_path_with_sgid), ] ) @@ -73,7 +83,7 @@ def test_scratch_install_fails_on_nonexistant_path(nonexistant_path: Path): @pytest.mark.parametrize("code", [1, 2, 65536]) def test_scratch_install_fails_on_non_zero_exit_code( mock_popen: Mock, - directory_path: Path, + directory_path_with_sgid: Path, code: int, ): mock_process = Mock() @@ -81,16 +91,16 @@ def test_scratch_install_fails_on_non_zero_exit_code( mock_popen.return_value = mock_process with pytest.raises(RuntimeError): - scratch_install(directory_path, timeout=1.0) + scratch_install(directory_path_with_sgid, timeout=1.0) @patch("blueapi.cli.scratch.Repo") def test_repo_not_cloned_and_validated_if_found_locally( mock_repo: Mock, - directory_path: Path, + directory_path_with_sgid: Path, ): - ensure_repo("http://example.com/foo.git", directory_path) - mock_repo.assert_called_once_with(directory_path) + ensure_repo("http://example.com/foo.git", directory_path_with_sgid) + mock_repo.assert_called_once_with(directory_path_with_sgid) mock_repo.clone_from.assert_not_called() @@ -109,9 +119,9 @@ def test_repo_cloned_if_not_found_locally( @patch("blueapi.cli.scratch.Repo") def test_repo_cloned_with_correct_umask( mock_repo: Mock, - directory_path: Path, + directory_path_with_sgid: Path, ): - repo_root = directory_path / "foo" + repo_root = directory_path_with_sgid / "foo" file_path = repo_root / "a" def write_repo_files(): @@ -149,15 +159,61 @@ def test_setup_scratch_fails_on_non_directory_root( setup_scratch(config) +def test_setup_scratch_fails_on_non_sgid_root( + directory_path: Path, +): + config = ScratchConfig(root=directory_path, repositories=[]) + with pytest.raises(PermissionError): + setup_scratch(config) + + +def test_setup_scratch_fails_on_wrong_gid( + directory_path_with_sgid: Path, +): + config = ScratchConfig( + root=directory_path_with_sgid, + required_gid=12345, + repositories=[], + ) + assert get_owner_gid(directory_path_with_sgid) != 12345 + with pytest.raises(PermissionError): + setup_scratch(config) + + +@pytest.mark.skip( + reason=""" +We can't chown a tempfile in all environments, in particular it +seems to be broken in GH actions at the moment. We should +rewrite these tests to use mocks. + +See https://github.com/DiamondLightSource/blueapi/issues/770 +""" +) +def test_setup_scratch_succeeds_on_required_gid( + directory_path_with_sgid: Path, +): + # We may not own the temp root in some environments + root = directory_path_with_sgid / "a-root" + os.makedirs(root) + os.chown(root, uid=12345, gid=12345) + config = ScratchConfig( + root=root, + required_gid=12345, + repositories=[], + ) + assert get_owner_gid(root) == 12345 + setup_scratch(config) + + @patch("blueapi.cli.scratch.ensure_repo") @patch("blueapi.cli.scratch.scratch_install") def test_setup_scratch_iterates_repos( mock_scratch_install: Mock, mock_ensure_repo: Mock, - directory_path: Path, + directory_path_with_sgid: Path, ): config = ScratchConfig( - root=directory_path, + root=directory_path_with_sgid, repositories=[ ScratchRepository( name="foo", @@ -173,15 +229,15 @@ def test_setup_scratch_iterates_repos( mock_ensure_repo.assert_has_calls( [ - call("http://example.com/foo.git", directory_path / "foo"), - call("http://example.com/bar.git", directory_path / "bar"), + call("http://example.com/foo.git", directory_path_with_sgid / "foo"), + call("http://example.com/bar.git", directory_path_with_sgid / "bar"), ] ) mock_scratch_install.assert_has_calls( [ - call(directory_path / "foo", timeout=120.0), - call(directory_path / "bar", timeout=120.0), + call(directory_path_with_sgid / "foo", timeout=120.0), + call(directory_path_with_sgid / "bar", timeout=120.0), ] ) @@ -191,10 +247,10 @@ def test_setup_scratch_iterates_repos( def test_setup_scratch_continues_after_failure( mock_scratch_install: Mock, mock_ensure_repo: Mock, - directory_path: Path, + directory_path_with_sgid: Path, ): config = ScratchConfig( - root=directory_path, + root=directory_path_with_sgid, repositories=[ ScratchRepository( name="foo", diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index 94c489b03..e86dbc490 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -95,7 +95,7 @@ def test_poll_for_token( assert token == valid_token -@patch("time.sleep") +@patch("blueapi.service.authentication.time.sleep", return_value=None) def test_poll_for_token_timeout( mock_sleep, oidc_well_known: dict[str, Any], @@ -111,7 +111,7 @@ def test_poll_for_token_timeout( status=HTTP_403_FORBIDDEN, ) with pytest.raises(TimeoutError), mock_authn_server: - session_manager.poll_for_token(device_code, 1, 2) + session_manager.poll_for_token(device_code, 0.01, 0.01) def test_server_raises_exception_for_invalid_token( diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 6e5fa6ee0..c18e98a68 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -1,7 +1,7 @@ import uuid from collections.abc import Iterator from dataclasses import dataclass -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import jwt import pytest @@ -14,40 +14,62 @@ from blueapi.config import ApplicationConfig, OIDCConfig from blueapi.core.bluesky_types import Plan from blueapi.service import main +from blueapi.service.interface import ( + cancel_active_task, + get_device, + get_plan, + pause_worker, + resume_worker, + submit_task, +) from blueapi.service.model import ( DeviceModel, + EnvironmentResponse, PlanModel, StateChangeRequest, WorkerTask, ) +from blueapi.service.runner import WorkerDispatcher from blueapi.worker.event import WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask +class MockCountModel(BaseModel): ... + + +COUNT = Plan(name="count", model=MockCountModel) + + +@pytest.fixture +def mock_runner() -> Mock: + return Mock(spec=WorkerDispatcher) + + @pytest.fixture -def client() -> Iterator[TestClient]: +def client(mock_runner: Mock) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): - main.setup_runner(use_subprocess=False) + main.setup_runner(runner=mock_runner) yield TestClient(main.get_app(ApplicationConfig())) main.teardown_runner() @pytest.fixture -def client_with_auth(oidc_config: OIDCConfig) -> Iterator[TestClient]: +def client_with_auth( + mock_runner: Mock, oidc_config: OIDCConfig +) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): - main.setup_runner(use_subprocess=False) + main.setup_runner(runner=mock_runner) yield TestClient(main.get_app(ApplicationConfig(oidc=oidc_config))) main.teardown_runner() -@patch("blueapi.service.interface.get_plans") -def test_get_plans(get_plans_mock: MagicMock, client: TestClient) -> None: +def test_get_plans(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plans_mock.return_value = [PlanModel.from_plan(plan)] + mock_runner.run.return_value = [PlanModel.from_plan(plan)] response = client.get("/plans") @@ -68,17 +90,16 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.get_plan") -def test_get_plan_by_name(get_plan_mock: MagicMock, client: TestClient) -> None: +def test_get_plan_by_name(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plan_mock.return_value = PlanModel.from_plan(plan) + mock_runner.run.return_value = PlanModel.from_plan(plan) response = client.get("/plans/my-plan") - get_plan_mock.assert_called_once_with("my-plan") + mock_runner.run.assert_called_once_with(get_plan, "my-plan") assert response.status_code == status.HTTP_200_OK assert response.json() == { "description": None, @@ -92,25 +113,21 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.get_plan") -def test_get_non_existent_plan_by_name( - get_plan_mock: MagicMock, client: TestClient -) -> None: - get_plan_mock.side_effect = KeyError("my-plan") +def test_get_non_existent_plan_by_name(mock_runner: Mock, client: TestClient) -> None: + mock_runner.run.side_effect = KeyError("my-plan") response = client.get("/plans/my-plan") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.get_devices") -def test_get_devices(get_devices_mock: MagicMock, client: TestClient) -> None: +def test_get_devices(mock_runner: Mock, client: TestClient) -> None: @dataclass class MyDevice: name: str device = MyDevice("my-device") - get_devices_mock.return_value = [DeviceModel.from_device(device)] + mock_runner.run.return_value = [DeviceModel.from_device(device)] response = client.get("/devices") @@ -125,18 +142,17 @@ class MyDevice: } -@patch("blueapi.service.interface.get_device") -def test_get_device_by_name(get_device_mock: MagicMock, client: TestClient) -> None: +def test_get_device_by_name(mock_runner: Mock, client: TestClient) -> None: @dataclass class MyDevice: name: str device = MyDevice("my-device") - get_device_mock.return_value = DeviceModel.from_device(device) + mock_runner.run.return_value = DeviceModel.from_device(device) response = client.get("/devices/my-device") - get_device_mock.assert_called_once_with("my-device") + mock_runner.run.assert_called_once_with(get_device, "my-device") assert response.status_code == status.HTTP_200_OK assert response.json() == { "name": "my-device", @@ -144,51 +160,44 @@ class MyDevice: } -@patch("blueapi.service.interface.get_device") -def test_get_non_existent_device_by_name( - get_device_mock: MagicMock, client: TestClient -) -> None: - get_device_mock.side_effect = KeyError("my-device") +def test_get_non_existent_device_by_name(mock_runner: Mock, client: TestClient) -> None: + mock_runner.run.side_effect = KeyError("my-device") response = client.get("/devices/my-device") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.submit_task") -@patch("blueapi.service.interface.get_plan") -def test_create_task( - get_plan_mock: MagicMock, submit_task_mock: MagicMock, client: TestClient -) -> None: +def test_create_task(mock_runner: Mock, client: TestClient) -> None: task = Task(name="count", params={"detectors": ["x"]}) task_id = str(uuid.uuid4()) - submit_task_mock.return_value = task_id + mock_runner.run.side_effect = [COUNT, task_id] response = client.post("/tasks", json=task.model_dump()) - submit_task_mock.assert_called_once_with(task) + mock_runner.run.assert_called_with(submit_task, task) assert response.json() == {"task_id": task_id} -@patch("blueapi.service.interface.submit_task") -@patch("blueapi.service.interface.get_plan") -def test_create_task_validation_error( - get_plan_mock: MagicMock, submit_task_mock: MagicMock, client: TestClient -) -> None: +def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - get_plan_mock.return_value = PlanModel.from_plan(plan) - submit_task_mock.side_effect = ValidationError.from_exception_data( - title="ValueError", - line_errors=[ - InitErrorDetails( - type="missing", loc=("id",), msg="value is required for Identifier" - ) # type: ignore - ], - ) + + mock_runner.run.side_effect = [ + PlanModel.from_plan(plan), + ValidationError.from_exception_data( + title="ValueError", + line_errors=[ + InitErrorDetails( + type="missing", loc=("id",), msg="value is required for Identifier" + ) # type: ignore + ], + ), + ] + response = client.post("/tasks", json={"name": "my-plan"}) assert response.status_code == 422 assert response.json() == { @@ -202,32 +211,21 @@ class MyModel(BaseModel): } -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") -def test_put_plan_begins_task( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient -) -> None: +def test_put_plan_begins_task(client: TestClient) -> None: task_id = "04cd9aa6-b902-414b-ae4b-49ea4200e957" - # Set to idle - get_active_task_mock.return_value = None - begin_task_mock.return_value = WorkerTask(task_id=task_id) - resp = client.put("/worker/task", json={"task_id": task_id}) assert resp.status_code == status.HTTP_200_OK assert resp.json() == {"task_id": task_id} -@patch("blueapi.service.interface.get_active_task") -def test_put_plan_fails_if_not_idle( - get_active_task_mock: MagicMock, client: TestClient -) -> None: +def test_put_plan_fails_if_not_idle(mock_runner: Mock, client: TestClient) -> None: task_id_current = "260f7de3-b608-4cdc-a66c-257e95809792" task_id_new = "07e98d68-21b5-4ad7-ac34-08b2cb992d42" # Set to non idle - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task=None, task_id=task_id_current, is_complete=False ) @@ -237,8 +235,7 @@ def test_put_plan_fails_if_not_idle( assert resp.json() == {"detail": "Worker already active"} -@patch("blueapi.service.interface.get_tasks") -def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: +def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: tasks = [ TrackableTask(task_id="0", task=Task(name="sleep", params={"time": 0.0})), TrackableTask( @@ -249,7 +246,7 @@ def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: ), ] - get_tasks_mock.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK @@ -276,10 +273,7 @@ def test_get_tasks(get_tasks_mock: MagicMock, client: TestClient) -> None: } -@patch("blueapi.service.interface.get_tasks_by_status") -def test_get_tasks_by_status( - get_tasks_by_status_mock: MagicMock, client: TestClient -) -> None: +def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: tasks = [ TrackableTask( task_id="3", @@ -289,7 +283,7 @@ def test_get_tasks_by_status( ), ] - get_tasks_by_status_mock.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks", params={"task_status": "PENDING"}) assert response.json() == { @@ -311,19 +305,14 @@ def test_get_tasks_by_status_invalid(client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST -@patch("blueapi.service.interface.clear_task") -def test_delete_submitted_task(clear_task_mock: MagicMock, client: TestClient) -> None: +def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: task_id = str(uuid.uuid4()) - clear_task_mock.return_value = task_id + mock_runner.run.return_value = task_id response = client.delete(f"/tasks/{task_id}") assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") -def test_set_active_task( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient -) -> None: +def test_set_active_task(client: TestClient) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) @@ -333,15 +322,13 @@ def test_set_active_task( assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") def test_set_active_task_active_task_complete( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient + mock_runner: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task_id="1", task=Task(name="a_completed_task"), is_complete=True, @@ -354,15 +341,13 @@ def test_set_active_task_active_task_complete( assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.begin_task") -@patch("blueapi.service.interface.get_active_task") def test_set_active_task_worker_already_running( - get_active_task_mock: MagicMock, begin_task_mock: MagicMock, client: TestClient + mock_runner: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - get_active_task_mock.return_value = TrackableTask( + mock_runner.run.return_value = TrackableTask( task_id="1", task=Task(name="a_running_task"), is_complete=False, @@ -375,15 +360,14 @@ def test_set_active_task_worker_already_running( assert response.json() == {"detail": "Worker already active"} -@patch("blueapi.service.interface.get_task_by_id") -def test_get_task(get_task_by_id: MagicMock, client: TestClient): +def test_get_task(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( task_id=task_id, task=Task(name="third_task"), ) - get_task_by_id.return_value = task + mock_runner.run.return_value = task response = client.get(f"/tasks/{task_id}") assert response.json() == { @@ -396,8 +380,7 @@ def test_get_task(get_task_by_id: MagicMock, client: TestClient): } -@patch("blueapi.service.interface.get_tasks") -def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): +def test_get_all_tasks(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) tasks = [ TrackableTask( @@ -406,7 +389,7 @@ def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): ) ] - get_all_tasks.return_value = tasks + mock_runner.run.return_value = tasks response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK assert response.json() == { @@ -423,138 +406,108 @@ def test_get_all_tasks(get_all_tasks: MagicMock, client: TestClient): } -@patch("blueapi.service.interface.get_task_by_id") -def test_get_task_error(get_task_by_id_mock: MagicMock, client: TestClient): +def test_get_task_error(mock_runner: Mock, client: TestClient): task_id = 567 - get_task_by_id_mock.return_value = None + mock_runner.run.return_value = None response = client.get(f"/tasks/{task_id}") assert response.json() == {"detail": "Item not found"} -@patch("blueapi.service.interface.get_active_task") -def test_get_active_task(get_active_task_mock: MagicMock, client: TestClient): +def test_get_active_task(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( task_id=task_id, task=Task(name="third_task"), ) - get_active_task_mock.return_value = task + mock_runner.run.return_value = task response = client.get("/worker/task") assert response.json() == {"task_id": f"{task_id}"} -@patch("blueapi.service.interface.get_active_task") -def test_get_active_task_none(get_active_task_mock: MagicMock, client: TestClient): - get_active_task_mock.return_value = None +def test_get_active_task_none(mock_runner: Mock, client: TestClient): + mock_runner.run.return_value = None response = client.get("/worker/task") assert response.json() == {"task_id": None} -@patch("blueapi.service.interface.get_worker_state") -def test_get_state(get_worker_state_mock: MagicMock, client: TestClient): +def test_get_state(mock_runner: Mock, client: TestClient): state = WorkerState.SUSPENDING - get_worker_state_mock.return_value = state + mock_runner.run.return_value = state response = client.get("/worker/state") assert response.json() == state -@patch("blueapi.service.interface.pause_worker") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_running_to_paused( - get_worker_state_mock: MagicMock, pause_worker_mock: MagicMock, client: TestClient -): +def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.PAUSED - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - pause_worker_mock.assert_called_once_with(False) + mock_runner.run.assert_any_call(pause_worker, False) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.resume_worker") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_paused_to_running( - get_worker_state_mock: MagicMock, resume_worker_mock: MagicMock, client: TestClient -): +def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): current_state = WorkerState.PAUSED final_state = WorkerState.RUNNING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - resume_worker_mock.assert_called_once() + mock_runner.run.assert_any_call(resume_worker) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_running_to_aborting( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, -): +def test_set_state_running_to_aborting(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.ABORTING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - cancel_active_task_mock.assert_called_once_with(True, None) + mock_runner.run.assert_any_call(cancel_active_task, True, None) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") def test_set_state_running_to_stopping_including_reason( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, + mock_runner: Mock, client: TestClient ): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING reason = "blueapi is being stopped" - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, None, final_state] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state, reason=reason).model_dump(), ) - cancel_active_task_mock.assert_called_once_with(False, reason) + mock_runner.run.assert_any_call(cancel_active_task, False, reason) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state -@patch("blueapi.service.interface.cancel_active_task") -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_transition_error( - get_worker_state_mock: MagicMock, - cancel_active_task_mock: MagicMock, - client: TestClient, -): +def test_set_state_transition_error(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING - get_worker_state_mock.side_effect = [current_state, final_state] - - cancel_active_task_mock.side_effect = TransitionError() + mock_runner.run.side_effect = [current_state, TransitionError(), final_state] response = client.put( "/worker/state", @@ -565,15 +518,12 @@ def test_set_state_transition_error( assert response.json() == final_state -@patch("blueapi.service.interface.get_worker_state") -def test_set_state_invalid_transition( - get_worker_state_mock: MagicMock, client: TestClient -): +def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient): current_state = WorkerState.STOPPING requested_state = WorkerState.PAUSED final_state = WorkerState.STOPPING - get_worker_state_mock.side_effect = [current_state, final_state] + mock_runner.run.side_effect = [current_state, final_state] response = client.put( "/worker/state", @@ -584,14 +534,19 @@ def test_set_state_invalid_transition( assert response.json() == final_state -def test_get_environment_idle(client: TestClient) -> None: +def test_get_environment_idle(mock_runner: Mock, client: TestClient) -> None: + mock_runner.state = EnvironmentResponse( + initialized=True, + error_message=None, + ) + assert client.get("/environment").json() == { "initialized": True, "error_message": None, } -def test_delete_environment(client: TestClient) -> None: +def test_delete_environment(mock_runner: Mock, client: TestClient) -> None: response = client.delete("/environment") assert response.status_code is status.HTTP_200_OK @@ -604,11 +559,8 @@ def test_subprocess_enabled_by_default(mp_pool_mock: MagicMock): main.teardown_runner() -@patch("blueapi.service.interface.get_device") -def test_get_without_authentication( - get_device_mock: MagicMock, client: TestClient -) -> None: - get_device_mock.side_effect = jwt.PyJWTError +def test_get_without_authentication(mock_runner: Mock, client: TestClient) -> None: + mock_runner.run.side_effect = jwt.PyJWTError response = client.get("/devices/my-device") assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -621,14 +573,13 @@ def test_oidc_config_not_found_when_auth_is_disabled(client: TestClient): assert response.json() == {"detail": "Not Found"} -@patch("blueapi.service.interface.get_oidc_config") def test_get_oidc_config( - get_oidc_config: MagicMock, + mock_runner: Mock, oidc_config: OIDCConfig, mock_authn_server, client_with_auth: TestClient, ): - get_oidc_config.return_value = oidc_config + mock_runner.run.return_value = oidc_config response = client_with_auth.get("/config/oidc") assert response.status_code == status.HTTP_200_OK assert response.json() == oidc_config.model_dump() diff --git a/tests/unit_tests/service/test_runner.py b/tests/unit_tests/service/test_runner.py index 1162d8108..2c7c4b7b0 100644 --- a/tests/unit_tests/service/test_runner.py +++ b/tests/unit_tests/service/test_runner.py @@ -1,6 +1,6 @@ +from multiprocessing.pool import Pool as PoolClass from typing import Any, Generic, TypeVar -from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from observability_utils.tracing import ( @@ -16,17 +16,19 @@ InvalidRunnerStateError, RpcError, WorkerDispatcher, + import_and_run_function, ) @pytest.fixture -def local_runner(): - return WorkerDispatcher(use_subprocess=False) +def mock_subprocess() -> Mock: + subprocess = Mock(spec=PoolClass) + return subprocess @pytest.fixture -def runner(): - return WorkerDispatcher() +def runner(mock_subprocess: Mock): + return WorkerDispatcher(subprocess_factory=lambda: mock_subprocess) @pytest.fixture @@ -36,13 +38,22 @@ def started_runner(runner: WorkerDispatcher): runner.stop() -def test_initialize(runner: WorkerDispatcher): +def test_initialize(runner: WorkerDispatcher, mock_subprocess: Mock): + mock_subprocess.apply.return_value = None + + assert runner.state.error_message is None assert not runner.state.initialized runner.start() + + assert runner.state.error_message is None assert runner.state.initialized + # Run a single call to the runner for coverage of dispatch to subprocess - assert runner.run(interface.get_worker_state) + mock_subprocess.apply.return_value = 123 + assert runner.run(interface.get_worker_state) == 123 runner.stop() + + assert runner.state.error_message is None assert not runner.state.initialized @@ -59,22 +70,20 @@ def test_raises_if_used_before_started(runner: WorkerDispatcher): runner.run(interface.get_plans) -def test_error_on_runner_setup(local_runner: WorkerDispatcher): +def test_error_on_runner_setup(runner: WorkerDispatcher, mock_subprocess: Mock): + error_message = "Intentional start_worker exception" expected_state = EnvironmentResponse( initialized=False, - error_message="Intentional start_worker exception", + error_message=error_message, ) + mock_subprocess.apply.side_effect = Exception(error_message) - with mock.patch( - "blueapi.service.runner.setup", - side_effect=Exception("Intentional start_worker exception"), - ): - # Calling reload here instead of start also indirectly - # tests that stop() doesn't raise if there is no error message - # and the runner is not yet initialised - local_runner.reload() - state = local_runner.state - assert state == expected_state + # Calling reload here instead of start also indirectly + # tests that stop() doesn't raise if there is no error message + # and the runner is not yet initialised + runner.reload() + state = runner.state + assert state == expected_state def start_worker_mock(): @@ -99,7 +108,7 @@ def test_can_reload_after_an_error(pool_mock: MagicMock): another_mock.apply.side_effect = subprocess_calls_return_values - runner = WorkerDispatcher(use_subprocess=True) + runner = WorkerDispatcher() runner.start() assert runner.state == EnvironmentResponse( @@ -111,55 +120,59 @@ def test_can_reload_after_an_error(pool_mock: MagicMock): assert runner.state == EnvironmentResponse(initialized=True, error_message=None) -def test_function_not_findable_on_subprocess(started_runner: WorkerDispatcher): - from tests.unit_tests.core.fake_device_module import fake_motor_y - - # Valid target on main but not sub process - # Change in this process not reflected in subprocess - fake_motor_y.__name__ = "not_exported" - - with pytest.raises( - RpcError, match="not_exported: No such function in subprocess API" - ): - started_runner.run(fake_motor_y) +@patch("blueapi.service.runner.Pool") +def test_subprocess_enabled_by_default(pool_mock: MagicMock): + runner = WorkerDispatcher() + runner.start() + pool_mock.assert_called_once() + runner.stop() -def test_non_callable_excepts_in_main_process(started_runner: WorkerDispatcher): - # Not a valid target on main or sub process - from tests.unit_tests.core.fake_device_module import fetchable_non_callable +def test_clear_message_for_anonymous_function(started_runner: WorkerDispatcher): + non_fetchable_callable = MagicMock() with pytest.raises( RpcError, - match=" is not Callable, " - + "cannot be run in subprocess", + match=" is anonymous, cannot be run in subprocess", ): - started_runner.run(fetchable_non_callable) + started_runner.run(non_fetchable_callable) -def test_non_callable_excepts_in_sub_process(started_runner: WorkerDispatcher): - # Valid target on main but finds non-callable in sub process - from tests.unit_tests.core.fake_device_module import ( - fetchable_callable, - fetchable_non_callable, - ) +def test_function_not_findable_on_subprocess(): + with pytest.raises(RpcError, match="unknown: No such function in subprocess API"): + import_and_run_function("blueapi", "unknown", None, {}) - fetchable_callable.__name__ = fetchable_non_callable.__name__ - with pytest.raises( - RpcError, - match="fetchable_non_callable: Object in subprocess is not a function", - ): - started_runner.run(fetchable_callable) +def test_module_not_findable_on_subprocess(): + with pytest.raises(ModuleNotFoundError): + import_and_run_function("unknown", "unknown", None, {}) -def test_clear_message_for_anonymous_function(started_runner: WorkerDispatcher): - non_fetchable_callable = MagicMock() +def run_rpc_function( + func: Callable[..., Any], + expected_type: type[Any], + *args: Any, + **kwargs: Any, +) -> Any: + import_and_run_function( + func.__module__, + func.__name__, + expected_type, + {}, + *args, + **kwargs, + ) + + +def test_non_callable_excepts(started_runner: WorkerDispatcher): + # Not a valid target on main or sub process + from tests.unit_tests.core.fake_device_module import fetchable_non_callable with pytest.raises( RpcError, - match=" is anonymous, cannot be run in subprocess", + match="fetchable_non_callable: Object in subprocess is not a function", ): - started_runner.run(non_fetchable_callable) + run_rpc_function(fetchable_non_callable, Mock) def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher): @@ -169,7 +182,7 @@ def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher): ValidationError, match="1 validation error for int", ): - started_runner.run(wrong_return_type) + run_rpc_function(wrong_return_type, int) T = TypeVar("T") diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 72e98c9d2..fb918fb66 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -60,6 +60,21 @@ def test_main_no_params(): assert result.stdout == expected +@patch("blueapi.service.main.start") +@patch("blueapi.cli.scratch.setup_scratch") +@patch("blueapi.cli.cli.os.umask") +@pytest.mark.parametrize("subcommand", ["serve", "setup-scratch"]) +def test_runs_with_umask_002( + mock_umask: Mock, + mock_setup_scratch: Mock, + mock_start: Mock, + runner: CliRunner, + subcommand: str, +): + runner.invoke(main, [subcommand]) + mock_umask.assert_called_once_with(0o002) + + @patch("requests.request") def test_connection_error_caught_by_wrapper_func( mock_requests: Mock, runner: CliRunner @@ -241,7 +256,7 @@ def test_reset_env_client_behavior( reload_result = runner.invoke(main, ["controller", "env", "-r"]) # Verify if sleep was called between polling iterations - assert mock_sleep.call_count == 2 # Since the last check doesn't require a sleep + mock_sleep.assert_called() for index, call in enumerate(responses.calls): if index == 0: diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 26d353a70..67517a691 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -277,6 +277,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): }, "scratch": { "root": "/tmp/scratch/blueapi", + "required_gid": None, "repositories": [ { "name": "dodal", @@ -309,6 +310,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): }, "scratch": { "root": "/tmp/scratch/blueapi", + "required_gid": None, "repositories": [ { "name": "dodal", diff --git a/tests/unit_tests/utils/test_file_permissions.py b/tests/unit_tests/utils/test_file_permissions.py new file mode 100644 index 000000000..11f36f889 --- /dev/null +++ b/tests/unit_tests/utils/test_file_permissions.py @@ -0,0 +1,65 @@ +import stat +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from blueapi.utils import get_owner_gid, is_sgid_set + + +@pytest.mark.parametrize( + "bits", + [ + # Files + 0o10_0600, # -rw-------. + 0o10_0777, # -rwxrwxrwx. + 0o10_0000, # ----------. + 0o10_0644, # -rw-r--r--. + 0o10_0400, # -r--------. + 0o10_0666, # -rw-rw-rw-. + 0o10_0444, # -r--r--r--. + # Directories + 0o04_0777, # drwxrwxrwx. + 0o04_0000, # d---------. + 0o04_0600, # drw-------. + ], + ids=lambda p: f"{p:06o} ({stat.filemode(p)})", +) +def test_is_sgid_set_should_be_disabled(bits: int): + assert not _mocked_is_sgid_set(bits) + + +@pytest.mark.parametrize( + "bits", + [ + # Files + 0o10_2777, # -rwxrwsrwx. + 0o10_2000, # ------S---. + 0o10_2644, # -rw-r-Sr--. + 0o10_2600, # -rw---S---. + 0o10_2400, # -r----S---. + 0o10_2666, # -rw-rwSrw-. + 0o10_2444, # -r--r-Sr--. + # Directories + 0o04_2777, # drwxrwsrwx. + 0o04_2000, # d-----S---. + 0o04_2600, # drw---S---. + ], + ids=lambda p: f"{p:06o} ({stat.filemode(p)})", +) +def test_is_sgid_set_should_be_enabled(bits: int): + assert _mocked_is_sgid_set(bits) + + +def _mocked_is_sgid_set(bits: int) -> bool: + path = Mock(spec=Path) + path.stat().st_mode = bits + + return is_sgid_set(path) + + +def test_get_owner_gid(): + path = Mock(spec=Path) + path.stat().st_gid = 12345 + + assert get_owner_gid(path) == 12345