diff --git a/nibiru/query_clients/util.py b/nibiru/query_clients/util.py index 945e446c..8da29f24 100644 --- a/nibiru/query_clients/util.py +++ b/nibiru/query_clients/util.py @@ -1,10 +1,24 @@ +from typing import Dict, List + +from google.protobuf import message as protobuf_message from google.protobuf.json_format import MessageToDict -from nibiru.utils import from_sdk_dec +from nibiru.utils import from_sdk_dec, from_sdk_int + +PROTOBUF_MSG_BASE_ATTRS: List[str] = ( + dir(protobuf_message.Message) + + ['Extensions', 'FindInitializationErrors', '_CheckCalledFromGeneratedFile'] + + ['_extensions_by_name', '_extensions_by_number'] +) +"""PROTOBUF_MSG_BASE_ATTRS (List[str]): The default attributes and methods of +an instance of the 'protobuf.message.Message' class. +""" -def camel_to_snake(s): - return ''.join(['_' + c.lower() if c.isupper() else c for c in s]).lstrip('_') +def camel_to_snake(camel: str): + return ''.join( + ['_' + char.lower() if char.isupper() else char for char in camel] + ).lstrip('_') def t_dict(d): @@ -16,16 +30,75 @@ def t_dict(d): } -def deserialize(proto_message: object) -> dict: +def deserialize(pb_msg: protobuf_message.Message) -> dict: + """Deserializes a proto message into a dictionary. + + - sdk.Dec values are converted to floats. + - sdk.Int values are converted to ints. + - Missing fields become blank strings. + + Args: + pb_msg (protobuf.message.Message) + + Returns: + dict: 'pb_msg' as a JSON-able dictionary. + """ + if not isinstance(pb_msg, protobuf_message.Message): + raise TypeError(f"expted protobuf Message for 'pb_msg', not {type(pb_msg)}") + custom_dtypes: Dict[str, bytes] = { + str(field[1]): field[0].GetOptions().__getstate__().get("serialized", None) + for field in pb_msg.ListFields() + } + serialized_output = {} + expected_fields: List[str] = [ + attr for attr in dir(pb_msg) if attr not in PROTOBUF_MSG_BASE_ATTRS + ] + + for _, attr in enumerate(expected_fields): + + attr_search = pb_msg.__getattribute__(attr) + custom_dtype = custom_dtypes.get(str(attr_search)) + + if custom_dtype is not None: + + if "sdk/types.Dec" in str(custom_dtype): + serialized_output[str(attr)] = from_sdk_dec( + pb_msg.__getattribute__(attr) + ) + elif "sdk/types.Int" in str(custom_dtype): + serialized_output[str(attr)] = from_sdk_int( + pb_msg.__getattribute__(attr) + ) + else: + try: + val = pb_msg.__getattribute__(attr) + if hasattr(val, '__len__') and not isinstance(val, str): + updated_vals = [] + for v in val: + updated_vals.append(deserialize(v)) + serialized_output[str(attr)] = updated_vals + else: + serialized_output[str(attr)] = deserialize(val) + except: + serialized_output[str(attr)] = pb_msg.__getattribute__(attr) + elif (custom_dtype is None) and (attr_search == ''): + serialized_output[str(attr)] = "" + else: + serialized_output[str(attr)] = deserialize(pb_msg.__getattribute__(attr)) + + return serialized_output + + +def deserialize_exp(proto_message: protobuf_message.Message) -> dict: """ Take a proto message and convert it into a dictionnary. sdk.Dec values are converted to be consistent with txs. Args: - proto_message (object): The proto message + proto_message (protobuf.message.Message) Returns: - dict: The dictionary + dict """ output = MessageToDict(proto_message) @@ -38,7 +111,7 @@ def deserialize(proto_message: object) -> dict: if field.message_type is not None: # This is another proto object try: - output[field.camelcase_name] = deserialize( + output[field.camelcase_name] = deserialize_exp( proto_message.__getattribute__(field.camelcase_name) ) except AttributeError: diff --git a/nibiru/sdk.py b/nibiru/sdk.py index 0fb1874d..f9f5f632 100644 --- a/nibiru/sdk.py +++ b/nibiru/sdk.py @@ -18,16 +18,25 @@ class Sdk: - """ - The Sdk class creates an interface to sign and send transactions or execute queries from a node. + """The Sdk class creates an interface to sign and send transactions or execute + queries from a node. + It is associated to: - - a wallet, which can be either created or recovered from an existing mnemonic. - - a network, defining the node to connect to - - optionally a configuration defining how to behave and the gas configuration for each transaction + - a wallet, which can be either created or recovered from an existing mnemonic. + - a network, defining the node to connect to + - optionally a configuration defining how to behave and the gas configuration + for each transaction Each method starting with `with_` will replace the existing Sdk object with a new version having the defined behavior. + Attributes: + priv_key + query + tx + network + tx_config + Example :: @@ -38,6 +47,11 @@ class Sdk: ) """ + query: GrpcClient + network: Network + tx: BaseTxClient + tx_config: TxConfig + def __init__(self, _error_do_not_use_init_directly=None) -> None: """Unsupported, please use from_mnemonic to initialize.""" if not _error_do_not_use_init_directly: diff --git a/nibiru/utils.py b/nibiru/utils.py index 8747a4af..e0947c7d 100644 --- a/nibiru/utils.py +++ b/nibiru/utils.py @@ -167,8 +167,8 @@ def to_sdk_int(i: float) -> str: return str(int(i)) -def from_sdk_int(int_str: str) -> float: - return float(int_str) +def from_sdk_int(int_str: str) -> int: + return int(int_str) def toPbTimestamp(dt: datetime): diff --git a/tests/__init__.py b/tests/__init__.py index b468b72b..06c1c5e9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,7 @@ """Tests for the nibiru package""" import logging import sys +from typing import Iterable import shutup @@ -25,7 +26,10 @@ def init_test_logger() -> logging.Logger: """Simple logger to use throughout the test suite.""" -def dict_keys_must_match(dict_: dict, keys: list[str]): +def dict_keys_must_match(dict_: dict, keys: Iterable[str]): + keys = list(keys) + if not isinstance(dict_, dict): + raise TypeError(f"'dict' must be a dicitonary, not {type(dict_)}") assert len(dict_.keys()) == len(keys) for key in dict_.keys(): assert key in keys diff --git a/tests/chain_info_test.py b/tests/chain_info_test.py index 3e4144a2..a12b0433 100644 --- a/tests/chain_info_test.py +++ b/tests/chain_info_test.py @@ -33,12 +33,3 @@ def test_query_perp_params(val_node: Sdk): "twapLookbackWindow", ] assert all([(param_name in params) for param_name in perp_param_names]) - - -def test_query_vpool_reserve_assets(val_node: Sdk): - expected_pairs: List[str] = ["ubtc:unusd", "ueth:unusd"] - for pair in expected_pairs: - query_resp: dict = val_node.query.vpool.reserve_assets(pair) - assert isinstance(query_resp, dict) - assert query_resp["base_asset_reserve"] > 0 - assert query_resp["quote_asset_reserve"] > 0 diff --git a/tests/perp_test.py b/tests/perp_test.py index 1404c86e..159e6fb4 100644 --- a/tests/perp_test.py +++ b/tests/perp_test.py @@ -1,6 +1,6 @@ # perp_test.py +import pytest from grpc._channel import _InactiveRpcError -from pytest import approx, raises import nibiru import nibiru.msg @@ -25,7 +25,7 @@ def test_open_close_position(val_node: nibiru.Sdk, agent: nibiru.Sdk): ) # Exception must be raised when requesting not existing position - with raises(_InactiveRpcError, match="no position found"): + with pytest.raises(_InactiveRpcError, match="no position found"): agent.query.perp.trader_position(trader=agent.address, token_pair=pair) # Transaction open_position must succeed @@ -57,12 +57,12 @@ def test_open_close_position(val_node: nibiru.Sdk, agent: nibiru.Sdk): ], ) # Margin ratio must be ~10% - assert position_res["margin_ratio_mark"] == approx(0.1, PRECISION) + assert position_res["margin_ratio_mark"] == pytest.approx(0.1, PRECISION) position = position_res["position"] assert position["margin"] == 10.0 assert position["open_notional"] == 100.0 - assert position["size"] == approx(0.005, PRECISION) + assert position["size"] == pytest.approx(0.005, PRECISION) # Transaction add_margin must succeed tx_output = agent.tx.execute_msgs( @@ -103,5 +103,5 @@ def test_open_close_position(val_node: nibiru.Sdk, agent: nibiru.Sdk): transaction_must_succeed(tx_output) # Exception must be raised when querying closed position - with raises(_InactiveRpcError, match="no position found"): + with pytest.raises(_InactiveRpcError, match="no position found"): agent.query.perp.trader_position(trader=agent.address, token_pair=pair) diff --git a/tests/vpool_test.py b/tests/vpool_test.py new file mode 100644 index 00000000..859b0a26 --- /dev/null +++ b/tests/vpool_test.py @@ -0,0 +1,66 @@ +import pprint +from typing import Dict, List + +import nibiru +import tests +from nibiru import common + + +def test_query_vpool_reserve_assets(val_node: nibiru.Sdk): + expected_pairs: List[str] = ["ubtc:unusd", "ueth:unusd"] + for pair in expected_pairs: + query_resp: dict = val_node.query.vpool.reserve_assets(pair) + assert isinstance(query_resp, dict) + assert query_resp["base_asset_reserve"] > 0 + assert query_resp["quote_asset_reserve"] > 0 + + +def test_query_vpool_all_pools(agent: nibiru.Sdk): + """Tests deserialization and expected attributes for the + 'nibid query vpool all-pools' command. + """ + + query_resp: Dict[str, List[dict]] = agent.query.vpool.all_pools() + tests.dict_keys_must_match(query_resp, keys=["pools", "prices"]) + + all_vpools: List[dict] = query_resp["pools"] + vpool_fields: List[str] = [ + "base_asset_reserve", + "fluctuation_limit_ratio", + "maintenance_margin_ratio", + "max_leverage", + "max_oracle_spread_ratio", + "pair", + "quote_asset_reserve", + "trade_limit_ratio", + ] + tests.dict_keys_must_match(all_vpools[0], keys=vpool_fields) + + all_vpool_prices = query_resp["prices"] + price_fields: List[str] = [ + "block_number", + "index_price", + "mark_price", + "swap_invariant", + "twap_mark", + "pair", + ] + tests.dict_keys_must_match(all_vpool_prices[0], keys=price_fields) + + vpool_prices = all_vpool_prices[0] + assert isinstance(vpool_prices["block_number"], int), "block_number" + assert isinstance(vpool_prices["index_price"], float), "index_price" + assert isinstance(vpool_prices["mark_price"], float), "mark_price" + assert isinstance(vpool_prices["swap_invariant"], int), "swap_invariant" + assert isinstance(vpool_prices["twap_mark"], float), "twap_mark" + assert isinstance(vpool_prices["pair"], str), "pair" + tests.LOGGER.info(f"vpool_prices: {pprint.pformat(vpool_prices, indent=3)}") + + +def test_query_vpool_base_asset_price(agent: nibiru.Sdk): + query_resp: Dict[str, List[dict]] = agent.query.vpool.base_asset_price( + pair="ueth:unusd", direction=common.Direction.ADD, base_asset_amount="15" + ) + tests.dict_keys_must_match(query_resp, keys=["price_in_quote_denom"]) + assert isinstance(query_resp["price_in_quote_denom"], float) + assert query_resp["price_in_quote_denom"] > 0