Skip to content

Commit

Permalink
refactor: remove subscription and use internal interval
Browse files Browse the repository at this point in the history
  • Loading branch information
ali-bahjati committed Sep 13, 2024
1 parent 0f95b2d commit a591d8a
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 82 deletions.
3 changes: 2 additions & 1 deletion config/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# the configuration for the chosen engine as described below.
provider_engine = 'pyth_replicator'

product_update_interval_secs = 10
price_update_interval_secs = 1.0
product_update_interval_secs = 60
health_check_port = 8000

# The health check will return a failure status if no price data has been published within the specified time frame.
Expand Down
1 change: 1 addition & 0 deletions example_publisher/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class Config:
pythd: Pythd
health_check_port: int
health_check_threshold_secs: int
price_update_interval_secs: float = ts.option(default=1.0)
product_update_interval_secs: int = ts.option(default=60)
coin_gecko: Optional[CoinGeckoConfig] = ts.option(default=None)
pyth_replicator: Optional[PythReplicatorConfig] = ts.option(default=None)
4 changes: 2 additions & 2 deletions example_publisher/providers/pyth_replicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def _update_loop(self) -> None:

while True:
update = await self._ws.next_update()
log.debug("Received a WS update", account_key=update.key, slot=update.slot)
log.trace("Received a WS update", account_key=update.key, slot=update.slot)
if isinstance(update, PythPriceAccount) and update.product is not None:
symbol = update.product.symbol

Expand Down Expand Up @@ -118,7 +118,7 @@ async def _update_accounts_loop(self) -> None:

await asyncio.sleep(self._config.account_update_interval_secs)

def upd_products(self, *args) -> None:
def upd_products(self, product_symbols: List[Symbol]) -> None:
# This provider stores all the possible feeds and
# does not care about the desired products as knowing
# them does not improve the performance of the replicator
Expand Down
89 changes: 32 additions & 57 deletions example_publisher/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def __init__(self, config: Config) -> None:

self.pythd: Pythd = Pythd(
address=config.pythd.endpoint,
on_notify_price_sched=self.on_notify_price_sched,
)
self.subscriptions: Dict[SubscriptionId, Product] = {}
self.products: List[Product] = []
Expand All @@ -66,18 +65,17 @@ def is_healthy(self) -> bool:
async def start(self):
await self.pythd.connect()

self._product_update_task = asyncio.create_task(
self._start_product_update_loop()
)
self._product_update_task = asyncio.create_task(self._product_update_loop())

self._price_update_task = asyncio.create_task(self._price_update_loop())

async def _start_product_update_loop(self):
await self._upd_products()
self.provider.start()

async def _product_update_loop(self):
while True:
await self._upd_products()
await self._subscribe_notify_price_sched()
await asyncio.sleep(self.config.product_update_interval_secs)
await self._upd_products()

async def _upd_products(self):
log.debug("fetching product accounts from Pythd")
Expand Down Expand Up @@ -114,58 +112,35 @@ async def _upd_products(self):

self.provider.upd_products([product.symbol for product in self.products])

async def _subscribe_notify_price_sched(self):
# Subscribe to Pythd's notify_price_sched for each product that
# is not subscribed yet. Unfortunately there is no way to unsubscribe
# to the prices that are no longer available.
log.debug("subscribing to notify_price_sched")

subscriptions = {}
for product in self.products:
if not product.subscription_id:
subscription_id = await self.pythd.subscribe_price_sched(
product.price_account
async def _price_update_loop(self):
while True:
for product in self.products:
price = self.provider.latest_price(product.symbol)
if not price:
log.info("latest price not available", symbol=product.symbol)
continue

scaled_price = self.apply_exponent(price.price, product.exponent)
scaled_conf = self.apply_exponent(price.conf, product.exponent)

log.info(
"sending update_price",
product_account=product.product_account,
price_account=product.price_account,
price=scaled_price,
conf=scaled_conf,
symbol=product.symbol,
)
await self.pythd.update_price(
product.price_account, scaled_price, scaled_conf, TRADING
)
self.last_successful_update = (
price.timestamp
if self.last_successful_update is None
else max(self.last_successful_update, price.timestamp)
)
product.subscription_id = subscription_id

subscriptions[product.subscription_id] = product

self.subscriptions = subscriptions

async def on_notify_price_sched(self, subscription: int) -> None:

log.debug("received notify_price_sched", subscription=subscription)
if subscription not in self.subscriptions:
return

# Look up the current price and confidence interval of the product
product = self.subscriptions[subscription]
price = self.provider.latest_price(product.symbol)
if not price:
log.info("latest price not available", symbol=product.symbol)
return

# Scale the price and confidence interval using the Pyth exponent
scaled_price = self.apply_exponent(price.price, product.exponent)
scaled_conf = self.apply_exponent(price.conf, product.exponent)

# Send the price update
log.info(
"sending update_price",
product_account=product.product_account,
price_account=product.price_account,
price=scaled_price,
conf=scaled_conf,
symbol=product.symbol,
)
await self.pythd.update_price(
product.price_account, scaled_price, scaled_conf, TRADING
)
self.last_successful_update = (
price.timestamp
if self.last_successful_update is None
else max(self.last_successful_update, price.timestamp)
)
await asyncio.sleep(self.config.price_update_interval_secs)

def apply_exponent(self, x: float, exp: int) -> int:
return int(x * (10 ** (-exp)))
23 changes: 1 addition & 22 deletions example_publisher/pythd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
from dataclasses import dataclass, field
import sys
import traceback
from dataclasses_json import config, DataClassJsonMixin
from typing import Callable, Coroutine, List
from typing import List
from structlog import get_logger
from jsonrpc_websocket import Server

Expand Down Expand Up @@ -38,16 +37,13 @@ class Pythd:
def __init__(
self,
address: str,
on_notify_price_sched: Callable[[SubscriptionId], Coroutine[None, None, None]],
) -> None:
self.address = address
self.server: Server
self.on_notify_price_sched = on_notify_price_sched
self._tasks = set()

async def connect(self):
self.server = Server(self.address)
self.server.notify_price_sched = self._notify_price_sched
task = await self.server.ws_connect()
task.add_done_callback(Pythd._on_connection_done)
self._tasks.add(task)
Expand All @@ -60,23 +56,6 @@ def _on_connection_done(task):
traceback.print_exception(None, e, e.__traceback__)
sys.exit(1)

async def subscribe_price_sched(self, account: str) -> int:
subscription = (await self.server.subscribe_price_sched(account=account))[
"subscription"
]
log.debug(
"subscribed to price_sched", account=account, subscription=subscription
)
return subscription

def _notify_price_sched(self, subscription: int) -> None:
log.debug("notify_price_sched RPC call received", subscription=subscription)
task = asyncio.get_event_loop().create_task(
self.on_notify_price_sched(subscription)
)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)

async def all_products(self) -> List[Product]:
result = await self.server.get_product_list()
return [Product.from_dict(d) for d in result]
Expand Down

0 comments on commit a591d8a

Please sign in to comment.