diff --git a/pyproject.toml b/pyproject.toml index dde5b0816..c2d7fa354 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,10 @@ classifiers = [ ] description = "Lightweight bluesky-as-a-service wrapper application. Also usable as a library." dependencies = [ + "tiled", + "json_merge_patch", + "jsonpatch", + "pyarrow", "bluesky>=1.13", "ophyd", "nslsii", @@ -95,7 +99,8 @@ addopts = """ --ignore=src/blueapi/startup """ # https://iscinumpy.gitlab.io/post/bound-version-constraints/#watch-for-warnings -filterwarnings = ["error", "ignore::DeprecationWarning"] +# Unignore UserWarning after Pydantic warning removed from bluesky/bluesky and release +filterwarnings = ["error", "ignore::DeprecationWarning", "ignore::UserWarning"] # Doctest python code in docs, python code in src docstrings, test functions in tests testpaths = "docs src tests" asyncio_mode = "auto" diff --git a/src/blueapi/config.py b/src/blueapi/config.py index e4581a663..10a1476b2 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -41,6 +41,15 @@ class StompConfig(BaseModel): auth: BasicAuthentication | None = None +class TiledConfig(BaseModel): + """ + Config for connecting to a tiled instance + """ + + host: str + port: int + + class WorkerEventConfig(BlueapiBaseModel): """ Config for event broadcasting via the message bus @@ -111,7 +120,9 @@ class OIDCConfig(BlueapiBaseModel): description="URL to fetch OIDC config from the provider" ) client_id: str = Field(description="Client ID") - client_audience: str = Field(description="Client Audience(s)", default="blueapi") + client_audience: str | list[str] | None = Field( + description="Client Audience(s)", default="blueapi" + ) @cached_property def _config_from_oidc_url(self) -> dict[str, Any]: @@ -160,6 +171,7 @@ class ApplicationConfig(BlueapiBaseModel): """ stomp: StompConfig | None = None + tiled: TiledConfig | None = None env: EnvironmentConfig = Field(default_factory=EnvironmentConfig) logging: LoggingConfig = Field(default_factory=LoggingConfig) api: RestConfig = Field(default_factory=RestConfig) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index df3df8a25..2f0d41830 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -110,7 +110,7 @@ def decode_jwt(self, json_web_token: str): signing_key.key, algorithms=self._server_config.id_token_signing_alg_values_supported, verify=True, - audience=self._server_config.client_audience, + # audience=self._server_config.client_audience, issuer=self._server_config.issuer, ) @@ -169,6 +169,7 @@ def poll_for_token( "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, "client_id": self._server_config.client_id, + "client_secret": "secret", }, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) @@ -184,6 +185,7 @@ def start_device_flow(self): self._server_config.device_authorization_endpoint, data={ "client_id": self._server_config.client_id, + "client_secret": "secret", "scope": SCOPES, }, headers={"Content-Type": "application/x-www-form-urlencoded"}, diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 803841964..2f66f738a 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -13,6 +13,7 @@ from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TaskWorker, TrackableTask +from blueapi.worker.tiled import TiledConnection """This module provides interface between web application and underlying Bluesky context and worker""" @@ -40,9 +41,11 @@ def context() -> BlueskyContext: @cache def worker() -> TaskWorker: + conf = config() worker = TaskWorker( context(), - broadcast_statuses=config().env.events.broadcast_status_events, + broadcast_statuses=conf.env.events.broadcast_status_events, + tiled_inserter=TiledConnection(conf.tiled) if conf.tiled else None, ) worker.start() return worker @@ -144,10 +147,10 @@ def clear_task(task_id: str) -> str: return worker().clear_task(task_id) -def begin_task(task: WorkerTask) -> WorkerTask: +def begin_task(task: WorkerTask, token: str | None) -> WorkerTask: """Trigger a task. Will fail if the worker is busy""" if task.task_id is not None: - worker().begin_task(task.task_id) + worker().begin_task(task.task_id, token=token) return task diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 37a41f30d..31d3ec274 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -47,95 +47,76 @@ REST_API_VERSION = "0.0.5" -RUNNER: WorkerDispatcher | None = None - CONTEXT_HEADER = "traceparent" +TRACER = get_tracer("interface") -def _runner() -> WorkerDispatcher: - """Intended to be used only with FastAPI Depends""" - if RUNNER is None: - raise ValueError() - return RUNNER - - -def setup_runner( - config: ApplicationConfig | None = None, - runner: WorkerDispatcher | None = None, -): - global RUNNER - runner = runner or WorkerDispatcher(config) +def setup_runner(config: ApplicationConfig | None = None) -> WorkerDispatcher: + runner = WorkerDispatcher(config) runner.start() + return runner - RUNNER = runner - -def teardown_runner(): - global RUNNER - if RUNNER is None: - return - RUNNER.stop() - RUNNER = None +def _bearer(config: OIDCConfig) -> OAuth2AuthorizationCodeBearer: + return OAuth2AuthorizationCodeBearer( + authorizationUrl=config.authorization_endpoint, + tokenUrl=config.token_endpoint, + refreshUrl=config.token_endpoint, + ) -def lifespan(config: ApplicationConfig): +def lifespan(runner: WorkerDispatcher): @asynccontextmanager async def inner(app: FastAPI): - setup_runner(config) yield - teardown_runner() + runner.stop() return inner -router = APIRouter() -auth_router = APIRouter() - - def get_app(config: ApplicationConfig): + runner = setup_runner(config) + app = FastAPI( docs_url="/docs", title="BlueAPI Control", - lifespan=lifespan(config), + lifespan=lifespan(runner), version=REST_API_VERSION, ) - dependencies = [] + if config.oidc: - dependencies = [Depends(verify_access_token(config.oidc))] - app.include_router(auth_router) - app.include_router(router, dependencies=dependencies) + bearer = _bearer(config.oidc) + dependencies = [Depends(verify_access_token(bearer, config.oidc))] + else: + bearer = None + dependencies = [] + app.include_router(get_auth_router(runner)) + app.include_router(get_api_router(runner, bearer), dependencies=dependencies) app.add_exception_handler(KeyError, on_key_error_404) app.add_exception_handler(jwt.PyJWTError, on_token_error_401) app.middleware("http")(add_api_version_header) app.middleware("http")(inject_propagated_observability_context) + return app -def verify_access_token(config: OIDCConfig): +def verify_access_token(bearer: OAuth2AuthorizationCodeBearer, config: OIDCConfig): jwkclient = jwt.PyJWKClient(config.jwks_uri) - oauth_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=config.authorization_endpoint, - tokenUrl=config.token_endpoint, - refreshUrl=config.token_endpoint, - ) - def inner(access_token: str = Depends(oauth_scheme)): + def inner(access_token: str = Depends(bearer)): signing_key = jwkclient.get_signing_key_from_jwt(access_token) jwt.decode( access_token, signing_key.key, algorithms=config.id_token_signing_alg_values_supported, verify=True, - audience=config.client_audience, + # audience=config.client_audience, issuer=config.issuer, ) return inner -TRACER = get_tracer("interface") - - async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -151,273 +132,267 @@ async def on_token_error_401(_: Request, __: Exception): ) -@router.get("/environment", response_model=EnvironmentResponse) -@start_as_current_span(TRACER, "runner") -def get_environment( - runner: WorkerDispatcher = Depends(_runner), -) -> EnvironmentResponse: - """Get the current state of the environment, i.e. initialization state.""" - return runner.state - - -@router.delete("/environment", response_model=EnvironmentResponse) -async def delete_environment( - background_tasks: BackgroundTasks, - runner: WorkerDispatcher = Depends(_runner), -) -> EnvironmentResponse: - """Delete the current environment, causing internal components to be reloaded.""" - - if runner.state.initialized or runner.state.error_message is not None: - background_tasks.add_task(runner.reload) - return EnvironmentResponse(initialized=False) - - -@auth_router.get("/config/oidc", tags=["auth"], response_model=OIDCConfig) -@start_as_current_span(TRACER) -def get_oidc_config(runner: WorkerDispatcher = Depends(_runner)) -> OIDCConfig | None: - """Retrieve the OpenID Connect (OIDC) configuration for the server.""" - return runner.run(interface.get_oidc_config) - - -@router.get("/plans", response_model=PlanResponse) -@start_as_current_span(TRACER) -def get_plans(runner: WorkerDispatcher = Depends(_runner)): - """Retrieve information about all available plans.""" - plans = runner.run(interface.get_plans) - return PlanResponse(plans=plans) - - -@router.get( - "/plans/{name}", - response_model=PlanModel, -) -@start_as_current_span(TRACER, "name") -def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)): - """Retrieve information about a plan by its (unique) name.""" - return runner.run(interface.get_plan, name) - - -@router.get("/devices", response_model=DeviceResponse) -@start_as_current_span(TRACER) -def get_devices(runner: WorkerDispatcher = Depends(_runner)): - """Retrieve information about all available devices.""" - devices = runner.run(interface.get_devices) - return DeviceResponse(devices=devices) +def get_auth_router(runner: WorkerDispatcher): + auth_router = APIRouter() + @auth_router.get("/config/oidc", tags=["auth"], response_model=OIDCConfig) + @start_as_current_span(TRACER) + def get_oidc_config() -> OIDCConfig | None: + """Retrieve the OpenID Connect (OIDC) configuration for the server.""" + return runner.run(interface.get_oidc_config) -@router.get( - "/devices/{name}", - response_model=DeviceModel, -) -@start_as_current_span(TRACER, "name") -def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)): - """Retrieve information about a devices by its (unique) name.""" - return runner.run(interface.get_device, name) + return auth_router -example_task = Task(name="count", params={"detectors": ["x"]}) +def get_api_router( + runner: WorkerDispatcher, bearer: OAuth2AuthorizationCodeBearer | None +): + router = APIRouter() + + @router.get("/environment", response_model=EnvironmentResponse) + @start_as_current_span(TRACER, "runner") + def get_environment() -> EnvironmentResponse: + """Get the current state of the environment, i.e. initialization state.""" + return runner.state + + @router.delete("/environment", response_model=EnvironmentResponse) + async def delete_environment( + background_tasks: BackgroundTasks, + ) -> EnvironmentResponse: + """Delete the current environment, + causing internal components to be reloaded.""" + + if runner.state.initialized or runner.state.error_message is not None: + background_tasks.add_task(runner.reload) + return EnvironmentResponse(initialized=False) + + @router.get("/plans", response_model=PlanResponse) + @start_as_current_span(TRACER) + def get_plans(): + """Retrieve information about all available plans.""" + plans = runner.run(interface.get_plans) + return PlanResponse(plans=plans) + + @router.get( + "/plans/{name}", + response_model=PlanModel, + ) + @start_as_current_span(TRACER, "name") + def get_plan_by_name(name: str): + """Retrieve information about a plan by its (unique) name.""" + return runner.run(interface.get_plan, name) + + @router.get("/devices", response_model=DeviceResponse) + @start_as_current_span(TRACER) + def get_devices(): + """Retrieve information about all available devices.""" + devices = runner.run(interface.get_devices) + return DeviceResponse(devices=devices) + + @router.get( + "/devices/{name}", + response_model=DeviceModel, + ) + @start_as_current_span(TRACER, "name") + def get_device_by_name(name: str): + """Retrieve information about a devices by its (unique) name.""" + return runner.run(interface.get_device, name) + example_task = Task(name="count", params={"detectors": ["x"]}) -@router.post( - "/tasks", - response_model=TaskResponse, - status_code=status.HTTP_201_CREATED, -) -@start_as_current_span(TRACER, "request", "task.name", "task.params") -def submit_task( - request: Request, - response: Response, - task: Task = Body(..., example=example_task), - runner: WorkerDispatcher = Depends(_runner), -): - """Submit a task to the worker.""" - plan_model = runner.run(interface.get_plan, task.name) - try: - task_id: str = runner.run(interface.submit_task, task) - response.headers["Location"] = f"{request.url}/{task_id}" - return TaskResponse(task_id=task_id) - except ValidationError as e: - errors = e.errors() - formatted_errors = "; ".join( - [f"{err['loc'][0]}: {err['msg']}" for err in errors] - ) - error_detail_response = f""" - Input validation failed: {formatted_errors}, - supplied params {task.params}, - do not match the expected params: {plan_model.parameter_schema} - """ - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=error_detail_response, - ) from e - - -@router.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK) -@start_as_current_span(TRACER, "task_id") -def delete_submitted_task( - task_id: str, - runner: WorkerDispatcher = Depends(_runner), -) -> TaskResponse: - return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) - - -@start_as_current_span(TRACER, "v") -def validate_task_status(v: str) -> TaskStatusEnum: - v_upper = v.upper() - if v_upper not in TaskStatusEnum.__members__: - raise ValueError("Invalid status query parameter") - return TaskStatusEnum(v_upper) - - -@router.get("/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK) -@start_as_current_span(TRACER) -def get_tasks( - task_status: str | None = None, - runner: WorkerDispatcher = Depends(_runner), -) -> TasksListResponse: - """ - Retrieve tasks based on their status. - The status of a newly created task is 'unstarted'. - """ - tasks = [] - if task_status: - add_span_attributes({"status": task_status}) + @router.post( + "/tasks", + response_model=TaskResponse, + status_code=status.HTTP_201_CREATED, + ) + @start_as_current_span(TRACER, "request", "task.name", "task.params") + def submit_task( + request: Request, + response: Response, + task: Task = Body(..., example=example_task), + ): + """Submit a task to the worker.""" + plan_model = runner.run(interface.get_plan, task.name) try: - desired_status = validate_task_status(task_status) - except ValueError as e: + task_id: str = runner.run(interface.submit_task, task) + response.headers["Location"] = f"{request.url}/{task_id}" + return TaskResponse(task_id=task_id) + except ValidationError as e: + errors = e.errors() + formatted_errors = "; ".join( + [f"{err['loc'][0]}: {err['msg']}" for err in errors] + ) + error_detail_response = f""" + Input validation failed: {formatted_errors}, + supplied params {task.params}, + do not match the expected params: {plan_model.parameter_schema} + """ raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid status query parameter", + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=error_detail_response, ) from e - tasks = runner.run(interface.get_tasks_by_status, desired_status) - else: - tasks = runner.run(interface.get_tasks) - return TasksListResponse(tasks=tasks) - - -@router.put( - "/worker/task", - response_model=WorkerTask, - responses={status.HTTP_409_CONFLICT: {"worker": "already active"}}, -) -@start_as_current_span(TRACER, "task.task_id") -def set_active_task( - task: WorkerTask, - runner: WorkerDispatcher = Depends(_runner), -) -> WorkerTask: - """Set a task to active status, the worker should begin it as soon as possible. - This will return an error response if the worker is not idle.""" - active_task = runner.run(interface.get_active_task) - if active_task is not None and not active_task.is_complete: - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, detail="Worker already active" - ) - runner.run(interface.begin_task, task) - return task + @router.delete("/tasks/{task_id}", status_code=status.HTTP_200_OK) + @start_as_current_span(TRACER, "task_id") + def delete_submitted_task(task_id: str) -> TaskResponse: + return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) + @start_as_current_span(TRACER, "v") + def validate_task_status(v: str) -> TaskStatusEnum: + v_upper = v.upper() + if v_upper not in TaskStatusEnum.__members__: + raise ValueError("Invalid status query parameter") + return TaskStatusEnum(v_upper) -@router.get( - "/tasks/{task_id}", - response_model=TrackableTask, -) -@start_as_current_span(TRACER, "task_id") -def get_task( - task_id: str, - runner: WorkerDispatcher = Depends(_runner), -) -> TrackableTask: - """Retrieve a task""" - task = runner.run(interface.get_task_by_id, task_id) - if task is None: - raise KeyError - return task - - -@router.get("/worker/task") -@start_as_current_span(TRACER) -def get_active_task(runner: WorkerDispatcher = Depends(_runner)) -> WorkerTask: - active = runner.run(interface.get_active_task) - task_id = active.task_id if active is not None else None - return WorkerTask(task_id=task_id) - - -@router.get("/worker/state") -@start_as_current_span(TRACER) -def get_state(runner: WorkerDispatcher = Depends(_runner)) -> WorkerState: - """Get the State of the Worker""" - return runner.run(interface.get_worker_state) - - -# Map of current_state: allowed new_states -_ALLOWED_TRANSITIONS: dict[WorkerState, set[WorkerState]] = { - WorkerState.RUNNING: { - WorkerState.PAUSED, - WorkerState.ABORTING, - WorkerState.STOPPING, - }, - WorkerState.PAUSED: { - WorkerState.RUNNING, - WorkerState.ABORTING, - WorkerState.STOPPING, - }, -} - - -@router.put( - "/worker/state", - status_code=status.HTTP_202_ACCEPTED, - responses={ - status.HTTP_400_BAD_REQUEST: {"detail": "Transition not allowed"}, - status.HTTP_202_ACCEPTED: {"detail": "Transition requested"}, - }, -) -@start_as_current_span(TRACER, "state_change_request.new_state") -def set_state( - state_change_request: StateChangeRequest, - response: Response, - runner: WorkerDispatcher = Depends(_runner), -) -> WorkerState: - """ - Request that the worker is put into a particular state. - Returns the state of the worker at the end of the call. - - - **The following transitions are allowed and return 202: Accepted** - - If the worker is **PAUSED**, new_state may be **RUNNING** to resume. - - If the worker is **RUNNING**, new_state may be **PAUSED** to pause: - - If defer is False (default): pauses and rewinds to the previous checkpoint - - If defer is True: waits until the next checkpoint to pause - - **If the task has no checkpoints, the task will instead be Aborted** - - If the worker is **RUNNING/PAUSED**, new_state may be **STOPPING** to stop. - Stop marks any currently open Runs in the Task as a success and ends the task. - - If the worker is **RUNNING/PAUSED**, new_state may be **ABORTING** to abort. - Abort marks any currently open Runs in the Task as a Failure and ends the task. - - If reason is set, the reason will be passed as the reason for the Run failure. - - **All other transitions return 400: Bad Request** - """ - current_state = runner.run(interface.get_worker_state) - new_state = state_change_request.new_state - add_span_attributes({"current_state": current_state}) - if ( - current_state in _ALLOWED_TRANSITIONS - and new_state in _ALLOWED_TRANSITIONS[current_state] - ): - if new_state == WorkerState.PAUSED: - runner.run(interface.pause_worker, state_change_request.defer) - elif new_state == WorkerState.RUNNING: - runner.run(interface.resume_worker) - elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: + @router.get( + "/tasks", response_model=TasksListResponse, status_code=status.HTTP_200_OK + ) + @start_as_current_span(TRACER) + def get_tasks( + task_status: str | None = None, + ) -> TasksListResponse: + """ + Retrieve tasks based on their status. + The status of a newly created task is 'unstarted'. + """ + tasks = [] + if task_status: + add_span_attributes({"status": task_status}) try: - runner.run( - interface.cancel_active_task, - state_change_request.new_state is WorkerState.ABORTING, - state_change_request.reason, - ) - except TransitionError: - response.status_code = status.HTTP_400_BAD_REQUEST - else: - response.status_code = status.HTTP_400_BAD_REQUEST - - return runner.run(interface.get_worker_state) + desired_status = validate_task_status(task_status) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid status query parameter", + ) from e + + tasks = runner.run(interface.get_tasks_by_status, desired_status) + else: + tasks = runner.run(interface.get_tasks) + return TasksListResponse(tasks=tasks) + + @router.put( + "/worker/task", + response_model=WorkerTask, + responses={status.HTTP_409_CONFLICT: {"worker": "already active"}}, + ) + @start_as_current_span(TRACER, "task.task_id") + def set_active_task( + task: WorkerTask, + token: str = Depends(bearer) if bearer else Depends(lambda: None), + ) -> WorkerTask: + """Set a task to active status, the worker should begin it as soon as possible. + This will return an error response if the worker is not idle.""" + active_task = runner.run(interface.get_active_task) + if active_task is not None and not active_task.is_complete: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, detail="Worker already active" + ) + runner.run(interface.begin_task, task, token) + return task + + @router.get( + "/tasks/{task_id}", + response_model=TrackableTask, + ) + @start_as_current_span(TRACER, "task_id") + def get_task( + task_id: str, + ) -> TrackableTask: + """Retrieve a task""" + task = runner.run(interface.get_task_by_id, task_id) + if task is None: + raise KeyError + return task + + @router.get("/worker/task") + @start_as_current_span(TRACER) + def get_active_task() -> WorkerTask: + active = runner.run(interface.get_active_task) + task_id = active.task_id if active is not None else None + return WorkerTask(task_id=task_id) + + @router.get("/worker/state") + @start_as_current_span(TRACER) + def get_state() -> WorkerState: + """Get the State of the Worker""" + return runner.run(interface.get_worker_state) + + # Map of current_state: allowed new_states + _ALLOWED_TRANSITIONS: dict[WorkerState, set[WorkerState]] = { + WorkerState.RUNNING: { + WorkerState.PAUSED, + WorkerState.ABORTING, + WorkerState.STOPPING, + }, + WorkerState.PAUSED: { + WorkerState.RUNNING, + WorkerState.ABORTING, + WorkerState.STOPPING, + }, + } + + @router.put( + "/worker/state", + status_code=status.HTTP_202_ACCEPTED, + responses={ + status.HTTP_400_BAD_REQUEST: {"detail": "Transition not allowed"}, + status.HTTP_202_ACCEPTED: {"detail": "Transition requested"}, + }, + ) + @start_as_current_span(TRACER, "state_change_request.new_state") + def set_state( + state_change_request: StateChangeRequest, + response: Response, + ) -> WorkerState: + """ + Request that the worker is put into a particular state. + Returns the state of the worker at the end of the call. + + - **The following transitions are allowed and return 202: Accepted** + - If the worker is **PAUSED**, new_state may be **RUNNING** to resume. + - If the worker is **RUNNING**, new_state may be **PAUSED** to pause: + - If defer is False (default): pauses and rewinds to the previous checkpoint + - If defer is True: waits until the next checkpoint to pause + - **If the task has no checkpoints, the task will instead be Aborted** + - If the worker is **RUNNING/PAUSED**, + new_state may be **STOPPING** to stop. + Stop marks any currently open Runs + in the Task as a success and ends the task. + - If the worker is **RUNNING/PAUSED**, + new_state may be **ABORTING** to abort. + Abort marks any currently open Runs in the Task + as a Failure and ends the task. + - If reason is set, the reason will be passed + as the reason for the Run failure. + - **All other transitions return 400: Bad Request** + """ + current_state = runner.run(interface.get_worker_state) + new_state = state_change_request.new_state + add_span_attributes({"current_state": current_state}) + if ( + current_state in _ALLOWED_TRANSITIONS + and new_state in _ALLOWED_TRANSITIONS[current_state] + ): + if new_state == WorkerState.PAUSED: + runner.run(interface.pause_worker, state_change_request.defer) + elif new_state == WorkerState.RUNNING: + runner.run(interface.resume_worker) + elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: + try: + runner.run( + interface.cancel_active_task, + state_change_request.new_state is WorkerState.ABORTING, + state_change_request.reason, + ) + except TransitionError: + response.status_code = status.HTTP_400_BAD_REQUEST + else: + response.status_code = status.HTTP_400_BAD_REQUEST + + return runner.run(interface.get_worker_state) + + return router @start_as_current_span(TRACER, "config") @@ -440,7 +415,6 @@ def start(config: ApplicationConfig): http_capture_headers_server_request=[",*"], http_capture_headers_server_response=[",*"], ) - app.state.config = config uvicorn.run(app, host=config.api.host, port=config.api.port) diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index 546c5f3b5..e43816c05 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -32,6 +32,7 @@ from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop from blueapi.utils.base_model import BlueapiBaseModel from blueapi.utils.thread_exception import handle_all_exceptions +from blueapi.worker.tiled import TiledConnection from .event import ( ProgressEvent, @@ -112,9 +113,11 @@ def __init__( ctx: BlueskyContext, start_stop_timeout: float = DEFAULT_START_STOP_TIMEOUT, broadcast_statuses: bool = True, + tiled_inserter: TiledConnection | None = None, ) -> None: self._ctx = ctx self._start_stop_timeout = start_stop_timeout + self._tiled_inserter = tiled_inserter self._tasks = {} @@ -194,13 +197,25 @@ def get_active_task(self) -> TrackableTask[Task] | None: return current @start_as_current_span(TRACER, "task_id") - def begin_task(self, task_id: str) -> None: + def begin_task(self, task_id: str, token: str | None) -> None: task = self._tasks.get(task_id) + data_subs: list[int] = [] if task is not None: - self._submit_trackable_task(task) + if self._tiled_inserter: + data_subs.append(self._authorize_running_task(token)) + self._submit_trackable_task(task, data_subs) + else: raise KeyError(f"No pending task with ID {task_id}") + def _authorize_running_task(self, token: str | None) -> int: + assert self._tiled_inserter + # https://github.com/DiamondLightSource/blueapi/issues/774 + # If users should only be able to run their own scans, pass headers + # as part of submitting a task, cache in TrackableTask field and check + # that token belongs to same user (but may be newer token!) + return self.data_events.subscribe(self._tiled_inserter(token)) + @start_as_current_span(TRACER, "task.name", "task.params") def submit_task(self, task: Task) -> str: task.prepare_params(self._ctx) # Will raise if parameters are invalid @@ -218,7 +233,9 @@ def submit_task(self, task: Task) -> str: "trackable_task.task.name", "trackable_task.task.params", ) - def _submit_trackable_task(self, trackable_task: TrackableTask) -> None: + def _submit_trackable_task( + self, trackable_task: TrackableTask, data_subs: list[int] | None = None + ) -> None: if self.state is not WorkerState.IDLE: raise WorkerBusyError(f"Worker is in state {self.state}") @@ -231,15 +248,19 @@ def mark_task_as_started(event: WorkerEvent, _: str | None) -> None: ): task_started.set() + def unsubscribe_watchers(event: WorkerEvent, _: str | None) -> None: + if event.task_status and event.task_status.task_complete and data_subs: + for data_sub in data_subs: + self.data_events.unsubscribe(data_sub) + LOGGER.info(f"Submitting: {trackable_task}") sub = self.worker_events.subscribe(mark_task_as_started) + self.worker_events.subscribe(unsubscribe_watchers) try: self._current_task_otel_context = get_current() - sub = self.worker_events.subscribe(mark_task_as_started) """ Cache the current trace context as the one for this task id """ self._task_channel.put_nowait(trackable_task) - task_started.wait(timeout=5.0) - if not task_started.is_set(): + if not task_started.wait(timeout=5.0): raise TimeoutError("Failed to start plan within timeout") except Full as f: LOGGER.error("Cannot submit task while another is running") diff --git a/src/blueapi/worker/tiled.py b/src/blueapi/worker/tiled.py new file mode 100644 index 000000000..e56a4826c --- /dev/null +++ b/src/blueapi/worker/tiled.py @@ -0,0 +1,25 @@ +from typing import Any + +from bluesky.callbacks.tiled_writer import TiledWriter +from tiled.client import from_uri + +from blueapi.config import TiledConfig +from blueapi.core.bluesky_types import DataEvent + + +class TiledConverter: + def __init__(self, uri: str, headers: dict[str, Any]): + self._writer: TiledWriter = TiledWriter(from_uri(uri, headers=headers)) + + def __call__(self, data: DataEvent, _: str | None = None) -> None: + self._writer(data.name, data.doc) + + +class TiledConnection: + def __init__(self, config: TiledConfig): + self.uri = f"{config.host}:{config.port}" + + def __call__(self, token: str | None): + return TiledConverter( + self.uri, headers={"Authorization": f"Bearer {token}"} if token else {} + ) diff --git a/src/script/stomp_config.yml b/src/script/stomp_config.yml index 99ff4ad63..b490ccdc4 100644 --- a/src/script/stomp_config.yml +++ b/src/script/stomp_config.yml @@ -1,7 +1,20 @@ ---- stomp: host: "localhost" port: 61613 auth: username: "guest" password: "guest" +api: + port: 3000 +tiled: + host: "http://localhost" + port: 4000 +oidc: + well_known_url: http://localhost:8080/foo/.well-known/openid-configuration + client_id: blueapi +env: + sources: + - kind: dodal + module: dodal.beamlines.i22 + - kind: planFunctions + module: dodal.plans diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 67517a691..3caf99a19 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -224,6 +224,7 @@ def temp_yaml_config_file( "logging": {"level": "INFO"}, "api": {"host": "0.0.0.0", "port": 8000, "protocol": "http"}, "scratch": None, + "tiled": None, }, ], indirect=True, @@ -285,6 +286,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "tiled": None, }, { "stomp": { @@ -318,6 +320,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "tiled": None, }, ], indirect=True,