Skip to content

Commit

Permalink
refactor to decouple the event source
Browse files Browse the repository at this point in the history
move retrying logic to event source (with exp backoff + jitter)
  • Loading branch information
catwell committed Jan 29, 2025
1 parent cad3fb6 commit 62d24d5
Showing 1 changed file with 208 additions and 126 deletions.
334 changes: 208 additions & 126 deletions finegrain/src/finegrain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import json
import logging
import random
from collections import defaultdict
from collections.abc import Awaitable, Callable, Mapping
from collections.abc import AsyncIterator, Awaitable, Callable, Mapping
from typing import Any, Literal, cast

import httpx
Expand All @@ -15,7 +16,22 @@


class SSELoopStopped(RuntimeError):
pass
first_error: Exception | None
last_error: Exception | None

def __init__(
self,
message: str | None = None,
first_error: Exception | None = None,
last_error: Exception | None = None,
) -> None:
self.first_error = first_error
self.last_error = last_error
super().__init__(message or self.default_message)

@property
def default_message(self) -> str:
return f"SSE loop stopped (first error: {self.first_error}, last error: {self.last_error})"


class Futures[T]:
Expand All @@ -42,6 +58,159 @@ def __delitem__(self, key: str) -> None:
pass


class RetryContext:
max_failures: int
max_jitter: float
max_backoff: float
exp_base: float
exp_factor: float

failures: int
first_error: Exception | None
last_error: Exception | None

def __init__(
self,
max_failures: int = 10,
max_jitter: float = 1.0,
max_backoff: float = 15.0,
exp_base: float = 2.0,
exp_factor: float = 0.1,
):
self.max_failures = max_failures
self.max_jitter = max_jitter
self.max_backoff = max_backoff
self.exp_base = exp_base
self.exp_factor = exp_factor

self.reset()

def reset(self) -> None:
self.failures = 0
self.first_error = None
self.last_error = None

@property
def backoff(self) -> float:
if self.failures == 0:
return 0
jitter = random.uniform(0, self.max_jitter)
return min(self.exp_factor * (self.exp_base**self.failures) + jitter, self.max_backoff)

@property
def remaining_attempts(self) -> int:
return max(self.max_failures - self.failures, 0)

def failure(self, exc: Exception | None) -> None:
if self.failures == 0:
self.first_error = exc
self.last_error = exc
self.failures += 1

def success(self) -> None:
self.failures = 0


class ResilientEventSource:
get_url: Callable[[], Awaitable[str]]
verify: bool | str
retry_ctx: RetryContext

logger: logging.Logger

_last_event_id: str
_retry_ms: int

active: asyncio.Future[None]

def __init__(
self,
url: str | Callable[[], Awaitable[str]],
verify: bool | str = True,
retry_ctx: RetryContext | None = None,
) -> None:
self.get_url = self.async_return(url) if isinstance(url, str) else url
self.verify = verify
self.retry_ctx = RetryContext() if retry_ctx is None else retry_ctx

self.logger = logger

def reset(self) -> None:
self._last_event_id = ""
self._retry_ms = 0
self.retry_ctx.reset()
self.active = asyncio.get_running_loop().create_future()

@staticmethod
def async_return[T](x: T) -> Callable[[], Awaitable[T]]:
async def f() -> T:
return x

return f

@staticmethod
def decode_json(data: str) -> dict[str, Any] | None:
try:
r = json.loads(data)
except json.JSONDecodeError:
return None
if type(r) is not dict:
return None
return cast(dict[str, Any], r)

@property
def headers(self) -> dict[str, str]:
r = {"Accept": "text/event-stream"}
if self._last_event_id:
r["Last-Event-ID"] = self._last_event_id
return r

def failure(self, exc: Exception | None) -> None:
self.active = asyncio.get_running_loop().create_future()
self.retry_ctx.failure(exc)

def success(self) -> None:
self.retry_ctx.success()
self.active.set_result(None)

