Skip to content

Commit

Permalink
Merge pull request #43 from yozik04/ws_retry
Browse files Browse the repository at this point in the history
Retry ws requests 4 additional times
  • Loading branch information
yozik04 authored Jul 6, 2023
2 parents 9e63c9e + 387ddf9 commit 205c1a8
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 12 deletions.
44 changes: 43 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import binascii
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch

import pytest
import websockets
from websockets.exceptions import InvalidMessage

from vallox_websocket_api.client import Client
Expand Down Expand Up @@ -110,3 +112,43 @@ async def test_set_new_settable_address_by_address_exception(client: Client, ws)

with pytest.raises(ValloxWebsocketException):
await client.set_values({"A_CYC_RH_VALUE": 22})

assert ws.send.call_count == 5


async def test_connection_closed_ws_exception(client: Client, ws):
ws.recv.side_effect = AsyncMock(side_effect=websockets.ConnectionClosed(None, None))

with pytest.raises(ValloxWebsocketException):
await client.fetch_metric("A_CYC_ENABLED")

assert ws.send.call_count == 5


async def test_ws_recv_timeout_exception(client: Client, ws):
ws.recv.side_effect = AsyncMock(side_effect=asyncio.TimeoutError())

with pytest.raises(ValloxWebsocketException):
await client.fetch_metric("A_CYC_ENABLED")

assert ws.send.call_count == 5


async def test_invalid_ws_url_exception(client: Client):
with patch("websockets.connect") as connect:
connect.side_effect = websockets.InvalidURI("test", "test")

with pytest.raises(ValloxWebsocketException):
await client.fetch_metric("A_CYC_ENABLED")

assert connect.call_count == 1


async def test_ws_connection_timeout_exception(client: Client):
with patch("websockets.connect") as connect:
connect.side_effect = asyncio.TimeoutError()

with pytest.raises(ValloxWebsocketException):
await client.fetch_metric("A_CYC_ENABLED")

assert connect.call_count == 5
2 changes: 1 addition & 1 deletion vallox_websocket_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
"ValloxWebsocketException",
]

__version__ = "3.2.1"
__version__ = "3.3.0"
54 changes: 44 additions & 10 deletions vallox_websocket_api/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from functools import wraps
import logging
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast

Expand All @@ -20,8 +21,14 @@
WriteMessageRequest,
)

logger = logging.getLogger("vallox").getChild(__name__)

KPageSize = 65536

WEBSOCKETS_OPEN_TIMEOUT = 1
WEBSOCKETS_RECV_TIMEOUT = 1
WEBSOCKET_RETRY_DELAYS = [0.1, 0.2, 0.5, 1]


def calculate_offset(aIndex: int) -> int:
offset = 0
Expand Down Expand Up @@ -146,11 +153,28 @@ def to_kelvin(value: float) -> int:
FuncT = TypeVar("FuncT", bound=Callable[..., Any])


def _websocket_exception_handler(request_fn: FuncT) -> FuncT:
def _websocket_retry_wrapper(request_fn: FuncT) -> FuncT:
retry_on_exceptions = (
websockets.InvalidHandshake,
websockets.InvalidState,
websockets.WebSocketProtocolError,
websockets.ConnectionClosed,
OSError,
asyncio.TimeoutError,
)

@wraps(request_fn)
async def wrapped(*args: Any, **kwargs: Any) -> Any:
try:
return await request_fn(*args, **kwargs)
delays = WEBSOCKET_RETRY_DELAYS.copy()
while len(delays) >= 0:
try:
return await request_fn(*args, **kwargs)
except Exception as e:
if isinstance(e, retry_on_exceptions) and len(delays) > 0:
await asyncio.sleep(delays.pop(0))
else:
raise e
except websockets.InvalidHandshake as e:
raise ValloxWebsocketException("Websocket handshake failed") from e
except websockets.InvalidURI as e:
Expand All @@ -161,8 +185,12 @@ async def wrapped(*args: Any, **kwargs: Any) -> Any:
raise ValloxWebsocketException("Websocket invalid state") from e
except websockets.WebSocketProtocolError as e:
raise ValloxWebsocketException("Websocket protocol error") from e
except websockets.ConnectionClosed as e:
raise ValloxWebsocketException("Websocket connection closed") from e
except OSError as e:
raise ValloxWebsocketException("Websocket connection failed") from e
except asyncio.TimeoutError as e:
raise ValloxWebsocketException("Websocket connection timed out") from e

return cast(FuncT, wrapped)

Expand Down Expand Up @@ -232,20 +260,26 @@ def _encode_pair(

return address, raw_value

@_websocket_exception_handler
async def _websocket_request(self, payload: bytes) -> bytes:
async with websockets.connect(f"ws://{self.ip_address}/") as ws:
await ws.send(payload)
r: bytes = await ws.recv()
return r
return (await self._websocket_request_multiple(payload, 1))[0]

@_websocket_exception_handler
@_websocket_retry_wrapper
async def _websocket_request_multiple(
self, payload: bytes, read_packets: int
) -> List[bytes]:
async with websockets.connect(f"ws://{self.ip_address}/") as ws:
async with websockets.connect(
f"ws://{self.ip_address}/",
open_timeout=WEBSOCKETS_OPEN_TIMEOUT,
logger=logger,
) as ws:
await ws.send(payload)
return await asyncio.gather(*[ws.recv() for _ in range(0, read_packets)])

async def _get_responses() -> List[bytes]:
return [await ws.recv() for _ in range(0, read_packets)]

return await asyncio.wait_for(
_get_responses(), timeout=WEBSOCKETS_RECV_TIMEOUT * read_packets
)

async def fetch_metrics(
self, metric_keys: Optional[List[str]] = None
Expand Down

0 comments on commit 205c1a8

Please sign in to comment.