diff --git a/web_app/api/dashboard.py b/web_app/api/dashboard.py index aee329e60..46a08ed1a 100644 --- a/web_app/api/dashboard.py +++ b/web_app/api/dashboard.py @@ -55,7 +55,7 @@ async def get_dashboard(wallet_id: str) -> DashboardResponse: # Fetching first 10 positions at the moment opened_positions = position_db_connector.get_positions_by_wallet_id( - wallet_id, 0, 10 + wallet_id ) # At the moment, we only support one position per wallet diff --git a/web_app/api/position.py b/web_app/api/position.py index 11662482b..8bda2a491 100644 --- a/web_app/api/position.py +++ b/web_app/api/position.py @@ -249,7 +249,7 @@ async def get_user_positions(wallet_id: str, start: Optional[int] = None) -> lis start_index = max(0, start) if start is not None else 0 - positions = position_db_connector.get_positions_by_wallet_id( + positions = position_db_connector.get_all_positions_by_wallet_id( wallet_id, start_index, PAGINATION_STEP ) return positions diff --git a/web_app/db/crud/position.py b/web_app/db/crud/position.py index 1e042f01b..f0c2738bc 100644 --- a/web_app/db/crud/position.py +++ b/web_app/db/crud/position.py @@ -97,6 +97,42 @@ def get_positions_by_wallet_id( logger.error(f"Failed to retrieve positions: {str(e)}") return [] + def get_all_positions_by_wallet_id( + self, wallet_id: str, start: int, limit: int + ) -> list: + """ + Retrieves paginated positions for a user by their wallet ID + and returns them as a list of dictionaries. + :param wallet_id: str + :param start: starting index for pagination + :param limit: number of records to return + :return: list of dict + """ + with self.Session() as db: + user = self._get_user_by_wallet_id(wallet_id) + if not user: + return [] + + try: + positions = ( + db.query(Position) + .filter( + Position.user_id == user.id, + ) + .offset(start) + .limit(limit) + .all() + ) + # Convert positions to a list of dictionaries + positions_dicts = [ + self._position_to_dict(position) for position in positions + ] + return positions_dict + + except SQLAlchemyError as e: + logger.error(f"Failed to retrieve positions: {str(e)}") + return [] + def has_opened_position(self, wallet_id: str) -> bool: """ Checks if a user has any opened positions. diff --git a/web_app/tests/test_positions.py b/web_app/tests/test_positions.py index 165b6fe62..963a26551 100644 --- a/web_app/tests/test_positions.py +++ b/web_app/tests/test_positions.py @@ -470,7 +470,7 @@ async def test_get_user_positions_success(client: TestClient) -> None: ] with patch( - "web_app.db.crud.PositionDBConnector.get_positions_by_wallet_id" + "web_app.db.crud.PositionDBConnector.get_all_positions_by_wallet_id" ) as mock_get_positions: mock_get_positions.return_value = mock_positions response = client.get(f"/api/user-positions/{wallet_id}")