From 047f9683a2b2e4021bfd106651e5b981d19865f8 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Mon, 27 Jan 2025 17:03:30 +0100 Subject: [PATCH] restart SSE loop when getting a timeout and state exists This typically means there was a silent disconnection of the SSE loop. This happens in practice. There should be a better way to detect it but let's try to be more reliable for now. --- finegrain/requirements.lock | 4 ++-- finegrain/src/finegrain/__init__.py | 33 ++++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/finegrain/requirements.lock b/finegrain/requirements.lock index 270fa10..1330575 100644 --- a/finegrain/requirements.lock +++ b/finegrain/requirements.lock @@ -20,9 +20,9 @@ h11==0.14.0 httpcore==1.0.7 # via httpx httpx==0.28.1 - # via finegrain-python + # via finegrain httpx-sse==0.4.0 - # via finegrain-python + # via finegrain idna==3.10 # via anyio # via httpx diff --git a/finegrain/src/finegrain/__init__.py b/finegrain/src/finegrain/__init__.py index cf5b1d6..ff4d08c 100644 --- a/finegrain/src/finegrain/__init__.py +++ b/finegrain/src/finegrain/__init__.py @@ -195,21 +195,44 @@ async def sse_stop(self) -> None: async def sse_await(self, state_id: str, timeout: float | None = None) -> bool: assert self._sse_task future = self._sse_futures[state_id] + timeout = timeout or self.default_timeout while True: + sse_task = self._sse_task done, _ = await asyncio.wait( - {future, self._sse_task}, - timeout=timeout or self.default_timeout, + {future, sse_task}, + timeout=timeout, return_when=asyncio.FIRST_COMPLETED, ) if not done: - raise TimeoutError(f"state {state_id} timed out after {timeout}") - if self._sse_task in done: + assert state_id != "_sse_loop" + # check if SSE is not disconnected (waited too long) + r = await self.request("GET", f"state/meta/{state_id}", raise_for_status=False) + if r.is_success: + is_ok = r.json()["status"] == "ok" + done, _ = await asyncio.wait( + {future, sse_task}, + timeout=2.0, + return_when=asyncio.FIRST_COMPLETED, + ) + del self._sse_futures[state_id] + if done == {future}: # we just didn't wait enough + return is_ok + self.logger.warning(f"got timeout for state {state_id} and meta is OK, restarting loop") + aw = asyncio.gather(sse_task, return_exceptions=True) # catch `CancelledError` properly + sse_task.cancel() + await aw + return is_ok + elif r.status_code != 404: + raise TimeoutError(f"state {state_id} timed out after {timeout}") + else: + raise RuntimeError(f"getting state {state_id} after timeout {timeout} returned {r.status_code}") + if sse_task in done: self._sse_failures += 1 if state_id != "_sse_loop" and (await self.sse_recover()): self._sse_failures = 0 continue - exception = self._sse_task.exception() + exception = sse_task.exception() raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception break