Skip to content

Commit

Permalink
get training week fix + integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
voynow committed Nov 10, 2024
1 parent c77b005 commit 9761651
Show file tree
Hide file tree
Showing 9 changed files with 617 additions and 760 deletions.
344 changes: 343 additions & 1 deletion api/poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ uvicorn = "^0.32.0"
supabase = "^2.10.0"
python-dotenv = "^1.0.1"
pyjwt = "^2.9.0"
stravalib = "^2.1"

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.3"

[build-system]
requires = ["poetry-core"]
Expand Down
788 changes: 52 additions & 736 deletions api/requirements.txt

Large diffs are not rendered by default.

126 changes: 108 additions & 18 deletions api/src/auth_manager.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,135 @@
import logging
import os
from typing import Optional

import jwt
from fastapi import HTTPException, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jwt import PyJWTError

from api.src.types.user import UserRow
from src import supabase_client
from src.types.user import UserAuthRow, UserRow
from stravalib.client import Client

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
bearer_scheme = HTTPBearer()

SECRET_KEY = os.getenv("JWT_SECRET_KEY", "default_secret")
strava_client = Client()


def generate_jwt(athlete_id: int, expires_at: int) -> str:
"""
Generate a JWT token using athlete_id and expiration time, aligning token
expiration cycle with the athlete's Strava token
:param athlete_id: strava internal identifier
:param expires_at: expiration time of strava token
:return: str
"""
payload = {"athlete_id": athlete_id, "exp": expires_at}
token = jwt.encode(payload, os.environ["JWT_SECRET"], algorithm="HS256")
return token


def decode_jwt(token: str) -> Optional[dict]:
def decode_jwt(jwt_token: str, verify_exp: bool = True) -> int:
"""
Decode and validate a JWT token
Decode JWT token and return athlete_id
:param token: The JWT token
:return: Decoded payload if valid, otherwise raises an HTTPException
:param jwt_token: JWT token
:param verify_exp: whether to verify expiration
:return: int if successful, None if decoding fails
:raises: jwt.DecodeError if token is invalid
"""
payload = jwt.decode(
jwt_token,
os.environ["JWT_SECRET"],
algorithms=["HS256"],
options={"verify_exp": verify_exp},
)
return payload["athlete_id"]


def refresh_and_update_user_token(athlete_id: int, refresh_token: str) -> UserAuthRow:
"""
Refresh the user's Strava token and update database
:param athlete_id: strava internal identifier
:param refresh_token: refresh token for Strava API
:return: UserAuthRow
"""
logger.info(f"Refreshing and updating token for athlete {athlete_id}")
access_info = strava_client.refresh_access_token(
client_id=os.environ["STRAVA_CLIENT_ID"],
client_secret=os.environ["STRAVA_CLIENT_SECRET"],
refresh_token=refresh_token,
)
new_jwt_token = generate_jwt(
athlete_id=athlete_id, expires_at=access_info["expires_at"]
)

user_auth = UserAuthRow(
athlete_id=athlete_id,
access_token=access_info["access_token"],
refresh_token=access_info["refresh_token"],
expires_at=access_info["expires_at"],
jwt_token=new_jwt_token,
device_token=supabase_client.get_device_token(athlete_id),
)
supabase_client.upsert_user_auth(user_auth)
return user_auth


def validate_and_refresh_token(token: str) -> int:
"""
Validate and refresh the user's credentials in DB
:param token: JWT token
:return: athlete_id
"""
try:
decoded_token = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
return decoded_token
except PyJWTError:
raise HTTPException(status_code=401, detail="Invalid or expired token")
athlete_id = decode_jwt(token)
except jwt.ExpiredSignatureError:
try:
# If the token is expired, decode athlete_id and refresh
athlete_id = decode_jwt(token, verify_exp=False)
user_auth = supabase_client.get_user_auth(athlete_id)
refresh_and_update_user_token(
athlete_id=athlete_id, refresh_token=user_auth.refresh_token
)
except jwt.DecodeError:
logger.error("Invalid JWT token")
raise HTTPException(status_code=401, detail="Invalid JWT token")
except Exception as e:
logger.error(
f"Unknown error validating and refreshing token: {e}",
exc_info=True,
)
raise HTTPException(status_code=500, detail="Internal server error")
except jwt.DecodeError:
logger.error("Invalid JWT token")
raise HTTPException(status_code=401, detail="Invalid JWT token")
except Exception as e:
logger.error(
f"Unknown error validating and refreshing token: {e}",
exc_info=True,
)
raise HTTPException(status_code=500, detail="Internal server error")

return athlete_id


