From b7cbcfc153438453a46507a586fb30f20dd6cfe6 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Fri, 13 Dec 2024 10:08:11 +0000 Subject: [PATCH 1/5] Add minimal configuration for inserting documents into Tiled --- pyproject.toml | 4 ++++ src/blueapi/config.py | 10 ++++++++++ src/blueapi/service/interface.py | 18 +++++++++++++++++- 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2311c5914..52d9367ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,10 @@ classifiers = [ ] description = "Lightweight bluesky-as-a-service wrapper application. Also usable as a library." dependencies = [ + "tiled", + "json_merge_patch", + "jsonpatch", + "pyarrow", "bluesky>=1.13", "ophyd", "nslsii", diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 8f53878cc..016925216 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -40,6 +40,15 @@ class StompConfig(BaseModel): auth: BasicAuthentication | None = None +class TiledConfig(BaseModel): + """ + Config for connecting to a tiled instance + """ + + uri: str + api_key: str + + class WorkerEventConfig(BlueapiBaseModel): """ Config for event broadcasting via the message bus @@ -138,6 +147,7 @@ class ApplicationConfig(BlueapiBaseModel): """ stomp: StompConfig | None = None + tiled: TiledConfig | None = None env: EnvironmentConfig = Field(default_factory=EnvironmentConfig) logging: LoggingConfig = Field(default_factory=LoggingConfig) api: RestConfig = Field(default_factory=RestConfig) diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 803841964..620b8114a 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -3,10 +3,12 @@ from functools import cache from typing import Any +from bluesky.callbacks.tiled_writer import TiledWriter from bluesky_stomp.messaging import StompClient from bluesky_stomp.models import Broker, DestinationBase, MessageTopic +from tiled.client import from_uri -from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig +from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig, TiledConfig from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream from blueapi.service.model import DeviceModel, PlanModel, WorkerTask @@ -48,6 +50,19 @@ def worker() -> TaskWorker: return worker +@cache +def tiled_inserter(): + tiled_config: TiledConfig | None = config().tiled + if tiled_config is not None: + client = from_uri(tiled_config.uri, api_key=tiled_config.api_key) + + ctx = context() + ctx.run_engine.subscribe(TiledWriter(client)) + return client + else: + return None + + @cache def stomp_client() -> StompClient | None: stomp_config: StompConfig | None = config().stomp @@ -86,6 +101,7 @@ def setup(config: ApplicationConfig) -> None: logging.basicConfig(format="%(asctime)s - %(message)s", level=config.logging.level) worker() stomp_client() + tiled_inserter() def teardown() -> None: From 7f3e707f127c8e15042f747af12486783ad17b8d Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Thu, 9 Jan 2025 09:42:25 +0000 Subject: [PATCH 2/5] Ignore UserWarning temporarily --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c0024eb5c..c2d7fa354 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,8 @@ addopts = """ --ignore=src/blueapi/startup """ # https://iscinumpy.gitlab.io/post/bound-version-constraints/#watch-for-warnings -filterwarnings = ["error", "ignore::DeprecationWarning"] +# Unignore UserWarning after Pydantic warning removed from bluesky/bluesky and release +filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::UserWarning"] # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "docs src tests" asyncio_mode = "auto" From f3767e15f264489ef1eb32c6201bc6bd7b119c45 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Thu, 9 Jan 2025 10:55:21 +0000 Subject: [PATCH 3/5] Add default Tiled config to tests --- tests/unit_tests/test_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 67517a691..3caf99a19 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -224,6 +224,7 @@ def temp_yaml_config_file( "logging": {"level": "INFO"}, "api": {"host": "0.0.0.0", "port": 8000, "protocol": "http"}, "scratch": None, + "tiled": None, }, ], indirect=True, @@ -285,6 +286,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "tiled": None, }, { "stomp": { @@ -318,6 +320,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "tiled": None, }, ], indirect=True, From e9406e56d4d41c5f29b332b20f8cfaf2abe48be5 Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Fri, 10 Jan 2025 14:46:28 +0000 Subject: [PATCH 4/5] Forward existing Authencation headers --- src/blueapi/config.py | 4 ++-- src/blueapi/service/interface.py | 23 +++++------------------ src/blueapi/worker/task_worker.py | 31 +++++++++++++++++++++++++------ src/blueapi/worker/tiled.py | 23 +++++++++++++++++++++++ 4 files changed, 55 insertions(+), 26 deletions(-) create mode 100644 src/blueapi/worker/tiled.py diff --git a/src/blueapi/config.py b/src/blueapi/config.py index e99d02b15..1055eea8b 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -46,8 +46,8 @@ class TiledConfig(BaseModel): Config for connecting to a tiled instance """ - uri: str - api_key: str + host: str + port: int class WorkerEventConfig(BlueapiBaseModel): diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 620b8114a..162064c14 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -3,18 +3,17 @@ from functools import cache from typing import Any -from bluesky.callbacks.tiled_writer import TiledWriter from bluesky_stomp.messaging import StompClient from bluesky_stomp.models import Broker, DestinationBase, MessageTopic -from tiled.client import from_uri -from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig, TiledConfig +from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig from blueapi.core.context import BlueskyContext from blueapi.core.event import EventStream from blueapi.service.model import DeviceModel, PlanModel, WorkerTask from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask +from blueapi.worker.tiled import TiledConnection """This module provides interface between web application and underlying Bluesky context and worker""" @@ -42,27 +41,16 @@ def context() -> BlueskyContext: @cache def worker() -> TaskWorker: + conf = config() worker = TaskWorker( context(), - broadcast_statuses=config().env.events.broadcast_status_events, + broadcast_statuses=conf.env.events.broadcast_status_events, + tiled_inserter=TiledConnection(conf.tiled) if conf.tiled else None, ) worker.start() return worker -@cache -def tiled_inserter(): - tiled_config: TiledConfig | None = config().tiled - if tiled_config is not None: - client = from_uri(tiled_config.uri, api_key=tiled_config.api_key) - - ctx = context() - ctx.run_engine.subscribe(TiledWriter(client)) - return client - else: - return None - - @cache def stomp_client() -> StompClient | None: stomp_config: StompConfig | None = config().stomp @@ -101,7 +89,6 @@ def setup(config: ApplicationConfig) -> None: logging.basicConfig(format="%(asctime)s - %(message)s", level=config.logging.level) worker() stomp_client() - tiled_inserter() def teardown() -> None: diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 546c5f3b5..67c22119c 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -9,6 +9,7 @@ from typing import Any, Generic, TypeVar from bluesky.protocols import Status +from httpx import Headers from observability_utils.tracing import ( add_span_attributes, get_tracer, @@ -32,6 +33,7 @@ from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop from blueapi.utils.base_model import BlueapiBaseModel from blueapi.utils.thread_exception import handle_all_exceptions +from blueapi.worker.tiled import TiledConnection from .event import ( ProgressEvent, @@ -112,9 +114,11 @@ def __init__( ctx: BlueskyContext, start_stop_timeout: float = DEFAULT_START_STOP_TIMEOUT, broadcast_statuses: bool = True, + tiled_inserter: TiledConnection | None = None, ) -> None: self._ctx = ctx self._start_stop_timeout = start_stop_timeout + self._tiled_inserter = tiled_inserter self._tasks = {} @@ -194,13 +198,25 @@ def get_active_task(self) -> TrackableTask[Task] | None: return current @start_as_current_span(TRACER, "task_id") - def begin_task(self, task_id: str) -> None: + def begin_task(self, task_id: str, headers: Headers | None) -> None: task = self._tasks.get(task_id) + data_subs: list[int] = [] if task is not None: - self._submit_trackable_task(task) + if self._tiled_inserter: + data_subs.append(self._authorize_running_task(headers)) + self._submit_trackable_task(task, data_subs) + else: raise KeyError(f"No pending task with ID {task_id}") + def _authorize_running_task(self, headers: Headers | None) -> int: + assert self._tiled_inserter + # https://github.com/DiamondLightSource/blueapi/issues/774 + # If users should only be able to run their own scans, pass headers + # as part of submitting a task, cache in TrackableTask field and check + # that token belongs to same user (but may be newer token!) + return self.data_events.subscribe(self._tiled_inserter(headers)) + @start_as_current_span(TRACER, "task.name", "task.params") def submit_task(self, task: Task) -> str: task.prepare_params(self._ctx) # Will raise if parameters are invalid @@ -218,7 +234,9 @@ def submit_task(self, task: Task) -> str: "trackable_task.task.name", "trackable_task.task.params", ) - def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: + def _submit_trackable_task( + self, trackable_task: TrackableTask, data_subs: list[int] | None = None + ) -> None: if self.state is not WorkerState.IDLE: raise WorkerBusyError(f"Worker is in state {self.state}") @@ -235,17 +253,18 @@ def mark_task_as_started(event: WorkerEvent, _: str | None) -> None: sub = self.worker_events.subscribe(mark_task_as_started) try: self._current_task_otel_context = get_current() - sub = self.worker_events.subscribe(mark_task_as_started) """ Cache the current trace context as the one for this task id """ self._task_channel.put_nowait(trackable_task) - task_started.wait(timeout=5.0) - if not task_started.is_set(): + if not task_started.wait(timeout=5.0): raise TimeoutError("Failed to start plan within timeout") except Full as f: LOGGER.error("Cannot submit task while another is running") raise WorkerBusyError("Cannot submit task while another is running") from f finally: self.worker_events.unsubscribe(sub) + if data_subs: + for data_sub in data_subs: + self.data_events.unsubscribe(data_sub) @start_as_current_span(TRACER) def start(self) -> None: diff --git a/src/blueapi/worker/tiled.py b/src/blueapi/worker/tiled.py new file mode 100644 index 000000000..816f0156a --- /dev/null +++ b/src/blueapi/worker/tiled.py @@ -0,0 +1,23 @@ +from bluesky.callbacks.tiled_writer import TiledWriter +from httpx import Headers +from tiled.client import from_context +from tiled.client.context import Context as TiledContext + +from blueapi.config import TiledConfig +from blueapi.core.bluesky_types import DataEvent + + +class TiledConverter: + def __init__(self, tiled_context: TiledContext): + self._writer: TiledWriter = TiledWriter(from_context(tiled_context)) + + def __call__(self, data: DataEvent, _: str | None = None) -> None: + self._writer(data.name, data.doc) + + +class TiledConnection: + def __init__(self, config: TiledConfig): + self.uri = f"{config.host}:{config.port}" + + def __call__(self, headers: Headers | None): + return TiledConverter(TiledContext(self.uri, headers=headers)) From ad3ad9a39bfd62717148a42ccf4f6b79e52f82fc Mon Sep 17 00:00:00 2001 From: "Ware, Joseph (DLSLtd,RAL,LSCI)" Date: Mon, 13 Jan 2025 18:21:43 +0000 Subject: [PATCH 5/5] Almost working tiled insertion and configuration for use with tiled and mock-oidc-server --- src/blueapi/config.py | 4 +- src/blueapi/service/authentication.py | 4 +- src/blueapi/service/interface.py | 4 +- src/blueapi/service/main.py | 576 ++++++++++++-------------- src/blueapi/worker/task_worker.py | 18 +- src/blueapi/worker/tiled.py | 16 +- src/script/stomp_config.yml | 15 +- 7 files changed, 316 insertions(+), 321 deletions(-) diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 1055eea8b..10a1476b2 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -120,7 +120,9 @@ class OIDCConfig(BlueapiBaseModel): description="URL to fetch OIDC config from the provider" ) client_id: str = Field(description="Client ID") - client_audience: str = Field(description="Client Audience(s)", default="blueapi") + client_audience: str | list[str] | None = Field( + description="Client Audience(s)", default="blueapi" + ) @cached_property def _config_from_oidc_url(self) -> dict[str, Any]: diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index df3df8a25..2f0d41830 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -110,7 +110,7 @@ def decode_jwt(self, json_web_token: str): signing_key.key, algorithms=self._server_config.id_token_signing_alg_values_supported, verify=True, - audience=self._server_config.client_audience, + # audience=self._server_config.client_audience, issuer=self._server_config.issuer, ) @@ -169,6 +169,7 @@ def poll_for_token( "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, "client_id": self._server_config.client_id, + "client_secret": "secret", }, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) @@ -184,6 +185,7 @@ def start_device_flow(self): self._server_config.device_authorization_endpoint, data={ "client_id": self._server_config.client_id, + "client_secret": "secret", "scope": SCOPES, }, headers={"Content-Type": "application/x-www-form-urlencoded"}, diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 162064c14..2f66f738a 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -147,10 +147,10 @@ def clear_task(task_id: str) -> str: return worker().clear_task(task_id) -def begin_task(task: WorkerTask) -> WorkerTask: +def begin_task(task: WorkerTask, token: str | None) -> WorkerTask: """Trigger a task. Will fail if the worker is busy""" if task.task_id is not None: - worker().begin_task(task.task_id) + worker().begin_task(task.task_id, token=token) return task diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 37a41f30d..31d3ec274 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -47,95 +47,76 @@ REST_API_VERSION = "0.0.5" -RUNNER: WorkerDispatcher | None = None - CONTEXT_HEADER = "traceparent" +TRACER = get_tracer("interface") -def _runner() -> WorkerDispatcher: - """Intended to be used only with FastAPI Depends""" - if RUNNER is None: - raise ValueError() - return RUNNER - - -def setup_runner( - config: ApplicationConfig | None = None, - runner: WorkerDispatcher | None = None, -): - global RUNNER - runner = runner or WorkerDispatcher(config) +def setup_runner(config: ApplicationConfig | None = None) -> WorkerDispatcher: + runner = WorkerDispatcher(config) runner.start() + return runner - RUNNER = runner - -def teardown_runner(): - global RUNNER - if RUNNER is None: - return - RUNNER.stop() - RUNNER = None +def _bearer(config: OIDCConfig) -> OAuth2AuthorizationCodeBearer: + return OAuth2AuthorizationCodeBearer( + authorizationUrl=config.authorization_endpoint, + tokenUrl=config.token_endpoint, + refreshUrl=config.token_endpoint, + ) -def lifespan(config: ApplicationConfig): +def lifespan(runner: WorkerDispatcher): @asynccontextmanager async def inner(app: FastAPI): - setup_runner(config) yield - teardown_runner() + runner.stop() return inner -router = APIRouter() -auth_router = APIRouter() - - def get_app(config: ApplicationConfig): + runner = setup_runner(config) + app = FastAPI( docs_url="/docs", title="BlueAPI Control", - lifespan=lifespan(config), + lifespan=lifespan(runner), version=REST_API_VERSION, ) - dependencies = [] + if config.oidc: - dependencies = [Depends(verify_access_token(config.oidc))] - app.include_router(auth_router) - app.include_router(router, dependencies=dependencies) + bearer = _bearer(config.oidc) + dependencies = [Depends(verify_access_token(bearer, config.oidc))] + else: + bearer = None + dependencies = [] + app.include_router(get_auth_router(runner)) + app.include_router(get_api_router(runner, bearer), dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) app.middleware("http")(add_api_version_header) app.middleware("http")(inject_propagated_observability_context) + return app -def verify_access_token(config: OIDCConfig): +def verify_access_token(bearer: OAuth2AuthorizationCodeBearer, config: OIDCConfig): jwkclient = jwt.PyJWKClient(config.jwks_uri) - oauth_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=config.authorization_endpoint, - tokenUrl=config.token_endpoint, - refreshUrl=config.token_endpoint, - ) - def inner(access_token: str = Depends(oauth_scheme)): + def inner(access_token: str = Depends(bearer)): signing_key = jwkclient.get_signing_key_from_jwt(access_token) jwt.decode( access_token, signing_key.key, algorithms=config.id_token_signing_alg_values_supported, verify=True, - audience=config.client_audience, + # audience=config.client_audience, issuer=config.issuer, ) return inner -TRACER = get_tracer("interface") - - async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -151,273 +132,267 @@ async def on_token_error_401(_: Request, __: Exception): ) -@router.get("/environment", response_model=EnvironmentResponse) -@start_as_current_span(TRACER, "runner") -def get_environment( - runner: WorkerDispatcher = Depends(_runner), -) -> EnvironmentResponse: - """Get the current state of the environment, i.e. initialization state.""" - return runner.state - - -@router.delete("/environment", response_model=EnvironmentResponse) -async def delete_environment( - background_tasks: BackgroundTasks, - runner: WorkerDispatcher = Depends(_runner), -) -> EnvironmentResponse: - """Delete the current environment, causing internal components to be reloaded.""" - - if runner.state.initialized or runner.state.error_message is not None: - background_tasks.add_task(runner.reload) - return EnvironmentResponse(initialized=False) - - -@auth_router.get("/config/oidc", tags=["auth"], response_model=OIDCConfig) -@start_as_current_span(TRACER) -def get_oidc_config(runner: WorkerDispatcher = Depends(_runner)) -> OIDCConfig | None: - """Retrieve the OpenID Connect (OIDC) configuration for the server.""" - return runner.run(interface.get_oidc_config) - - -@router.get("/plans", response_model=PlanResponse) -@start_as_current_span(TRACER) -def get_plans(runner: WorkerDispatcher = Depends(_runner)): - """Retrieve information about all available plans.""" - plans = runner.run(interface.get_plans) - return PlanResponse(plans=plans) - - -@router.get( - "/plans/{name}", - response_model=PlanModel, -) -@start_as_current_span(TRACER, "name") -def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)): - """Retrieve information about a plan by its (unique) name.""" - return runner.run(interface.get_plan, name) - - -@router.get("/devices", response_model=DeviceResponse) -@start_as_current_span(TRACER) -def get_devices(runner: WorkerDispatcher = Depends(_runner)): - """Retrieve information about all available devices.""" - devices = runner.run(interface.get_devices) - return DeviceResponse(devices=devices) +def get_auth_router(runner: WorkerDispatcher): + auth_router = APIRouter() + @auth_router.get("/config/oidc", tags=["auth"], response_model=OIDCConfig) + @start_as_current_span(TRACER) + def get_oidc_config() -> OIDCConfig | None: + """Retrieve the OpenID Connect (OIDC) configuration for the server.""" + return runner.run(interface.get_oidc_config) -@router.get( - "/devices/{name}", - response_model=DeviceModel, -) -@start_as_current_span(TRACER, "name") -def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)): - """Retrieve information about a devices by its (unique) name.""" - return runner.run(interface.get_device, name) + return auth_router -example_task = Task(name="count", params={"detectors": ["x"]}) +def get_api_router( + runner: WorkerDispatcher, bearer: OAuth2AuthorizationCodeBearer | None +): + router = APIRouter() + + @router.get("/environment", response_model=EnvironmentResponse) + @start_as_current_span(TRACER, "runner") + def get_environment() -> EnvironmentResponse: + """Get the current state of the environment, i.e. initialization state.""" + return runner.state + + @router.delete("/environment", response_model=EnvironmentResponse) + async def delete_environment( + background_tasks: BackgroundTasks, + ) -> EnvironmentResponse: + """Delete the current environment, + causing internal components to be reloaded.""" + + if runner.state.initialized or runner.state.error_message is not None: + background_tasks.add_task(runner.reload) + return EnvironmentResponse(initialized=False) + + @router.get("/plans", response_model=PlanResponse) + @start_as_current_span(TRACER) + def get_plans(): + """Retrieve information about all available plans.""" + plans = runner.run(interface.get_plans) + return PlanResponse(plans=plans) + + @router.get( + "/plans/{name}", + response_model=PlanModel, + ) + @start_as_current_span(TRACER, "name") + def get_plan_by_name(name: str): + """Retrieve information about a plan by its (unique) name.""" + return runner.run(interface.get_plan, name) + + @router.get("/devices", response_model=DeviceResponse) + @start_as_current_span(TRACER) + def get_devices(): + """Retrieve information about all available devices.""" + devices = runner.run(interface.get_devices) + return DeviceResponse(devices=devices) + + @router.get( + "/devices/{name}", + response_model=DeviceModel, + ) + @start_as_current_span(TRACER, "name") + def get_device_by_name(name: str): + """Retrieve information about a devices by its (unique) name.""" + return runner.run(interface.get_device, name) + example_task = Task(name="count", params={"detectors": ["x"]}) -@router.post( - "/tasks", - response_model=TaskResponse, - status_code=status.HTTP_201_CREATED, -) -@start_as_current_span(TRACER, "request", "task.name", "task.params") -def submit_task( - request: Request, - response: Response, - task: Task = Body(..., example=example_task), - runner: WorkerDispatcher = Depends(_runner), -): - """Submit a task to the worker.""" - plan_model = runner.run(interface.get_plan, task.name) - try: - task_id: str = runner.run(interface.submit_task, task) - response.headers["Location"] = f"{request.url}/{task_id}" - return TaskResponse(task_id=task_id) - except ValidationError as e: - errors = e.errors() - formatted_errors = "; ".join( - [f"{err['loc'][0]}: {err['msg']}" for err in errors] - ) - error_detail_response = f""" - Input validation failed: {formatted_errors}, - supplied params {task.params}, - do not match the expected params: {plan_model.parameter_schema} - """ - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=error_detail_response, - ) from e - - -@router.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK) -@start_as_current_span(TRACER, "task_id") -def delete_submitted_task( - task_id: str, - runner: WorkerDispatcher = Depends(_runner), -) -> TaskResponse: - return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) - - -@start_as_current_span(TRACER, "v") -def validate_task_status(v: str) -> TaskStatusEnum: - v_upper = v.upper() - if v_upper not in TaskStatusEnum.__members__: - raise ValueError("Invalid status query parameter") - return TaskStatusEnum(v_upper) - - -@router.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK) -@start_as_current_span(TRACER) -def get_tasks( - task_status: str | None = None, - runner: WorkerDispatcher = Depends(_runner), -) -> TasksListResponse: - """ - Retrieve tasks based on their status. - The status of a newly created task is 'unstarted'. - """ - tasks = [] - if task_status: - add_span_attributes({"status": task_status}) + @router.post( + "/tasks", + response_model=TaskResponse, + status_code=status.HTTP_201_CREATED, + ) + @start_as_current_span(TRACER, "request", "task.name", "task.params") + def submit_task( + request: Request, + response: Response, + task: Task = Body(..., example=example_task), + ): + """Submit a task to the worker.""" + plan_model = runner.run(interface.get_plan, task.name) try: - desired_status = validate_task_status(task_status) - except ValueError as e: + task_id: str = runner.run(interface.submit_task, task) + response.headers["Location"] = f"{request.url}/{task_id}" + return TaskResponse(task_id=task_id) + except ValidationError as e: + errors = e.errors() + formatted_errors = "; ".join( + [f"{err['loc'][0]}: {err['msg']}" for err in errors] + ) + error_detail_response = f""" + Input validation failed: {formatted_errors}, + supplied params {task.params}, + do not match the expected params: {plan_model.parameter_schema} + """ raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid status query parameter", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=error_detail_response, ) from e - tasks = runner.run(interface.get_tasks_by_status, desired_status) - else: - tasks = runner.run(interface.get_tasks) - return TasksListResponse(tasks=tasks) - - -@router.put( - "/worker/task", - response_model=WorkerTask, - responses={status.HTTP_409_CONFLICT: {"worker": "already active"}}, -) -@start_as_current_span(TRACER, "task.task_id") -def set_active_task( - task: WorkerTask, - runner: WorkerDispatcher = Depends(_runner), -) -> WorkerTask: - """Set a task to active status, the worker should begin it as soon as possible. - This will return an error response if the worker is not idle.""" - active_task = runner.run(interface.get_active_task) - if active_task is not None and not active_task.is_complete: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, detail="Worker already active" - ) - runner.run(interface.begin_task, task) - return task + @router.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK) + @start_as_current_span(TRACER, "task_id") + def delete_submitted_task(task_id: str) -> TaskResponse: + return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) + @start_as_current_span(TRACER, "v") + def validate_task_status(v: str) -> TaskStatusEnum: + v_upper = v.upper() + if v_upper not in TaskStatusEnum.__members__: + raise ValueError("Invalid status query parameter") + return TaskStatusEnum(v_upper) -@router.get( - "/tasks/{task_id}", - response_model=TrackableTask, -) -@start_as_current_span(TRACER, "task_id") -def get_task( - task_id: str, - runner: WorkerDispatcher = Depends(_runner), -) -> TrackableTask: - """Retrieve a task""" - task = runner.run(interface.get_task_by_id, task_id) - if task is None: - raise KeyError - return task - - -@router.get("/worker/task") -@start_as_current_span(TRACER) -def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask: - active = runner.run(interface.get_active_task) - task_id = active.task_id if active is not None else None - return WorkerTask(task_id=task_id) - - -@router.get("/worker/state") -@start_as_current_span(TRACER) -def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState: - """Get the State of the Worker""" - return runner.run(interface.get_worker_state) - - -# Map of current_state: allowed new_states -_ALLOWED_TRANSITIONS: dict[WorkerState, set[WorkerState]] = { - WorkerState.RUNNING: { - WorkerState.PAUSED, - WorkerState.ABORTING, - WorkerState.STOPPING, - }, - WorkerState.PAUSED: { - WorkerState.RUNNING, - WorkerState.ABORTING, - WorkerState.STOPPING, - }, -} - - -@router.put( - "/worker/state", - status_code=status.HTTP_202_ACCEPTED, - responses={ - status.HTTP_400_BAD_REQUEST: {"detail": "Transition not allowed"}, - status.HTTP_202_ACCEPTED: {"detail": "Transition requested"}, - }, -) -@start_as_current_span(TRACER, "state_change_request.new_state") -def set_state( - state_change_request: StateChangeRequest, - response: Response, - runner: WorkerDispatcher = Depends(_runner), -) -> WorkerState: - """ - Request that the worker is put into a particular state. - Returns the state of the worker at the end of the call. - - - **The following transitions are allowed and return 202: Accepted** - - If the worker is **PAUSED**, new_state may be **RUNNING** to resume. - - If the worker is **RUNNING**, new_state may be **PAUSED** to pause: - - If defer is False (default): pauses and rewinds to the previous checkpoint - - If defer is True: waits until the next checkpoint to pause - - **If the task has no checkpoints, the task will instead be Aborted** - - If the worker is **RUNNING/PAUSED**, new_state may be **STOPPING** to stop. - Stop marks any currently open Runs in the Task as a success and ends the task. - - If the worker is **RUNNING/PAUSED**, new_state may be **ABORTING** to abort. - Abort marks any currently open Runs in the Task as a Failure and ends the task. - - If reason is set, the reason will be passed as the reason for the Run failure. - - **All other transitions return 400: Bad Request** - """ - current_state = runner.run(interface.get_worker_state) - new_state = state_change_request.new_state - add_span_attributes({"current_state": current_state}) - if ( - current_state in _ALLOWED_TRANSITIONS - and new_state in _ALLOWED_TRANSITIONS[current_state] - ): - if new_state == WorkerState.PAUSED: - runner.run(interface.pause_worker, state_change_request.defer) - elif new_state == WorkerState.RUNNING: - runner.run(interface.resume_worker) - elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: + @router.get( + "/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK + ) + @start_as_current_span(TRACER) + def get_tasks( + task_status: str | None = None, + ) -> TasksListResponse: + """ + Retrieve tasks based on their status. + The status of a newly created task is 'unstarted'. + """ + tasks = [] + if task_status: + add_span_attributes({"status": task_status}) try: - runner.run( - interface.cancel_active_task, - state_change_request.new_state is WorkerState.ABORTING, - state_change_request.reason, - ) - except TransitionError: - response.status_code = status.HTTP_400_BAD_REQUEST - else: - response.status_code = status.HTTP_400_BAD_REQUEST - - return runner.run(interface.get_worker_state) + desired_status = validate_task_status(task_status) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid status query parameter", + ) from e + + tasks = runner.run(interface.get_tasks_by_status, desired_status) + else: + tasks = runner.run(interface.get_tasks) + return TasksListResponse(tasks=tasks) + + @router.put( + "/worker/task", + response_model=WorkerTask, + responses={status.HTTP_409_CONFLICT: {"worker": "already active"}}, + ) + @start_as_current_span(TRACER, "task.task_id") + def set_active_task( + task: WorkerTask, + token: str = Depends(bearer) if bearer else Depends(lambda: None), + ) -> WorkerTask: + """Set a task to active status, the worker should begin it as soon as possible. + This will return an error response if the worker is not idle.""" + active_task = runner.run(interface.get_active_task) + if active_task is not None and not active_task.is_complete: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail="Worker already active" + ) + runner.run(interface.begin_task, task, token) + return task + + @router.get( + "/tasks/{task_id}", + response_model=TrackableTask, + ) + @start_as_current_span(TRACER, "task_id") + def get_task( + task_id: str, + ) -> TrackableTask: + """Retrieve a task""" + task = runner.run(interface.get_task_by_id, task_id) + if task is None: + raise KeyError + return task + + @router.get("/worker/task") + @start_as_current_span(TRACER) + def get_active_task() -> WorkerTask: + active = runner.run(interface.get_active_task) + task_id = active.task_id if active is not None else None + return WorkerTask(task_id=task_id) + + @router.get("/worker/state") + @start_as_current_span(TRACER) + def get_state() -> WorkerState: + """Get the State of the Worker""" + return runner.run(interface.get_worker_state) + + # Map of current_state: allowed new_states + _ALLOWED_TRANSITIONS: dict[WorkerState, set[WorkerState]] = { + WorkerState.RUNNING: { + WorkerState.PAUSED, + WorkerState.ABORTING, + WorkerState.STOPPING, + }, + WorkerState.PAUSED: { + WorkerState.RUNNING, + WorkerState.ABORTING, + WorkerState.STOPPING, + }, + } + + @router.put( + "/worker/state", + status_code=status.HTTP_202_ACCEPTED, + responses={ + status.HTTP_400_BAD_REQUEST: {"detail": "Transition not allowed"}, + status.HTTP_202_ACCEPTED: {"detail": "Transition requested"}, + }, + ) + @start_as_current_span(TRACER, "state_change_request.new_state") + def set_state( + state_change_request: StateChangeRequest, + response: Response, + ) -> WorkerState: + """ + Request that the worker is put into a particular state. + Returns the state of the worker at the end of the call. + + - **The following transitions are allowed and return 202: Accepted** + - If the worker is **PAUSED**, new_state may be **RUNNING** to resume. + - If the worker is **RUNNING**, new_state may be **PAUSED** to pause: + - If defer is False (default): pauses and rewinds to the previous checkpoint + - If defer is True: waits until the next checkpoint to pause + - **If the task has no checkpoints, the task will instead be Aborted** + - If the worker is **RUNNING/PAUSED**, + new_state may be **STOPPING** to stop. + Stop marks any currently open Runs + in the Task as a success and ends the task. + - If the worker is **RUNNING/PAUSED**, + new_state may be **ABORTING** to abort. + Abort marks any currently open Runs in the Task + as a Failure and ends the task. + - If reason is set, the reason will be passed + as the reason for the Run failure. + - **All other transitions return 400: Bad Request** + """ + current_state = runner.run(interface.get_worker_state) + new_state = state_change_request.new_state + add_span_attributes({"current_state": current_state}) + if ( + current_state in _ALLOWED_TRANSITIONS + and new_state in _ALLOWED_TRANSITIONS[current_state] + ): + if new_state == WorkerState.PAUSED: + runner.run(interface.pause_worker, state_change_request.defer) + elif new_state == WorkerState.RUNNING: + runner.run(interface.resume_worker) + elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: + try: + runner.run( + interface.cancel_active_task, + state_change_request.new_state is WorkerState.ABORTING, + state_change_request.reason, + ) + except TransitionError: + response.status_code = status.HTTP_400_BAD_REQUEST + else: + response.status_code = status.HTTP_400_BAD_REQUEST + + return runner.run(interface.get_worker_state) + + return router @start_as_current_span(TRACER, "config") @@ -440,7 +415,6 @@ def start(config: ApplicationConfig): http_capture_headers_server_request=[",*"], http_capture_headers_server_response=[",*"], ) - app.state.config = config uvicorn.run(app, host=config.api.host, port=config.api.port) diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 67c22119c..e43816c05 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -9,7 +9,6 @@ from typing import Any, Generic, TypeVar from bluesky.protocols import Status -from httpx import Headers from observability_utils.tracing import ( add_span_attributes, get_tracer, @@ -198,24 +197,24 @@ def get_active_task(self) -> TrackableTask[Task] | None: return current @start_as_current_span(TRACER, "task_id") - def begin_task(self, task_id: str, headers: Headers | None) -> None: + def begin_task(self, task_id: str, token: str | None) -> None: task = self._tasks.get(task_id) data_subs: list[int] = [] if task is not None: if self._tiled_inserter: - data_subs.append(self._authorize_running_task(headers)) + data_subs.append(self._authorize_running_task(token)) self._submit_trackable_task(task, data_subs) else: raise KeyError(f"No pending task with ID {task_id}") - def _authorize_running_task(self, headers: Headers | None) -> int: + def _authorize_running_task(self, token: str | None) -> int: assert self._tiled_inserter # https://github.com/DiamondLightSource/blueapi/issues/774 # If users should only be able to run their own scans, pass headers # as part of submitting a task, cache in TrackableTask field and check # that token belongs to same user (but may be newer token!) - return self.data_events.subscribe(self._tiled_inserter(headers)) + return self.data_events.subscribe(self._tiled_inserter(token)) @start_as_current_span(TRACER, "task.name", "task.params") def submit_task(self, task: Task) -> str: @@ -249,8 +248,14 @@ def mark_task_as_started(event: WorkerEvent, _: str | None) -> None: ): task_started.set() + def unsubscribe_watchers(event: WorkerEvent, _: str | None) -> None: + if event.task_status and event.task_status.task_complete and data_subs: + for data_sub in data_subs: + self.data_events.unsubscribe(data_sub) + LOGGER.info(f"Submitting: {trackable_task}") sub = self.worker_events.subscribe(mark_task_as_started) + self.worker_events.subscribe(unsubscribe_watchers) try: self._current_task_otel_context = get_current() """ Cache the current trace context as the one for this task id """ @@ -262,9 +267,6 @@ def mark_task_as_started(event: WorkerEvent, _: str | None) -> None: raise WorkerBusyError("Cannot submit task while another is running") from f finally: self.worker_events.unsubscribe(sub) - if data_subs: - for data_sub in data_subs: - self.data_events.unsubscribe(data_sub) @start_as_current_span(TRACER) def start(self) -> None: diff --git a/src/blueapi/worker/tiled.py b/src/blueapi/worker/tiled.py index 816f0156a..e56a4826c 100644 --- a/src/blueapi/worker/tiled.py +++ b/src/blueapi/worker/tiled.py @@ -1,15 +1,15 @@ +from typing import Any + from bluesky.callbacks.tiled_writer import TiledWriter -from httpx import Headers -from tiled.client import from_context -from tiled.client.context import Context as TiledContext +from tiled.client import from_uri from blueapi.config import TiledConfig from blueapi.core.bluesky_types import DataEvent class TiledConverter: - def __init__(self, tiled_context: TiledContext): - self._writer: TiledWriter = TiledWriter(from_context(tiled_context)) + def __init__(self, uri: str, headers: dict[str, Any]): + self._writer: TiledWriter = TiledWriter(from_uri(uri, headers=headers)) def __call__(self, data: DataEvent, _: str | None = None) -> None: self._writer(data.name, data.doc) @@ -19,5 +19,7 @@ class TiledConnection: def __init__(self, config: TiledConfig): self.uri = f"{config.host}:{config.port}" - def __call__(self, headers: Headers | None): - return TiledConverter(TiledContext(self.uri, headers=headers)) + def __call__(self, token: str | None): + return TiledConverter( + self.uri, headers={"Authorization": f"Bearer {token}"} if token else {} + ) diff --git a/src/script/stomp_config.yml b/src/script/stomp_config.yml index 99ff4ad63..b490ccdc4 100644 --- a/src/script/stomp_config.yml +++ b/src/script/stomp_config.yml @@ -1,7 +1,20 @@ ---- stomp: host: "localhost" port: 61613 auth: username: "guest" password: "guest" +api: + port: 3000 +tiled: + host: "http://localhost" + port: 4000 +oidc: + well_known_url: http://localhost:8080/foo/.well-known/openid-configuration + client_id: blueapi +env: + sources: + - kind: dodal + module: dodal.beamlines.i22 + - kind: planFunctions + module: dodal.plans