Skip to content

Commit

Permalink
feat: add periodic cleanup of in-flight requests
Browse files Browse the repository at this point in the history
  • Loading branch information
dsp-ant committed Jan 24, 2025
1 parent ba585a8 commit d50255e
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
receive_notification_type: type[ReceiveNotificationT],
# If none, reading will never time out
read_timeout_seconds: timedelta | None = None,
cleanup_interval_seconds: float = 60.0,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
Expand All @@ -116,6 +117,8 @@ def __init__(
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._cleanup_interval = cleanup_interval_seconds
self._in_flight: dict[RequestId, RequestResponder] = {}

self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
Expand All @@ -129,6 +132,7 @@ async def __aenter__(self):
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
self._task_group.start_soon(self._cleanup_loop)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down Expand Up @@ -221,8 +225,20 @@ async def _send_response(
)
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))

async def _cleanup_loop(self) -> None:
"""Periodically clean up completed and cancelled requests."""
while True:
with anyio.move_on_after(self._cleanup_interval):
# Clean up completed requests
self._in_flight = {
req_id: responder
for req_id, responder in self._in_flight.items()
if responder.in_flight
}
await anyio.sleep(self._cleanup_interval)

async def _receive_loop(self) -> None:
in_flight: dict[RequestId, RequestResponder] = {}
"""Handle incoming messages and maintain request state."""

async with (
self._read_stream,
Expand All @@ -231,9 +247,9 @@ async def _receive_loop(self) -> None:
):
async for message in self._read_stream:
# Clean up completed requests
in_flight = {
self._in_flight = {
req_id: responder
for req_id, responder in in_flight.items()
for req_id, responder in self._in_flight.items()
if responder.in_flight
}

Expand All @@ -257,7 +273,7 @@ async def _receive_loop(self) -> None:
cancel_scope=scope,
)

in_flight[message.root.id] = responder
self._in_flight[message.root.id] = responder

await self._received_request(responder)
if not responder._responded:
Expand All @@ -272,8 +288,8 @@ async def _receive_loop(self) -> None:
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in in_flight:
await in_flight[cancelled_id].cancel()
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
Expand Down

0 comments on commit d50255e

Please sign in to comment.