diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 2eb184a..443527b 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -108,6 +108,7 @@ def __init__( receive_notification_type: type[ReceiveNotificationT], # If none, reading will never time out read_timeout_seconds: timedelta | None = None, + cleanup_interval_seconds: float = 60.0, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -116,6 +117,8 @@ def __init__( self._receive_request_type = receive_request_type self._receive_notification_type = receive_notification_type self._read_timeout_seconds = read_timeout_seconds + self._cleanup_interval = cleanup_interval_seconds + self._in_flight: dict[RequestId, RequestResponder] = {} self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ @@ -129,6 +132,7 @@ async def __aenter__(self): self._task_group = anyio.create_task_group() await self._task_group.__aenter__() self._task_group.start_soon(self._receive_loop) + self._task_group.start_soon(self._cleanup_loop) return self async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -221,8 +225,20 @@ async def _send_response( ) await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) + async def _cleanup_loop(self) -> None: + """Periodically clean up completed and cancelled requests.""" + while True: + with anyio.move_on_after(self._cleanup_interval): + # Clean up completed requests + self._in_flight = { + req_id: responder + for req_id, responder in self._in_flight.items() + if responder.in_flight + } + await anyio.sleep(self._cleanup_interval) + async def _receive_loop(self) -> None: - in_flight: dict[RequestId, RequestResponder] = {} + """Handle incoming messages and maintain request state.""" async with ( self._read_stream, @@ -231,9 +247,9 @@ async def _receive_loop(self) -> None: ): async for message in self._read_stream: # Clean up completed requests - in_flight = { + self._in_flight = { req_id: responder - for req_id, responder in in_flight.items() + for req_id, responder in self._in_flight.items() if responder.in_flight } @@ -257,7 +273,7 @@ async def _receive_loop(self) -> None: cancel_scope=scope, ) - in_flight[message.root.id] = responder + self._in_flight[message.root.id] = responder await self._received_request(responder) if not responder._responded: @@ -272,8 +288,8 @@ async def _receive_loop(self) -> None: # Handle cancellation notifications if isinstance(notification.root, CancelledNotification): cancelled_id = notification.root.params.requestId - if cancelled_id in in_flight: - await in_flight[cancelled_id].cancel() + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() else: await self._received_notification(notification) await self._incoming_message_stream_writer.send(notification)