async def __aiter__(self) -> AsyncIterator[dict[str, Any]]:
while True:
if self.retry_ctx.remaining_attempts == 0:
raise SSELoopStopped(
first_error=self.retry_ctx.first_error,
last_error=self.retry_ctx.last_error,
)
try:
self.logger.info(
f"SSE loop retry attempt {self.retry_ctx.failures} "
f"(backoff {self.retry_ctx.backoff:.3f}, retry_ms {self._retry_ms})"
)
await asyncio.sleep(self.retry_ctx.backoff + self._retry_ms / 1000)
url = await self.get_url()
async with (
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
httpx_sse.aconnect_sse(c, "GET", url, headers=self.headers) as es,
):
es.response.raise_for_status()
self.success()
async for sse in es.aiter_sse():
self._last_event_id = sse.id
self._retry_ms = sse.retry or 0
if sse.event == "ping":
self.logger.debug("got SSE ping")
continue
if sse.event != "message":
self.logger.warning(f"unexpected SSE event: {sse.event} ({sse.data})")
continue
if (event := self.decode_json(sse.data)) is None:
self.logger.warning(f"unexpected SSE message: {sse.data}")
continue
yield event
raise SSELoopStopped(message="SSE loop exited")
except (SSELoopStopped, httpx.HTTPError) as exc:
self.failure(exc)


