Skip to content

Commit

Permalink
restart SSE loop when getting a timeout and state exists
Browse files Browse the repository at this point in the history
This typically means there was a silent disconnection of the SSE loop.
This happens in practice. There should be a better way to detect it
but let's try to be more reliable for now.
  • Loading branch information
catwell committed Jan 27, 2025
1 parent 22bb9cf commit 047f968
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 7 deletions.
4 changes: 2 additions & 2 deletions finegrain/requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ h11==0.14.0
httpcore==1.0.7
# via httpx
httpx==0.28.1
# via finegrain-python
# via finegrain
httpx-sse==0.4.0
# via finegrain-python
# via finegrain
idna==3.10
# via anyio
# via httpx
Expand Down
33 changes: 28 additions & 5 deletions finegrain/src/finegrain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,44 @@ async def sse_stop(self) -> None:
async def sse_await(self, state_id: str, timeout: float | None = None) -> bool:
assert self._sse_task
future = self._sse_futures[state_id]
timeout = timeout or self.default_timeout

while True:
sse_task = self._sse_task
done, _ = await asyncio.wait(
{future, self._sse_task},
timeout=timeout or self.default_timeout,
{future, sse_task},
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED,
)
if not done:
raise TimeoutError(f"state {state_id} timed out after {timeout}")
if self._sse_task in done:
assert state_id != "_sse_loop"
# check if SSE is not disconnected (waited too long)
r = await self.request("GET", f"state/meta/{state_id}", raise_for_status=False)
if r.is_success:
is_ok = r.json()["status"] == "ok"
done, _ = await asyncio.wait(
{future, sse_task},
timeout=2.0,
return_when=asyncio.FIRST_COMPLETED,
)
del self._sse_futures[state_id]
if done == {future}: # we just didn't wait enough
return is_ok
self.logger.warning(f"got timeout for state {state_id} and meta is OK, restarting loop")
aw = asyncio.gather(sse_task, return_exceptions=True) # catch `CancelledError` properly
sse_task.cancel()
await aw
return is_ok
elif r.status_code != 404:
raise TimeoutError(f"state {state_id} timed out after {timeout}")
else:
raise RuntimeError(f"getting state {state_id} after timeout {timeout} returned {r.status_code}")
if sse_task in done:
self._sse_failures += 1
if state_id != "_sse_loop" and (await self.sse_recover()):
self._sse_failures = 0
continue
exception = self._sse_task.exception()
exception = sse_task.exception()
raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception
break

Expand Down

0 comments on commit 047f968

Please sign in to comment.