-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
get training week fix + integration test
- Loading branch information
Showing
9 changed files
with
617 additions
and
760 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,45 +1,135 @@ | ||
import logging | ||
import os | ||
from typing import Optional | ||
|
||
import jwt | ||
from fastapi import HTTPException, Security | ||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer | ||
from jwt import PyJWTError | ||
|
||
from api.src.types.user import UserRow | ||
from src import supabase_client | ||
from src.types.user import UserAuthRow, UserRow | ||
from stravalib.client import Client | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
bearer_scheme = HTTPBearer() | ||
|
||
SECRET_KEY = os.getenv("JWT_SECRET_KEY", "default_secret") | ||
strava_client = Client() | ||
|
||
|
||
def generate_jwt(athlete_id: int, expires_at: int) -> str: | ||
""" | ||
Generate a JWT token using athlete_id and expiration time, aligning token | ||
expiration cycle with the athlete's Strava token | ||
:param athlete_id: strava internal identifier | ||
:param expires_at: expiration time of strava token | ||
:return: str | ||
""" | ||
payload = {"athlete_id": athlete_id, "exp": expires_at} | ||
token = jwt.encode(payload, os.environ["JWT_SECRET"], algorithm="HS256") | ||
return token | ||
|
||
|
||
def decode_jwt(token: str) -> Optional[dict]: | ||
def decode_jwt(jwt_token: str, verify_exp: bool = True) -> int: | ||
""" | ||
Decode and validate a JWT token | ||
Decode JWT token and return athlete_id | ||
:param token: The JWT token | ||
:return: Decoded payload if valid, otherwise raises an HTTPException | ||
:param jwt_token: JWT token | ||
:param verify_exp: whether to verify expiration | ||
:return: int if successful, None if decoding fails | ||
:raises: jwt.DecodeError if token is invalid | ||
""" | ||
payload = jwt.decode( | ||
jwt_token, | ||
os.environ["JWT_SECRET"], | ||
algorithms=["HS256"], | ||
options={"verify_exp": verify_exp}, | ||
) | ||
return payload["athlete_id"] | ||
|
||
|
||
def refresh_and_update_user_token(athlete_id: int, refresh_token: str) -> UserAuthRow: | ||
""" | ||
Refresh the user's Strava token and update database | ||
:param athlete_id: strava internal identifier | ||
:param refresh_token: refresh token for Strava API | ||
:return: UserAuthRow | ||
""" | ||
logger.info(f"Refreshing and updating token for athlete {athlete_id}") | ||
access_info = strava_client.refresh_access_token( | ||
client_id=os.environ["STRAVA_CLIENT_ID"], | ||
client_secret=os.environ["STRAVA_CLIENT_SECRET"], | ||
refresh_token=refresh_token, | ||
) | ||
new_jwt_token = generate_jwt( | ||
athlete_id=athlete_id, expires_at=access_info["expires_at"] | ||
) | ||
|
||
user_auth = UserAuthRow( | ||
athlete_id=athlete_id, | ||
access_token=access_info["access_token"], | ||
refresh_token=access_info["refresh_token"], | ||
expires_at=access_info["expires_at"], | ||
jwt_token=new_jwt_token, | ||
device_token=supabase_client.get_device_token(athlete_id), | ||
) | ||
supabase_client.upsert_user_auth(user_auth) | ||
return user_auth | ||
|
||
|
||
def validate_and_refresh_token(token: str) -> int: | ||
""" | ||
Validate and refresh the user's credentials in DB | ||
:param token: JWT token | ||
:return: athlete_id | ||
""" | ||
try: | ||
decoded_token = jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) | ||
return decoded_token | ||
except PyJWTError: | ||
raise HTTPException(status_code=401, detail="Invalid or expired token") | ||
athlete_id = decode_jwt(token) | ||
except jwt.ExpiredSignatureError: | ||
try: | ||
# If the token is expired, decode athlete_id and refresh | ||
athlete_id = decode_jwt(token, verify_exp=False) | ||
user_auth = supabase_client.get_user_auth(athlete_id) | ||
refresh_and_update_user_token( | ||
athlete_id=athlete_id, refresh_token=user_auth.refresh_token | ||
) | ||
except jwt.DecodeError: | ||
logger.error("Invalid JWT token") | ||
raise HTTPException(status_code=401, detail="Invalid JWT token") | ||
except Exception as e: | ||
logger.error( | ||
f"Unknown error validating and refreshing token: {e}", | ||
exc_info=True, | ||
) | ||
raise HTTPException(status_code=500, detail="Internal server error") | ||
except jwt.DecodeError: | ||
logger.error("Invalid JWT token") | ||
raise HTTPException(status_code=401, detail="Invalid JWT token") | ||
except Exception as e: | ||
logger.error( | ||
f"Unknown error validating and refreshing token: {e}", | ||
exc_info=True, | ||
) | ||
raise HTTPException(status_code=500, detail="Internal server error") | ||
|
||
return athlete_id | ||
|
||
|
||
async def get_current_user( | ||
async def validate_user( | ||
credentials: HTTPAuthorizationCredentials = Security(bearer_scheme), | ||
) -> UserRow: | ||
""" | ||
Dependency that validates the JWT token from the Authorization header | ||
:param credentials: Bearer token credentials | ||
:return: Decoded JWT payload containing user details | ||
:return: UserRow | ||
""" | ||
token = credentials.credentials | ||
payload = decode_jwt(token) | ||
if payload is None: | ||
athlete_id = validate_and_refresh_token(credentials.credentials) | ||
if athlete_id is None: | ||
logger.error("Invalid authentication credentials") | ||
raise HTTPException( | ||
status_code=401, detail="Invalid authentication credentials" | ||
) | ||
return payload | ||
return supabase_client.get_user(athlete_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,29 @@ | ||
import logging | ||
|
||
from fastapi import Depends, FastAPI, HTTPException | ||
from src.auth_manager import get_current_user | ||
from src.supabase_client import TrainingWeek, get_training_week | ||
from src import supabase_client | ||
from src.auth_manager import validate_user | ||
from src.types.training_week import TrainingWeek | ||
from src.types.user import UserRow | ||
|
||
app = FastAPI() | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
@app.get("/training_week/", response_model=TrainingWeek) | ||
async def training_week_endpoint(user: UserRow = Depends(get_current_user)): | ||
async def training_week(user: UserRow = Depends(validate_user)): | ||
""" | ||
Retrieve the most recent training_week row by athlete_id | ||
curl -X GET "http://localhost:8000/training_week/{athlete_id}" \ | ||
curl -X GET "http://trackflow-alb-499532887.us-east-1.elb.amazonaws.com/training_week/" \ | ||
-H "Authorization: Bearer YOUR_JWT_TOKEN" | ||
:param athlete_id: The athlete_id to retrieve the training_week for | ||
:return: The most recent training_week row for the athlete | ||
""" | ||
try: | ||
return get_training_week(user.athlete_id) | ||
return supabase_client.get_training_week(user.athlete_id) | ||
except ValueError as e: | ||
logger.error(f"Error retrieving training week: {e}", exc_info=True) | ||
raise HTTPException(status_code=404, detail=str(e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
import sys | ||
|
||
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | ||
sys.path.append(src_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
|
||
from fastapi.testclient import TestClient | ||
from src.main import app | ||
from src.supabase_client import get_user_auth | ||
from src.types.training_week import TrainingWeek | ||
|
||
client = TestClient(app) | ||
|
||
|
||
def test_get_training_week_success(): | ||
"""Test successful retrieval of training week""" | ||
user_auth = get_user_auth(os.environ["JAMIES_ATHLETE_ID"]) | ||
|
||
response = client.get( | ||
"/training_week/", headers={"Authorization": f"Bearer {user_auth.jwt_token}"} | ||
) | ||
assert TrainingWeek(**response.json()) | ||
assert response.status_code == 200 |