From 574847d0cde25220f6e610de950b901da1fcf833 Mon Sep 17 00:00:00 2001 From: zizou <111426680+flopell@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:13:38 +0100 Subject: [PATCH] refactor(simulation-py): make `price()` function optional --- .../tycho_simulation_py/evm/pool_state.py | 67 +++++++++++++------ 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py b/tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py index ebcb0ae4..baa4334a 100644 --- a/tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py +++ b/tycho_simulation_py/python/tycho_simulation_py/evm/pool_state.py @@ -166,28 +166,53 @@ def _set_engine(self): self._engine = engine def _set_marginal_prices(self): - """Set the spot prices for this pool. + """Set the spot prices for this pool.""" + if Capability.PriceFunction in self.capabilities: + for t0, t1 in itertools.permutations(self.tokens, 2): + sell_amount = t0.to_onchain_amount( + self.get_sell_amount_limit(t0, t1) * Decimal("0.01") + ) + frac = self._adapter_contract.price( + cast(HexStr, self.id_), + t0, + t1, + [sell_amount], + block=self.block, + overwrites=self._get_overwrites(t0, t1), + )[0] + if Capability.ScaledPrices in self.capabilities: + self.marginal_prices[(t0, t1)] = frac_to_decimal(frac) + else: + scaled = frac * Fraction(10**t0.decimals, 10**t1.decimals) + self.marginal_prices[(t0, t1)] = frac_to_decimal(scaled) + else: - We currently require the price function capability for now. - """ - self._ensure_capability(Capability.PriceFunction) - for t0, t1 in itertools.permutations(self.tokens, 2): - sell_amount = t0.to_onchain_amount( - self.get_sell_amount_limit(t0, t1) * Decimal("0.01") - ) - frac = self._adapter_contract.price( - cast(HexStr, self.id_), - t0, - t1, - [sell_amount], - block=self.block, - overwrites=self._get_overwrites(t0,t1), - )[0] - if Capability.ScaledPrices in self.capabilities: - self.marginal_prices[(t0, t1)] = frac_to_decimal(frac) - else: - scaled = frac * Fraction(10**t0.decimals, 10**t1.decimals) - self.marginal_prices[(t0, t1)] = frac_to_decimal(scaled) + def swap( + sell_token: EthereumToken, + sell_amount: Decimal, + buy_token: EthereumToken, + ) -> Decimal: + overwrites = self._get_overwrites(sell_token, buy_token) + trade, _ = self._adapter_contract.swap( + cast(HexStr, self.id_), + sell_token, + buy_token, + False, + sell_token.to_onchain_amount(sell_amount), + block=self.block, + overwrites=overwrites, + ) + + buy_amount = buy_token.from_onchain_amount(trade.received_amount) + + return buy_amount + + for t0, t1 in itertools.permutations(self.tokens, 2): + x1 = self.get_sell_amount_limit(t0, t1) * Decimal("0.01") + x2 = x1 + (x1 / 100) + y1 = swap(t0, x1, t1) + y2 = swap(t0, x2, t1) + self.marginal_prices[(t0, t1)] = (y2 - y1) / (x2 - x1) def _ensure_capability(self, capability: Capability): """Ensures the protocol/adapter implement a certain capability."""