diff --git a/src/backend/app/config.py b/src/backend/app/config.py index 7942006a..f682bdad 100644 --- a/src/backend/app/config.py +++ b/src/backend/app/config.py @@ -78,7 +78,8 @@ def assemble_db_connection(cls, v: Optional[str], info: ValidationInfo) -> Any: S3_BUCKET_NAME: str = "dtm-data" S3_DOWNLOAD_ROOT: Optional[str] = None - ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 + ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 1 # 1 day + REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 # 8 day @lru_cache diff --git a/src/backend/app/users/user_crud.py b/src/backend/app/users/user_crud.py index 53ab5ad8..cea04bda 100644 --- a/src/backend/app/users/user_crud.py +++ b/src/backend/app/users/user_crud.py @@ -11,15 +11,26 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - ALGORITHM = "HS256" -def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: +def create_access_token( + subject: str | Any, expires_delta: timedelta, refresh_token_expiry: timedelta +): expire = datetime.utcnow() + expires_delta - to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + refresh_expire = datetime.utcnow() + refresh_token_expiry + + to_encode_access_token = {"exp": expire, "sub": str(subject)} + to_encode_refresh_token = {"exp": refresh_expire, "sub": str(subject)} + + access_token = jwt.encode( + to_encode_access_token, settings.SECRET_KEY, algorithm=ALGORITHM + ) + refresh_token = jwt.encode( + to_encode_refresh_token, settings.SECRET_KEY, algorithm=ALGORITHM + ) + + return access_token, refresh_token def verify_password(plain_password: str, hashed_password: str) -> bool: diff --git a/src/backend/app/users/user_deps.py b/src/backend/app/users/user_deps.py new file mode 100644 index 00000000..f81a8f3b --- /dev/null +++ b/src/backend/app/users/user_deps.py @@ -0,0 +1,55 @@ +import jwt +from typing import Annotated + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jwt.exceptions import InvalidTokenError +from pydantic import ValidationError +from sqlalchemy.orm import Session +from app.config import settings +from app.db import database +from app.users import user_crud, user_schemas +from app.db.db_models import DbUser + + +reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_PREFIX}/users/login") + + +SessionDep = Annotated[ + Session, + Depends(database.get_db), +] +TokenDep = Annotated[str, Depends(reusable_oauth2)] + + +def get_current_user(session: SessionDep, token: TokenDep): + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[user_crud.ALGORITHM] + ) + token_data = user_schemas.TokenPayload(**payload) + + except (InvalidTokenError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + + user = session.get(DbUser, token_data.sub) + + if not user: + raise HTTPException(status_code=404, detail="User not found") + if not user.is_active: + raise HTTPException(status_code=400, detail="Inactive user") + return user + + +CurrentUser = Annotated[DbUser, Depends(get_current_user)] + + +def get_current_active_superuser(current_user: CurrentUser): + if not current_user.is_superuser: + raise HTTPException( + status_code=403, detail="The user doesn't have enough privileges" + ) + return current_user diff --git a/src/backend/app/users/user_routes.py b/src/backend/app/users/user_routes.py index 32f857d7..0b515116 100644 --- a/src/backend/app/users/user_routes.py +++ b/src/backend/app/users/user_routes.py @@ -1,9 +1,11 @@ +from typing import Any from datetime import timedelta from fastapi import APIRouter, HTTPException, Depends from sqlalchemy.orm import Session from typing import Annotated from fastapi.security import OAuth2PasswordRequestForm from app.users.user_schemas import Token, UserPublic, UserRegister +from app.users.user_deps import CurrentUser from app.config import settings from app.users import user_crud from app.db import database @@ -31,11 +33,14 @@ def login_access_token( elif not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - return Token( - access_token=user_crud.create_access_token( - user.id, expires_delta=access_token_expires - ) + refresh_token_expires = timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES) + + access_token, refresh_token = user_crud.create_access_token( + user.id, + expires_delta=access_token_expires, + refresh_token_expiry=refresh_token_expires, ) + return Token(access_token=access_token, refresh_token=refresh_token) @router.post("/signup", response_model=UserPublic) @@ -61,3 +66,24 @@ def register_user( user = user_crud.create_user(db, user_in) return user + + +@router.get("/refresh_token") +def update_token(current_user: CurrentUser): + access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + refresh_token_expires = timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES) + + access_token, refresh_token = user_crud.create_access_token( + current_user.id, + expires_delta=access_token_expires, + refresh_token_expiry=refresh_token_expires, + ) + return Token(access_token=access_token, refresh_token=refresh_token) + + +@router.get("/me", response_model=UserPublic) +def read_user_me(current_user: CurrentUser) -> Any: + """ + Get current user. + """ + return current_user diff --git a/src/backend/app/users/user_schemas.py b/src/backend/app/users/user_schemas.py index 2da7a267..cad9dc84 100644 --- a/src/backend/app/users/user_schemas.py +++ b/src/backend/app/users/user_schemas.py @@ -17,8 +17,14 @@ class User(BaseModel): name: str +# Contents of JWT token +class TokenPayload(BaseModel): + sub: int | None = None + + class Token(BaseModel): access_token: str + refresh_token: str token_type: str = "bearer"