From 62d24d5896a2921e6eda5fe4576f31f09c72f060 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Wed, 29 Jan 2025 18:57:13 +0100 Subject: [PATCH] refactor to decouple the event source move retrying logic to event source (with exp backoff + jitter) --- finegrain/src/finegrain/__init__.py | 334 +++++++++++++++++----------- 1 file changed, 208 insertions(+), 126 deletions(-) diff --git a/finegrain/src/finegrain/__init__.py b/finegrain/src/finegrain/__init__.py index 2d6cb07..f3b6dc5 100644 --- a/finegrain/src/finegrain/__init__.py +++ b/finegrain/src/finegrain/__init__.py @@ -1,8 +1,9 @@ import asyncio import json import logging +import random from collections import defaultdict -from collections.abc import Awaitable, Callable, Mapping +from collections.abc import AsyncIterator, Awaitable, Callable, Mapping from typing import Any, Literal, cast import httpx @@ -15,7 +16,22 @@ class SSELoopStopped(RuntimeError): - pass + first_error: Exception | None + last_error: Exception | None + + def __init__( + self, + message: str | None = None, + first_error: Exception | None = None, + last_error: Exception | None = None, + ) -> None: + self.first_error = first_error + self.last_error = last_error + super().__init__(message or self.default_message) + + @property + def default_message(self) -> str: + return f"SSE loop stopped (first error: {self.first_error}, last error: {self.last_error})" class Futures[T]: @@ -42,6 +58,159 @@ def __delitem__(self, key: str) -> None: pass +class RetryContext: + max_failures: int + max_jitter: float + max_backoff: float + exp_base: float + exp_factor: float + + failures: int + first_error: Exception | None + last_error: Exception | None + + def __init__( + self, + max_failures: int = 10, + max_jitter: float = 1.0, + max_backoff: float = 15.0, + exp_base: float = 2.0, + exp_factor: float = 0.1, + ): + self.max_failures = max_failures + self.max_jitter = max_jitter + self.max_backoff = max_backoff + self.exp_base = exp_base + self.exp_factor = exp_factor + + self.reset() + + def reset(self) -> None: + self.failures = 0 + self.first_error = None + self.last_error = None + + @property + def backoff(self) -> float: + if self.failures == 0: + return 0 + jitter = random.uniform(0, self.max_jitter) + return min(self.exp_factor * (self.exp_base**self.failures) + jitter, self.max_backoff) + + @property + def remaining_attempts(self) -> int: + return max(self.max_failures - self.failures, 0) + + def failure(self, exc: Exception | None) -> None: + if self.failures == 0: + self.first_error = exc + self.last_error = exc + self.failures += 1 + + def success(self) -> None: + self.failures = 0 + + +class ResilientEventSource: + get_url: Callable[[], Awaitable[str]] + verify: bool | str + retry_ctx: RetryContext + + logger: logging.Logger + + _last_event_id: str + _retry_ms: int + + active: asyncio.Future[None] + + def __init__( + self, + url: str | Callable[[], Awaitable[str]], + verify: bool | str = True, + retry_ctx: RetryContext | None = None, + ) -> None: + self.get_url = self.async_return(url) if isinstance(url, str) else url + self.verify = verify + self.retry_ctx = RetryContext() if retry_ctx is None else retry_ctx + + self.logger = logger + + def reset(self) -> None: + self._last_event_id = "" + self._retry_ms = 0 + self.retry_ctx.reset() + self.active = asyncio.get_running_loop().create_future() + + @staticmethod + def async_return[T](x: T) -> Callable[[], Awaitable[T]]: + async def f() -> T: + return x + + return f + + @staticmethod + def decode_json(data: str) -> dict[str, Any] | None: + try: + r = json.loads(data) + except json.JSONDecodeError: + return None + if type(r) is not dict: + return None + return cast(dict[str, Any], r) + + @property + def headers(self) -> dict[str, str]: + r = {"Accept": "text/event-stream"} + if self._last_event_id: + r["Last-Event-ID"] = self._last_event_id + return r + + def failure(self, exc: Exception | None) -> None: + self.active = asyncio.get_running_loop().create_future() + self.retry_ctx.failure(exc) + + def success(self) -> None: + self.retry_ctx.success() + self.active.set_result(None) + + async def __aiter__(self) -> AsyncIterator[dict[str, Any]]: + while True: + if self.retry_ctx.remaining_attempts == 0: + raise SSELoopStopped( + first_error=self.retry_ctx.first_error, + last_error=self.retry_ctx.last_error, + ) + try: + self.logger.info( + f"SSE loop retry attempt {self.retry_ctx.failures} " + f"(backoff {self.retry_ctx.backoff:.3f}, retry_ms {self._retry_ms})" + ) + await asyncio.sleep(self.retry_ctx.backoff + self._retry_ms / 1000) + url = await self.get_url() + async with ( + httpx.AsyncClient(timeout=None, verify=self.verify) as c, + httpx_sse.aconnect_sse(c, "GET", url, headers=self.headers) as es, + ): + es.response.raise_for_status() + self.success() + async for sse in es.aiter_sse(): + self._last_event_id = sse.id + self._retry_ms = sse.retry or 0 + if sse.event == "ping": + self.logger.debug("got SSE ping") + continue + if sse.event != "message": + self.logger.warning(f"unexpected SSE event: {sse.event} ({sse.data})") + continue + if (event := self.decode_json(sse.data)) is None: + self.logger.warning(f"unexpected SSE message: {sse.data}") + continue + yield event + raise SSELoopStopped(message="SSE loop exited") + except (SSELoopStopped, httpx.HTTPError) as exc: + self.failure(exc) + + class EditorAPIContext: user: str password: str @@ -80,15 +249,13 @@ def __init__( self.token = None self.logger = logger - self.max_sse_failures = 5 self._client = None self._client_ctx_depth = 0 + self._sse_futures = Futures() + self._sse_source = ResilientEventSource(self.get_sub_url, verify=self.verify) self._sse_task = None - self._sse_failures = 0 - self._sse_last_event_id = "" - self._sse_retry_ms = 0 async def __aenter__(self) -> httpx.AsyncClient: if self._client: @@ -154,147 +321,62 @@ async def _q() -> httpx.Response: r.raise_for_status() return r - @classmethod - def decode_json(cls, data: str) -> dict[str, Any] | None: - try: - r = json.loads(data) - except json.JSONDecodeError: - return None - if type(r) is not dict: - return None - return cast(dict[str, Any], r) - - async def _sse_loop(self) -> None: - sub_headers = {"Accept": "text/event-stream"} - retry_ms = self._sse_retry_ms + 1000 * (2**self._sse_failures - 1) - if self._sse_last_event_id: - self.logger.info(f"resuming SSE from event {self._sse_last_event_id} in {retry_ms} ms") - sub_headers["Last-Event-ID"] = self._sse_last_event_id - elif retry_ms > 0: - self.logger.info(f"resuming SSE in {retry_ms} ms") - await asyncio.sleep(retry_ms / 1000) - + async def get_sub_url(self) -> str: response = await self.request("POST", "sub-auth") sub_token = response.json()["token"] - sub_url = f"{self.base_url}/sub/{sub_token}" - - async with ( - httpx.AsyncClient(timeout=None, verify=self.verify) as c, - httpx_sse.aconnect_sse(c, "GET", sub_url, headers=sub_headers) as es, - ): - es.response.raise_for_status() - self._sse_futures["_sse_loop"].set_result({"status": "ok"}) - async for sse in es.aiter_sse(): - self._sse_last_event_id = sse.id - self._sse_retry_ms = sse.retry or 0 - if sse.event == "ping": - self.logger.debug("got SSE ping") - continue - elif sse.event == "message": - jdata = self.decode_json(sse.data) - if (jdata is None) or ("state" not in jdata): - # When the server restarts we can get an empty string here. - self.logger.warning(f"unexpected SSE message: {sse.data}") - continue - self.logger.debug(f"got message: {jdata}") - self._sse_futures[jdata["state"]].set_result(jdata) - else: - self.logger.warning(f"unexpected SSE event: {sse.event} ({sse.data})") - self.logger.info("SSE loop exited") + return f"{self.base_url}/sub/{sub_token}" + + async def _sse_loop(self) -> None: + async for event in self._sse_source: + if "state" not in event: + self.logger.warning(f"unexpected SSE message: {event}") + continue + self.logger.debug(f"got message: {event}") + self._sse_futures[event["state"]].set_result(event) async def sse_start(self) -> None: assert self._sse_task is None - self._sse_last_event_id = "" - self._sse_retry_ms = 0 + self._sse_source.reset() self._sse_task = asyncio.create_task(self._sse_loop()) - assert await self.sse_await("_sse_loop") - self._sse_failures = 0 - - async def sse_recover(self) -> bool: - while True: - if self._sse_failures > self.max_sse_failures: - return False - self._sse_task = asyncio.create_task(self._sse_loop()) - try: - assert await self.sse_await("_sse_loop") - return True - except SSELoopStopped: - pass + await self._sse_source.active async def sse_stop(self) -> None: assert self._sse_task self._sse_task.cancel() - await self._sse_task + exc = await asyncio.gather(self._sse_task, return_exceptions=True) + assert len(exc) == 1 and isinstance(exc[0], asyncio.CancelledError) self._sse_task = None - async def sse_monitor(self) -> None: - while True: - assert (task := self._sse_task) - exc = (await asyncio.gather(task, return_exceptions=True))[0] # catch `CancelledError` properly - if exc is None: - self.logger.info("SSE loop stopped, recovering") - elif isinstance(exc, asyncio.CancelledError): - self.logger.info("SSE loop cancelled, exiting") - break - else: - self.logger.info(f"Got exception {exc} in SSE loop, recovering") - self._sse_failures += 1 - r = await self.sse_recover() - if not r: - raise RuntimeError("SSE loop failed to recover") - self._sse_failures = 0 - async def sse_await(self, state_id: str, timeout: float | None = None) -> bool: assert self._sse_task future = self._sse_futures[state_id] timeout = timeout or self.default_timeout - while True: - sse_task = self._sse_task - done, _ = await asyncio.wait( - {future, sse_task}, - timeout=timeout, - return_when=asyncio.FIRST_COMPLETED, - ) - if not done: - assert state_id != "_sse_loop" - # check if SSE is not disconnected (waited too long) - r = await self.request("GET", f"state/meta/{state_id}", raise_for_status=False) - if r.is_success: - is_ok = r.json()["status"] == "ok" - done, _ = await asyncio.wait( - {future, sse_task}, - timeout=2.0, - return_when=asyncio.FIRST_COMPLETED, - ) - del self._sse_futures[state_id] - if done == {future}: # we just didn't wait enough - return is_ok - self.logger.warning(f"got timeout for state {state_id} and meta is OK, restarting loop") - aw = asyncio.gather(sse_task, return_exceptions=True) # catch `CancelledError` properly - sse_task.cancel() - await aw - return is_ok - elif r.status_code != 404: - raise TimeoutError(f"state {state_id} timed out after {timeout}") - else: - raise RuntimeError(f"getting state {state_id} after timeout {timeout} returned {r.status_code}") - if sse_task in done: - self._sse_failures += 1 - if state_id != "_sse_loop": - self.logger.info(f"SSE loop stopped while waiting for state {state_id}, recovering") - if await self.sse_recover(): - self._sse_failures = 0 - continue - exception = sse_task.exception() - raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception - break + sse_task = self._sse_task + done, _ = await asyncio.wait( + {future, self._sse_task}, + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED, + ) + if sse_task in done: + exception = sse_task.exception() + raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception + if not done: + r = await self.request("GET", f"state/meta/{state_id}", raise_for_status=False) + if r.is_success: + status = r.json()["status"] + self.logger.warning(f"got timeout for state {state_id}, found metadata with status {status}") + return status == "ok" + elif r.status_code != 404: + raise TimeoutError(f"state {state_id} timed out after {timeout}") + else: + raise RuntimeError(f"getting state {state_id} after timeout {timeout} returned {r.status_code}") assert done == {future} - jdata = future.result() + event = future.result() del self._sse_futures[state_id] - return jdata["status"] == "ok" + return event["status"] == "ok" async def get_meta(self, state_id: str) -> dict[str, Any]: response = await self.request("GET", f"state/meta/{state_id}")