class EditorAPIContext:
user: str
password: str
Expand Down Expand Up @@ -80,15 +249,13 @@ def __init__(

self.token = None
self.logger = logger
self.max_sse_failures = 5

self._client = None
self._client_ctx_depth = 0

self._sse_futures = Futures()
self._sse_source = ResilientEventSource(self.get_sub_url, verify=self.verify)
self._sse_task = None
self._sse_failures = 0
self._sse_last_event_id = ""
self._sse_retry_ms = 0

async def __aenter__(self) -> httpx.AsyncClient:
if self._client:
Expand Down Expand Up @@ -154,147 +321,62 @@ async def _q() -> httpx.Response:
r.raise_for_status()
return r

@classmethod
def decode_json(cls, data: str) -> dict[str, Any] | None:
try:
r = json.loads(data)
except json.JSONDecodeError:
return None
if type(r) is not dict:
return None
return cast(dict[str, Any], r)

async def _sse_loop(self) -> None:
sub_headers = {"Accept": "text/event-stream"}
retry_ms = self._sse_retry_ms + 1000 * (2**self._sse_failures - 1)
if self._sse_last_event_id:
self.logger.info(f"resuming SSE from event {self._sse_last_event_id} in {retry_ms} ms")
sub_headers["Last-Event-ID"] = self._sse_last_event_id
elif retry_ms > 0:
self.logger.info(f"resuming SSE in {retry_ms} ms")
await asyncio.sleep(retry_ms / 1000)

async def get_sub_url(self) -> str:
response = await self.request("POST", "sub-auth")
sub_token = response.json()["token"]
sub_url = f"{self.base_url}/sub/{sub_token}"

async with (
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
httpx_sse.aconnect_sse(c, "GET", sub_url, headers=sub_headers) as es,
):
es.response.raise_for_status()
self._sse_futures["_sse_loop"].set_result({"status": "ok"})
async for sse in es.aiter_sse():
self._sse_last_event_id = sse.id
self._sse_retry_ms = sse.retry or 0
if sse.event == "ping":
self.logger.debug("got SSE ping")
continue
elif sse.event == "message":
jdata = self.decode_json(sse.data)
if (jdata is None) or ("state" not in jdata):
# When the server restarts we can get an empty string here.
self.logger.warning(f"unexpected SSE message: {sse.data}")
continue
self.logger.debug(f"got message: {jdata}")
self._sse_futures[jdata["state"]].set_result(jdata)
else:
self.logger.warning(f"unexpected SSE event: {sse.event} ({sse.data})")
self.logger.info("SSE loop exited")
return f"{self.base_url}/sub/{sub_token}"

async def _sse_loop(self) -> None:
async for event in self._sse_source:
if "state" not in event:
self.logger.warning(f"unexpected SSE message: {event}")
continue
self.logger.debug(f"got message: {event}")
self._sse_futures[event["state"]].set_result(event)

async def sse_start(self) -> None:
assert self._sse_task is None
self._sse_last_event_id = ""
self._sse_retry_ms = 0
self._sse_source.reset()
self._sse_task = asyncio.create_task(self._sse_loop())
assert await self.sse_await("_sse_loop")
self._sse_failures = 0

async def sse_recover(self) -> bool:
while True:
if self._sse_failures > self.max_sse_failures:
return False
self._sse_task = asyncio.create_task(self._sse_loop())
try:
assert await self.sse_await("_sse_loop")
return True
except SSELoopStopped:
pass
await self._sse_source.active

async def sse_stop(self) -> None:
assert self._sse_task
self._sse_task.cancel()
await self._sse_task
exc = await asyncio.gather(self._sse_task, return_exceptions=True)
assert len(exc) == 1 and isinstance(exc[0], asyncio.CancelledError)
self._sse_task = None

async def sse_monitor(self) -> None:
while True:
assert (task := self._sse_task)
exc = (await asyncio.gather(task, return_exceptions=True))[0] # catch `CancelledError` properly
if exc is None:
self.logger.info("SSE loop stopped, recovering")
elif isinstance(exc, asyncio.CancelledError):
self.logger.info("SSE loop cancelled, exiting")
break
else:
self.logger.info(f"Got exception {exc} in SSE loop, recovering")
self._sse_failures += 1
r = await self.sse_recover()
if not r:
raise RuntimeError("SSE loop failed to recover")
self._sse_failures = 0

async def sse_await(self, state_id: str, timeout: float | None = None) -> bool:
assert self._sse_task
future = self._sse_futures[state_id]
timeout = timeout or self.default_timeout

while True:
sse_task = self._sse_task
done, _ = await asyncio.wait(
{future, sse_task},
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED,
)
if not done:
assert state_id != "_sse_loop"
# check if SSE is not disconnected (waited too long)
r = await self.request("GET", f"state/meta/{state_id}", raise_for_status=False)
if r.is_success:
is_ok = r.json()["status"] == "ok"
done, _ = await asyncio.wait(
{future, sse_task},
timeout=2.0,
return_when=asyncio.FIRST_COMPLETED,
)
del self._sse_futures[state_id]
if done == {future}: # we just didn't wait enough
return is_ok
self.logger.warning(f"got timeout for state {state_id} and meta is OK, restarting loop")
aw = asyncio.gather(sse_task, return_exceptions=True) # catch `CancelledError` properly
sse_task.cancel()
await aw
return is_ok
elif r.status_code != 404:
raise TimeoutError(f"state {state_id} timed out after {timeout}")
else:
raise RuntimeError(f"getting state {state_id} after timeout {timeout} returned {r.status_code}")
if sse_task in done:
self._sse_failures += 1
if state_id != "_sse_loop":
self.logger.info(f"SSE loop stopped while waiting for state {state_id}, recovering")
if await self.sse_recover():
self._sse_failures = 0
continue
exception = sse_task.exception()
raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception
break
sse_task = self._sse_task
done, _ = await asyncio.wait(
{future, self._sse_task},
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED,
)
if sse_task in done:
exception = sse_task.exception()
raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception
if not done:
r = await self.request("GET", f"state/meta/{state_id}", raise_for_status=False)
if r.is_success:
status = r.json()["status"]
self.logger.warning(f"got timeout for state {state_id}, found metadata with status {status}")
return status == "ok"
elif r.status_code != 404:
raise TimeoutError(f"state {state_id} timed out after {timeout}")
else:
raise RuntimeError(f"getting state {state_id} after timeout {timeout} returned {r.status_code}")

assert done == {future}

jdata = future.result()
event = future.result()
del self._sse_futures[state_id]
return jdata["status"] == "ok"
return event["status"] == "ok"

async def get_meta(self, state_id: str) -> dict[str, Any]:
response = await self.request("GET", f"state/meta/{state_id}")
Expand Down

0 comments on commit 62d24d5

Please sign in to comment.