Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block until environment reloaded #743

Merged
merged 7 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/reference/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ components:
additionalProperties: false
description: State of internal environment.
properties:
environment_id:
description: Unique ID for the environment instance, can be used to differentiate
between a new environment and old that has been torn down
format: uuid
title: Environment Id
type: string
error_message:
anyOf:
- minLength: 1
Expand All @@ -49,6 +55,7 @@ components:
title: Initialized
type: boolean
required:
- environment_id
- initialized
title: EnvironmentResponse
type: object
Expand Down
6 changes: 5 additions & 1 deletion src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def _wait_for_reload(
teardown_complete_time = time.time()
too_late = teardown_complete_time + timeout if timeout is not None else None

previous_environment_id = status.environment_id
# Wait forever if there was no timeout
while too_late is None or time.time() < too_late:
# Poll until the environment is restarted or the timeout is reached
Expand All @@ -415,7 +416,10 @@ def _wait_for_reload(
raise BlueskyRemoteControlError(
f"Error reloading environment: {status.error_message}"
)
elif status.initialized:
elif (
status.initialized is True
and status.environment_id != previous_environment_id
):
return status
time.sleep(polling_interval)
# If the function did not raise or return early, it timed out
Expand Down
4 changes: 2 additions & 2 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ async def delete_environment(
runner: WorkerDispatcher = Depends(_runner),
) -> EnvironmentResponse:
"""Delete the current environment, causing internal components to be reloaded."""

environment_id = runner.state.environment_id
if runner.state.initialized or runner.state.error_message is not None:
background_tasks.add_task(runner.reload)
return EnvironmentResponse(initialized=False)
return EnvironmentResponse(environment_id=environment_id, initialized=False)


@auth_router.get("/config/oidc", tags=["auth"], response_model=OIDCConfig)
Expand Down
5 changes: 5 additions & 0 deletions src/blueapi/service/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from collections.abc import Iterable
from typing import Any

Expand Down Expand Up @@ -144,6 +145,10 @@ class EnvironmentResponse(BlueapiBaseModel):
State of internal environment.
"""

environment_id: uuid.UUID = Field(
description="Unique ID for the environment instance, can be used to "
"differentiate between a new environment and old that has been torn down"
)
initialized: bool = Field(description="blueapi context initialized")
error_message: str | None = Field(
default=None,
Expand Down
12 changes: 11 additions & 1 deletion src/blueapi/service/runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import logging
import signal
import uuid
from collections.abc import Callable
from importlib import import_module
from multiprocessing import Pool, set_start_method
Expand Down Expand Up @@ -57,6 +58,7 @@ def default_subprocess_factory():
self._subprocess = None
self._subprocess_factory = subprocess_factory or default_subprocess_factory
self._state = EnvironmentResponse(
environment_id=uuid.uuid4(),
initialized=False,
)

Expand All @@ -69,30 +71,38 @@ def reload(self):

@start_as_current_span(TRACER)
def start(self):
environment_id = uuid.uuid4()
try:
self._subprocess = self._subprocess_factory()
self.run(setup, self._config)
self._state = EnvironmentResponse(initialized=True)
self._state = EnvironmentResponse(
environment_id=environment_id,
initialized=True,
)
except Exception as e:
self._state = EnvironmentResponse(
environment_id=environment_id,
initialized=False,
error_message=str(e),
)
LOGGER.exception(e)

@start_as_current_span(TRACER)
def stop(self):
environment_id = self._state.environment_id
try:
self.run(teardown)
if self._subprocess is not None:
self._subprocess.close()
self._subprocess.join()
self._state = EnvironmentResponse(
environment_id=environment_id,
initialized=False,
error_message=self._state.error_message,
)
except Exception as e:
self._state = EnvironmentResponse(
environment_id=environment_id,
initialized=False,
error_message=str(e),
)
Expand Down
10 changes: 7 additions & 3 deletions tests/system_tests/test_blueapi_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from blueapi.service.model import (
DeviceResponse,
EnvironmentResponse,
PlanResponse,
TaskResponse,
WorkerTask,
Expand Down Expand Up @@ -335,9 +334,14 @@ def on_event(event: AnyEvent):


def test_get_current_state_of_environment(client: BlueapiClient):
assert client.get_environment() == EnvironmentResponse(initialized=True)
assert client.get_environment().initialized


def test_delete_current_environment(client: BlueapiClient):
current_env = client.get_environment()
client.reload_environment()
assert client.get_environment() == EnvironmentResponse(initialized=True)
new_env = client.get_environment()
assert (
new_env.initialized is True
and new_env.environment_id != current_env.environment_id
)
83 changes: 77 additions & 6 deletions tests/unit_tests/client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid
from collections.abc import Callable
from unittest.mock import MagicMock, Mock, call
from unittest.mock import MagicMock, Mock, call, patch

import pytest
from bluesky_stomp.messaging import MessageContext
Expand Down Expand Up @@ -42,7 +43,10 @@
TASK = TrackableTask(task_id="foo", task=Task(name="bar", params={}))
TASKS = TasksListResponse(tasks=[TASK])
ACTIVE_TASK = WorkerTask(task_id="bar")
ENV = EnvironmentResponse(initialized=True)
ENVIRONMENT_ID = uuid.uuid4()
NEW_ENVIRONMENT_ID = uuid.uuid4()
ENV = EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=True)
NEW_ENV = EnvironmentResponse(environment_id=NEW_ENVIRONMENT_ID, initialized=True)
COMPLETE_EVENT = WorkerEvent(
state=WorkerState.IDLE,
task_status=TaskStatus(
Expand Down Expand Up @@ -74,8 +78,9 @@ def mock_rest() -> BlueapiRestClient:
mock.get_all_tasks.return_value = TASKS
mock.get_active_task.return_value = ACTIVE_TASK
mock.get_environment.return_value = ENV
mock.delete_environment.return_value = EnvironmentResponse(initialized=False)

mock.delete_environment.return_value = EnvironmentResponse(
environment_id=ENVIRONMENT_ID, initialized=False
)
return mock


Expand Down Expand Up @@ -254,17 +259,82 @@ def test_reload_environment(
client: BlueapiClient,
mock_rest: Mock,
):
client.reload_environment()
mock_rest.get_environment.return_value = NEW_ENV
environment = client.reload_environment()
mock_rest.get_environment.assert_called_once()
mock_rest.delete_environment.assert_called_once()
assert environment == NEW_ENV


@patch("blueapi.client.client.time.time")
@patch("blueapi.client.client.time.sleep")
def test_reload_environment_no_timeout(
mock_sleep: Mock,
mock_time: Mock,
client: BlueapiClient,
mock_rest: Mock,
):
mock_rest.get_environment.side_effect = [ENV, ENV, ENV, NEW_ENV]
mock_time.return_value = 100.0
environment = client.reload_environment(timeout=None)
assert mock_sleep.call_count == 3
assert environment == NEW_ENV


@patch("blueapi.client.client.time.time")
@patch("blueapi.client.client.time.sleep")
def test_reload_environment_with_timeout(
_: Mock,
mock_time: Mock,
client: BlueapiClient,
mock_rest: Mock,
):
mock_rest.get_environment.side_effect = [
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
EnvironmentResponse(environment_id=ENVIRONMENT_ID, initialized=False),
]
mock_time.side_effect = [
100.0,
100.5,
101.0, # Timeout should occur here
101.5,
]
with pytest.raises(
TimeoutError,
match="Failed to reload the environment within 1.0 "
"seconds, a server restart is recommended",
):
client.reload_environment(timeout=1.0)


@patch("blueapi.client.client.time.time")
@patch("blueapi.client.client.time.sleep")
def test_reload_environment_ignores_current_environment(
mock_sleep: Mock,
mock_time: Mock,
client: BlueapiClient,
mock_rest: Mock,
):
mock_rest.get_environment.side_effect = [
ENV, # This is the old environment
ENV,
ENV,
NEW_ENV, # This is the new environment
]
mock_time.return_value = 100.0
environment = client.reload_environment(timeout=None)
assert mock_sleep.call_count == 3
assert environment == NEW_ENV


def test_reload_environment_failure(
client: BlueapiClient,
mock_rest: Mock,
):
mock_rest.get_environment.return_value = EnvironmentResponse(
initialized=False, error_message="foo"
environment_id=ENVIRONMENT_ID, initialized=False, error_message="foo"
)
with pytest.raises(BlueskyRemoteControlError, match="foo"):
client.reload_environment()
Expand Down Expand Up @@ -527,6 +597,7 @@ def test_reload_environment_span_ok(
client: BlueapiClient,
mock_rest: Mock,
):
mock_rest.get_environment.return_value = NEW_ENV
with asserting_span_exporter(exporter, "reload_environment"):
client.reload_environment()

Expand Down
19 changes: 15 additions & 4 deletions tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import uuid
from pathlib import Path
from unittest.mock import Mock, patch

Expand Down Expand Up @@ -53,16 +54,21 @@ def test_auth_request_functionality(
mock_authn_server: responses.RequestsMock,
cached_valid_token: Path,
):
environment_id = uuid.uuid4()
mock_authn_server.stop() # Cannot use multiple RequestsMock context manager
mock_get_env = mock_authn_server.get(
"http://localhost:8000/environment",
json=EnvironmentResponse(initialized=True).model_dump(),
json=EnvironmentResponse(
environment_id=environment_id, initialized=True
).model_dump(mode="json"),
status=200,
)
result = None
with mock_authn_server:
result = rest_with_auth.get_environment()
assert result == EnvironmentResponse(initialized=True)
assert result == EnvironmentResponse(
environment_id=environment_id, initialized=True
)
calls = mock_get_env.calls
assert len(calls) == 1
cacheManager = SessionCacheManager(cached_valid_token)
Expand All @@ -75,16 +81,21 @@ def test_refresh_if_signature_expired(
mock_authn_server: responses.RequestsMock,
cached_valid_refresh: Path,
):
environment_id = uuid.uuid4()
mock_authn_server.stop() # Cannot use multiple RequestsMock context manager
mock_get_env = mock_authn_server.get(
"http://localhost:8000/environment",
json=EnvironmentResponse(initialized=True).model_dump(),
json=EnvironmentResponse(
environment_id=environment_id, initialized=True
).model_dump(mode="json"),
status=200,
)
result = None
with mock_authn_server:
result = rest_with_auth.get_environment()
assert result == EnvironmentResponse(initialized=True)
assert result == EnvironmentResponse(
environment_id=environment_id, initialized=True
)
calls = mock_get_env.calls
assert len(calls) == 1
assert calls[0].request.headers["Authorization"] == "Bearer new_token"
14 changes: 14 additions & 0 deletions tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,20 +535,34 @@ def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient):


def test_get_environment_idle(mock_runner: Mock, client: TestClient) -> None:
environment_id = uuid.uuid4()
mock_runner.state = EnvironmentResponse(
environment_id=environment_id,
initialized=True,
error_message=None,
)

assert client.get("/environment").json() == {
"environment_id": str(environment_id),
"initialized": True,
"error_message": None,
}


def test_delete_environment(mock_runner: Mock, client: TestClient) -> None:
environment_id = uuid.uuid4()
mock_runner.state = EnvironmentResponse(
environment_id=environment_id,
initialized=True,
error_message=None,
)
response = client.delete("/environment")
assert response.status_code is status.HTTP_200_OK
assert response.json() == {
"environment_id": str(environment_id),
"initialized": False,
"error_message": None,
}


@patch("blueapi.service.runner.Pool")
Expand Down
Loading
Loading