async def get_current_user(
async def validate_user(
credentials: HTTPAuthorizationCredentials = Security(bearer_scheme),
) -> UserRow:
"""
Dependency that validates the JWT token from the Authorization header
:param credentials: Bearer token credentials
:return: Decoded JWT payload containing user details
:return: UserRow
"""
token = credentials.credentials
payload = decode_jwt(token)
if payload is None:
athlete_id = validate_and_refresh_token(credentials.credentials)
if athlete_id is None:
logger.error("Invalid authentication credentials")
raise HTTPException(
status_code=401, detail="Invalid authentication credentials"
)
return payload
return supabase_client.get_user(athlete_id)
17 changes: 12 additions & 5 deletions api/src/main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
import logging

from fastapi import Depends, FastAPI, HTTPException
from src.auth_manager import get_current_user
from src.supabase_client import TrainingWeek, get_training_week
from src import supabase_client
from src.auth_manager import validate_user
from src.types.training_week import TrainingWeek
from src.types.user import UserRow

app = FastAPI()

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@app.get("/training_week/", response_model=TrainingWeek)
async def training_week_endpoint(user: UserRow = Depends(get_current_user)):
async def training_week(user: UserRow = Depends(validate_user)):
"""
Retrieve the most recent training_week row by athlete_id
curl -X GET "http://localhost:8000/training_week/{athlete_id}" \
curl -X GET "http://trackflow-alb-499532887.us-east-1.elb.amazonaws.com/training_week/" \
-H "Authorization: Bearer YOUR_JWT_TOKEN"
:param athlete_id: The athlete_id to retrieve the training_week for
:return: The most recent training_week row for the athlete
"""
try:
return get_training_week(user.athlete_id)
return supabase_client.get_training_week(user.athlete_id)
except ValueError as e:
logger.error(f"Error retrieving training week: {e}", exc_info=True)
raise HTTPException(status_code=404, detail=str(e))
65 changes: 65 additions & 0 deletions api/src/supabase_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import datetime
import json
import os
from typing import Optional

from dotenv import load_dotenv
from src.types.training_week import TrainingWeek
from src.types.user import UserAuthRow, UserRow
from supabase import Client, create_client

load_dotenv()
Expand All @@ -19,6 +22,52 @@ def init() -> Client:
client = init()


def get_device_token(athlete_id: int) -> Optional[str]:
"""
Get the device token for a user in the database.
:param athlete_id: The athlete's ID
:return: The device token for the user, or None if the user does not exist
"""
try:
user_auth = get_user_auth(athlete_id)
return user_auth.device_token
except ValueError:
return None


def get_user(athlete_id: int) -> UserRow:
"""
Get a user by athlete_id
:param athlete_id: int
:return: UserRow
"""
table = client.table("user")
response = table.select("*").eq("athlete_id", athlete_id).execute()

if not response.data:
raise ValueError(f"Could not find user with {athlete_id=}")

return UserRow(**response.data[0])


def get_user_auth(athlete_id: int) -> UserAuthRow:
"""
Get user_auth row by athlete_id
:param athlete_id: int
:return: APIResponse
"""
table = client.table("user_auth")
response = table.select("*").eq("athlete_id", athlete_id).execute()

if not response.data:
raise ValueError(f"Cound not find user_auth row with {athlete_id=}")

return UserAuthRow(**response.data[0])


def get_training_week(athlete_id: int) -> TrainingWeek:
"""
Get the most recent training_week row by athlete_id.
Expand All @@ -39,3 +88,19 @@ def get_training_week(athlete_id: int) -> TrainingWeek:
raise ValueError(
f"Could not find training_week row for athlete_id {athlete_id}"
)


def upsert_user_auth(user_auth_row: UserAuthRow) -> None:
"""
Convert UserAuthRow to a dictionary, ensure json serializable expires_at,
and upsert into user_auth table handling duplicates on athlete_id
:param user_auth_row: A dictionary representation of UserAuthRow
:return: APIResponse
"""
user_auth_row = user_auth_row.dict()
if isinstance(user_auth_row["expires_at"], datetime.datetime):
user_auth_row["expires_at"] = user_auth_row["expires_at"].isoformat()

table = client.table("user_auth")
table.upsert(user_auth_row, on_conflict="athlete_id").execute()
9 changes: 9 additions & 0 deletions api/src/types/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,12 @@ class UserRow(BaseModel):
preferences: Optional[Preferences] = Preferences()
email: Optional[str] = None
created_at: datetime = datetime.now()


class UserAuthRow(BaseModel):
athlete_id: int
access_token: str
refresh_token: str
expires_at: datetime
jwt_token: str
device_token: Optional[str] = None
5 changes: 5 additions & 0 deletions api/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os
import sys

src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(src_path)
19 changes: 19 additions & 0 deletions api/tests/test_get_training_week.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

from fastapi.testclient import TestClient
from src.main import app
from src.supabase_client import get_user_auth
from src.types.training_week import TrainingWeek

client = TestClient(app)


def test_get_training_week_success():
"""Test successful retrieval of training week"""
user_auth = get_user_auth(os.environ["JAMIES_ATHLETE_ID"])

response = client.get(
"/training_week/", headers={"Authorization": f"Bearer {user_auth.jwt_token}"}
)
assert TrainingWeek(**response.json())
assert response.status_code == 200

0 comments on commit 9761651

Please sign in to comment.