Skip to content

Commit

Permalink
Merge pull request #523 from djeck1432/feat/add-limit-to-possition-hi…
Browse files Browse the repository at this point in the history
…story-endpoint

[Backend] add limit prams to possiton history endpoint
  • Loading branch information
djeck1432 authored Jan 28, 2025
2 parents 9a1e604 + 9bc8b85 commit 09db61c
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 23 deletions.
43 changes: 26 additions & 17 deletions web_app/api/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from typing import Optional
from uuid import UUID

from fastapi import APIRouter, HTTPException, Request
from fastapi import APIRouter, HTTPException, Query, Request

from web_app.api.serializers.position import (
AddPositionDepositData,
PositionFormData,
TokenMultiplierResponse,
UserPositionResponse,
UserPositionExtraDepositsResponse,
UserPositionHistoryResponse,
)
from web_app.api.serializers.transaction import (
LoopLiquidityData,
Expand Down Expand Up @@ -311,27 +311,37 @@ async def add_extra_deposit(position_id: UUID, data: AddPositionDepositData):
@router.get(
"/api/user-positions/{wallet_id}",
tags=["Position Operations"],
response_model=list[UserPositionResponse],
response_model=UserPositionHistoryResponse,
summary="Get all positions for a user",
response_description="Returns paginated list of positions for the given wallet ID",
response_description="Returns paginated of positions for the given wallet ID",
)
async def get_user_positions(wallet_id: str, start: Optional[int] = None) -> list:
async def get_user_positions(
wallet_id: str,
start: int = Query(0, ge=0),
limit: int = Query(PAGINATION_STEP, ge=1, le=100),
) -> UserPositionHistoryResponse:
"""
Get all positions for a specific user by their wallet ID.
:param wallet_id: The wallet ID of the user
:param start: Optional starting index for pagination (0-based). If not provided, defaults to 0
:return: UserPositionsListResponse containing paginated list of positions
:param wallet_id: Wallet ID of the user
:param start: Starting index for pagination (default: 0)
:param limit: Number of items per page (default: 10 from PAGINATION_STEP variable)
:return: UserPositionHistoryResponse with positions and total count
:raises: HTTPException: If wallet ID is empty or invalid
"""
if not wallet_id:
raise HTTPException(status_code=400, detail="Wallet ID is required")

start_index = max(0, start) if start is not None else 0

positions = position_db_connector.get_all_positions_by_wallet_id(
wallet_id, start_index, PAGINATION_STEP
wallet_id, start=start, limit=limit
)
total_positions = position_db_connector.get_count_positions_by_wallet_id(wallet_id)

return UserPositionHistoryResponse(
positions=positions,
total_count=total_positions
)
return positions


@router.get(
Expand All @@ -348,8 +358,7 @@ async def get_list_of_deposited_tokens(position_id: UUID):
:return Dict containing main position and extra positions
"""
main_position = position_db_connector.get_position_by_id(position_id)
extra_deposits = position_db_connector.get_extra_deposits_by_position_id(position_id)
return {
"main": main_position,
"extra_deposits": extra_deposits
}
extra_deposits = position_db_connector.get_extra_deposits_by_position_id(
position_id
)
return {"main": main_position, "extra_deposits": extra_deposits}
13 changes: 13 additions & 0 deletions web_app/api/serializers/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,16 @@ class UserPositionExtraDepositsResponse(BaseModel):
"""
main: UserPositionResponse
extra_deposits: list[UserExtraDeposit]


class UserPositionHistoryResponse(BaseModel):
"""
Response model for user position history with pagination.
### Attributes:
- **positions**: List of user positions
- **total_count**: Total number of positions for pagination
"""

positions: List[UserPositionResponse] = []
total_count: int = 0
26 changes: 25 additions & 1 deletion web_app/db/crud/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_positions_by_wallet_id(

def get_all_positions_by_wallet_id(
self, wallet_id: str, start: int, limit: int
) -> list:
) -> list[dict]:
"""
Retrieves paginated positions for a user by their wallet ID
and returns them as a list of dictionaries.
Expand Down Expand Up @@ -138,6 +138,30 @@ def get_all_positions_by_wallet_id(
except SQLAlchemyError as e:
logger.error(f"Failed to retrieve positions: {str(e)}")
return []

def get_count_positions_by_wallet_id(self, wallet_id: str) -> int:
"""
Counts total number of positions for a user.
:param wallet_id: Wallet ID of the user
:return: Total number of positions
"""
with self.Session() as db:
user = self._get_user_by_wallet_id(wallet_id)
if not user:
return 0

try:
total_positions = (
db.query(func.count(Position.id))
.filter(Position.user_id == user.id)
.scalar()
)
return total_positions or 0

except SQLAlchemyError as e:
logger.error(f"Failed to count user positions: {str(e)}")
return 0

def has_opened_position(self, wallet_id: str) -> bool:
"""
Expand Down
16 changes: 11 additions & 5 deletions web_app/tests/test_positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,18 +469,24 @@ async def test_get_user_positions_success(client: TestClient) -> None:
"is_liquidated": False,
}
]
mock_total_count = len(mock_positions)

with patch(
"web_app.db.crud.PositionDBConnector.get_all_positions_by_wallet_id"
) as mock_get_positions:
) as mock_get_positions, patch(
"web_app.db.crud.PositionDBConnector.get_count_positions_by_wallet_id"
) as mock_get_count_positions:
mock_get_positions.return_value = mock_positions
mock_get_count_positions.return_value = mock_total_count

response = client.get(f"/api/user-positions/{wallet_id}")

assert response.status_code == 200
data = response.json()
assert len(data) == len(mock_positions)
assert data[0]["token_symbol"] == mock_positions[0]["token_symbol"]
assert data[0]["amount"] == mock_positions[0]["amount"]
assert len(data["positions"]) == len(mock_positions)
assert data["total_count"] == mock_total_count
assert data["positions"][0]["token_symbol"] == mock_positions[0]["token_symbol"]
assert data["positions"][0]["amount"] == mock_positions[0]["amount"]


@pytest.mark.asyncio
Expand All @@ -506,7 +512,7 @@ async def test_get_user_positions_no_positions(client: AsyncClient) -> None:

assert response.status_code == 200
data = response.json()
assert data == []
assert data == {"positions": [], "total_count": 0}


@pytest.mark.parametrize(
Expand Down

0 comments on commit 09db61c

Please sign in to comment.