Skip to content

Commit

Permalink
refactor: use batching for update_price
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-bahjati committed Sep 13, 2024
1 parent a591d8a commit 4e53387
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 127 deletions.
94 changes: 68 additions & 26 deletions example_publisher/pythd.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dataclasses import dataclass, field
import sys
import traceback
from dataclasses_json import config, DataClassJsonMixin
from typing import List
from typing import List, Any, Optional
from structlog import get_logger
from jsonrpc_websocket import Server
from websockets.client import connect, WebSocketClientProtocol
from asyncio import Lock

log = get_logger()

Expand All @@ -20,6 +19,14 @@ class Price(DataClassJsonMixin):
exponent: int = field(metadata=config(field_name="price_exponent"))


@dataclass
class PriceUpdate(DataClassJsonMixin):
account: str
price: int
conf: int
status: str


@dataclass
class Metadata(DataClassJsonMixin):
symbol: str
Expand All @@ -33,36 +40,71 @@ class Product(DataClassJsonMixin):
prices: List[Price] = field(metadata=config(field_name="price"))


@dataclass
class JSONRPCRequest(DataClassJsonMixin):
id: int
method: str
params: List[Any] | Any
jsonrpc: str = "2.0"


@dataclass
class JSONRPCResponse(DataClassJsonMixin):
id: int
result: Optional[Any]
error: Optional[Any]
jsonrpc: str = "2.0"


class Pythd:
def __init__(
self,
address: str,
) -> None:
self.address = address
self.server: Server
self._tasks = set()
self.client: WebSocketClientProtocol
self.id_counter = 0
self.lock = Lock()

async def connect(self):
self.server = Server(self.address)
task = await self.server.ws_connect()
task.add_done_callback(Pythd._on_connection_done)
self._tasks.add(task)

@staticmethod
def _on_connection_done(task):
log.error("pythd connection closed")
if not task.cancelled() and task.exception() is not None:
e = task.exception()
traceback.print_exception(None, e, e.__traceback__)
sys.exit(1)
self.client = await connect(self.address)

def _create_request(self, method: str, params: List[Any] | Any) -> JSONRPCRequest:
self.id_counter += 1
return JSONRPCRequest(
id=self.id_counter,
method=method,
params=params,
)

async def send_request(self, request: JSONRPCRequest) -> JSONRPCResponse:
async with self.lock:
await self.client.send(request.to_json())
response = await self.client.recv()
return JSONRPCResponse.from_json(response)

async def send_batch_request(
self, requests: List[JSONRPCRequest]
) -> List[JSONRPCResponse]:
async with self.lock:
await self.client.send(JSONRPCRequest.schema().dumps(requests, many=True))
response = await self.client.recv()
return JSONRPCResponse.schema().loads(response, many=True)

async def all_products(self) -> List[Product]:
result = await self.server.get_product_list()
return [Product.from_dict(d) for d in result]
request = self._create_request("get_product_list", [])
result = await self.send_request(request)
if result.result:
return Product.schema().load(result.result, many=True)
else:
raise ValueError(f"Error fetching products: {result.to_json()}")

async def update_price(
self, account: str, price: int, conf: int, status: str
) -> None:
await self.server.update_price(
account=account, price=price, conf=conf, status=status
)
async def update_price_batch(self, price_updates: List[PriceUpdate]) -> None:
requests = [
self._create_request("update_price", price_update.to_dict())
for price_update in price_updates
]
results = await self.send_batch_request(requests)
if any(result.error for result in results):
results_json_str = JSONRPCResponse.schema().dumps(results, many=True)
raise ValueError(f"Error updating prices: {results_json_str}")
Loading

0 comments on commit 4e53387

Please sign in to comment.