diff --git a/src/apn.py b/src/apn.py index d22e21f..cfd5ade 100644 --- a/src/apn.py +++ b/src/apn.py @@ -28,7 +28,7 @@ def generate_jwt_token(key_id: str, team_id: str, private_key: str) -> str: return jwt.encode(payload, private_key, algorithm="ES256", headers=headers) -def send_push_notification(device_token: str, title: str, body: str) -> dict: +def send_push_notification(device_token: str, title: str, body: str): """ Send a push notification to a user's device. @@ -71,4 +71,4 @@ def send_push_notification(device_token: str, title: str, body: str) -> dict: logging.error(f"APNs error: {response.status_code}, {error_payload}") raise ValueError(f"APNs rejected the request: {error_payload}") - return response.json() + return response diff --git a/src/auth_manager.py b/src/auth_manager.py index 5098337..bb11049 100644 --- a/src/auth_manager.py +++ b/src/auth_manager.py @@ -9,6 +9,7 @@ from src.email_manager import send_alert_email from src.supabase_client import ( + get_device_token, get_user_auth, upsert_user, upsert_user_auth, @@ -79,12 +80,14 @@ def refresh_and_update_user_token(athlete_id: int, refresh_token: str) -> UserAu new_jwt_token = generate_jwt( athlete_id=athlete_id, expires_at=access_info["expires_at"] ) + user_auth = UserAuthRow( athlete_id=athlete_id, access_token=access_info["access_token"], refresh_token=access_info["refresh_token"], expires_at=access_info["expires_at"], jwt_token=new_jwt_token, + device_token=get_device_token(athlete_id), ) upsert_user_auth(user_auth) return user_auth @@ -136,6 +139,7 @@ def authenticate_with_code(code: str) -> UserAuthRow: refresh_token=strava_client.refresh_token, expires_at=strava_client.token_expires_at, jwt_token=jwt_token, + device_token=get_device_token(athlete.id), ) upsert_user_auth(user_auth_row) return user_auth_row diff --git a/src/supabase_client.py b/src/supabase_client.py index e4b4959..3f05dca 100644 --- a/src/supabase_client.py +++ b/src/supabase_client.py @@ -2,6 +2,7 @@ import json import os from datetime import timedelta, timezone +from typing import Optional from dotenv import load_dotenv from postgrest.base_request_builder import APIResponse @@ -269,3 +270,17 @@ def update_user_device_token(athlete_id: str, device_token: str) -> None: client.table("user_auth").update({"device_token": device_token}).eq( "athlete_id", athlete_id ).execute() + + +def get_device_token(athlete_id: int) -> Optional[str]: + """ + Get the device token for a user in the database. + + :param athlete_id: The athlete's ID + :return: The device token for the user, or None if the user does not exist + """ + try: + user_auth = get_user_auth(athlete_id) + return user_auth.device_token + except ValueError: + return None