Skip to content

Commit

Permalink
Add user friendly feedback on exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
berrydenhartog committed Oct 7, 2024
1 parent 29842ab commit 167b020
Show file tree
Hide file tree
Showing 43 changed files with 466 additions and 330 deletions.
8 changes: 5 additions & 3 deletions amt/api/http_browser_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from starlette.staticfiles import StaticFiles
from starlette.types import Scope

from amt.core.exceptions import AMTNotFound, AMTOnlyStatic


class StaticFilesCache(StaticFiles):
def __init__(
Expand Down Expand Up @@ -62,17 +64,17 @@ class URLComponents(NamedTuple):
@lru_cache(maxsize=1000)
def url_for_cache(name: str, /, **path_params: str) -> str:
if name != "static":
raise ValueError("Only static files are supported.")
raise AMTOnlyStatic()

url_parts: ParseResult = urllib.parse.urlparse(path_params["path"]) # type: ignore
if url_parts.scheme or url_parts.hostname: # type: ignore
raise ValueError("Only local URLS are supported.")
raise AMTOnlyStatic()

query_list: dict[str, str] = dict(x.split("=") for x in url_parts.query.split("&")) if url_parts.query else {} # type: ignore
resolved_url_path: str = "/" + name + "/" + url_parts.path # type: ignore
_, stat_result = static_files.lookup_path(url_parts.path) # type: ignore
if not stat_result:
raise ValueError(f"Static file {url_parts.path} not found.") # type: ignore
raise AMTNotFound()

etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
etag = f"{md5_hexdigest(etag_base.encode(), usedforsecurity=False)}"
Expand Down
10 changes: 5 additions & 5 deletions amt/api/routes/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
resolve_base_navigation_items,
resolve_sub_menu,
)
from amt.core.exceptions import NotFound, RepositoryError
from amt.core.exceptions import AMTNotFound, AMTRepositoryError
from amt.enums.status import Status
from amt.models import Project
from amt.services.projects import ProjectsService
Expand All @@ -29,8 +29,8 @@ def get_project_or_error(project_id: int, projects_service: ProjectsService, req
logger.debug(f"getting project with id {project_id}")
project = projects_service.get(project_id)
request.state.path_variables = {"project_id": project_id}
except RepositoryError as e:
raise NotFound from e
except AMTRepositoryError as e:
raise AMTNotFound from e
return project


Expand Down Expand Up @@ -170,7 +170,7 @@ async def get_assessment_card(

if not assessment_card_data:
logger.warning("assessment card not found")
raise NotFound()
raise AMTNotFound()

context = {
"assessment_card": assessment_card_data,
Expand Down Expand Up @@ -215,7 +215,7 @@ async def get_model_card(

if not model_card_data:
logger.warning("model card not found")
raise NotFound()
raise AMTNotFound()

context = {
"model_card": model_card_data,
Expand Down
6 changes: 4 additions & 2 deletions amt/clients/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime, timezone

import httpx
from amt.core.exceptions import AMTNotFound
from amt.schema.github import RepositoryContent

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,7 +37,8 @@ def _get(self, url: str) -> httpx.Response:
Private function that performs a GET request to given URL.
"""
response = self.client.get(url)
response.raise_for_status()
if response.status_code != 200:
raise AMTNotFound()
return response


Expand All @@ -47,7 +49,7 @@ def get_client(repo_type: str) -> Client:
case "github":
return GitHubClient()
case _:
raise ValueError(f"unknown repository type: {repo_type}")
raise AMTNotFound()


class GitHubPagesClient(Client):
Expand Down
8 changes: 4 additions & 4 deletions amt/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic_core import MultiHostUrl
from pydantic_settings import BaseSettings, SettingsConfigDict

from amt.core.exceptions import SettingsError
from amt.core.exceptions import AMTSettingsError
from amt.core.types import DatabaseSchemaType, EnvironmentType, LoggingLevelType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -86,19 +86,19 @@ def SQLALCHEMY_DATABASE_URI(self) -> str:
@model_validator(mode="after")
def _enforce_database_rules(self: SelfSettings) -> SelfSettings:
if self.ENVIRONMENT == "production" and self.APP_DATABASE_SCHEME == "sqlite":
raise SettingsError("APP_DATABASE_SCHEME")
raise AMTSettingsError("APP_DATABASE_SCHEME")
return self

@model_validator(mode="after")
def _enforce_debug_rules(self: SelfSettings) -> SelfSettings:
if self.ENVIRONMENT == "production" and self.DEBUG:
raise SettingsError("DEBUG")
raise AMTSettingsError("DEBUG")
return self

@model_validator(mode="after")
def _enforce_autocreate_rules(self: SelfSettings) -> SelfSettings:
if self.ENVIRONMENT == "production" and self.AUTO_CREATE_SCHEMA:
raise SettingsError("AUTO_CREATE_SCHEMA")
raise AMTSettingsError("AUTO_CREATE_SCHEMA")
return self


Expand Down
85 changes: 41 additions & 44 deletions amt/core/exception_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,65 +3,62 @@
from fastapi import Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import HTMLResponse
from fastapi_csrf_protect.exceptions import CsrfProtectError # type: ignore
from starlette.exceptions import HTTPException as StarletteHTTPException

from amt.api.deps import templates
from amt.core.exceptions import AMTHTTPException, AMTNotFound, AMTRepositoryError
from amt.core.internationalization import (
get_current_translation,
)

logger = logging.getLogger(__name__)


async def http_exception_handler(request: Request, exc: StarletteHTTPException) -> HTMLResponse:
logger.debug(f"http_exception_handler: {exc.status_code} - {exc.detail}")
async def general_exception_handler(request: Request, exc: Exception) -> HTMLResponse:
exception_name = exc.__class__.__name__

if request.state.htmx:
return templates.TemplateResponse(
request,
"errors/_HTTPException.html.j2",
{"status_code": exc.status_code, "status_message": exc.detail},
status_code=exc.status_code,
)
logger.debug(f"general_exception_handler {exception_name}: {exc}")

return templates.TemplateResponse(
request,
"errors/HTTPException.html.j2",
{"status_code": exc.status_code, "status_message": exc.detail, "breadcrumbs": []},
status_code=exc.status_code,
)
translations = get_current_translation(request)

message = None
if isinstance(exc, AMTRepositoryError | AMTHTTPException):
message = exc.getmessage(translations)
elif isinstance(exc, StarletteHTTPException):
message = AMTNotFound().getmessage(translations) if exc.status_code == status.HTTP_404_NOT_FOUND else exc.detail
elif isinstance(exc, RequestValidationError):
messages: list[str] = [f"{error['loc'][-1]}: {error['msg']}" for error in exc.errors()]
message = "\n".join(messages)

async def validation_exception_handler(request: Request, exc: RequestValidationError) -> HTMLResponse:
logger.debug(f"validation_exception_handler: {exc.errors()}")
errors = exc.errors()
messages: list[str] = [f"{error['loc'][-1]}: {error['msg']}" for error in errors]

if request.state.htmx:
return templates.TemplateResponse(
request,
"errors/_RequestValidation.html.j2",
{"message": messages},
status_code=status.HTTP_400_BAD_REQUEST,
)
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
if isinstance(exc, StarletteHTTPException):
status_code = exc.status_code
elif isinstance(exc, RequestValidationError):
status_code = status.HTTP_400_BAD_REQUEST

return templates.TemplateResponse(
request, "errors/RequestValidation.html.j2", {"message": messages}, status_code=status.HTTP_400_BAD_REQUEST
# todo: what if request.state.htmx does not exist?
template_name = (
f"errors/_{exception_name}_{status_code}.html.j2"
if request.state.htmx
else f"errors/{exception_name}_{status_code}.html.j2"
)
fallback_template_name = "errors/_Exception.html.j2" if request.state.htmx else "errors/Exception.html.j2"

response: HTMLResponse | None = None

async def csrf_protect_exception_handler(request: Request, exc: CsrfProtectError) -> HTMLResponse:
logger.debug(f"csrf_protect_exception_handler: {exc.status_code} - {exc.message}")

if request.state.htmx:
return templates.TemplateResponse(
try:
response = templates.TemplateResponse(
request,
"errors/_CsrfProtectError.html.j2",
{"status_code": exc.status_code, "message": exc.message},
status_code=exc.status_code,
template_name,
{"message": message},
status_code=status_code,
)
except Exception:
response = templates.TemplateResponse(
request,
fallback_template_name,
{"message": message},
status_code=status_code,
)

return templates.TemplateResponse(
request,
"errors/CsrfProtectError.html.j2",
{"status_code": exc.status_code, "message": exc.message},
status_code=exc.status_code,
)
return response
85 changes: 46 additions & 39 deletions amt/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,67 @@
from pathlib import Path
from gettext import gettext as _

from babel.support import NullTranslations
from fastapi import status
from fastapi.exceptions import HTTPException, ValidationException
from fastapi.exceptions import HTTPException


class AMTHTTPException(HTTPException):
pass
def getmessage(self, translations: NullTranslations) -> str:
return translations.gettext(self.detail)


class AMTValidationException(ValidationException):
pass
class AMTError(Exception):
def getmessage(self, translations: NullTranslations) -> str:
return translations.gettext(self.detail) # type: ignore


class AMTError(RuntimeError):
"""
A generic, AMT-specific error.
"""
class AMTSettingsError(AMTError):
def __init__(self, field: str) -> None:
self.detail: str = _(
"An error occurred while configuring the options for '{field}'. Please check the settings and try again."
).format(field=field)
super().__init__(self.detail)


class SettingsError(AMTError):
def __init__(self, field: str) -> None:
self.message: str = f"Settings error for options {field}"
exception_name: str = self.__class__.__name__
super().__init__(f"{exception_name}: {self.message}")
class AMTRepositoryError(AMTHTTPException):
def __init__(self, detail: str | None = "Repository error") -> None:
self.detail: str = _("An internal server error occurred while processing your request. Please try again later.")
super().__init__(status.HTTP_500_INTERNAL_SERVER_ERROR, self.detail)


class AMTInstrumentError(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _("An error occurred while processing the instrument. Please try again later.")
super().__init__(status.HTTP_501_NOT_IMPLEMENTED, self.detail)

class RepositoryError(AMTHTTPException):
def __init__(self, message: str = "Repository error") -> None:
self.message: str = message
exception_name: str = self.__class__.__name__
super().__init__(status.HTTP_500_INTERNAL_SERVER_ERROR, f"{exception_name}: {self.message}")

class AMTNotFound(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _(
"The requested page or resource could not be found. Please check the URL or query and try again."
)
super().__init__(status.HTTP_404_NOT_FOUND, self.detail)

class InstrumentError(AMTHTTPException):
def __init__(self, message: str = "Instrument error") -> None:
self.message: str = message
exception_name: str = self.__class__.__name__
super().__init__(status.HTTP_501_NOT_IMPLEMENTED, f"{exception_name}: {self.message}")

class AMTCSRFProtectError(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _("CSRF check failed.")
super().__init__(status.HTTP_401_UNAUTHORIZED, self.detail)

class UnsafeFileError(AMTHTTPException):
def __init__(self, file: Path) -> None:
self.message: str = f"Unsafe file error for file {file}"
exception_name: str = self.__class__.__name__
super().__init__(status.HTTP_400_BAD_REQUEST, f"{exception_name}: {self.message}")

class AMTOnlyStatic(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _("Only static files are supported.")
super().__init__(status.HTTP_400_BAD_REQUEST, self.detail)

class RepositoryNoResultFound(AMTHTTPException):
def __init__(self, message: str = "No entity found") -> None:
self.message: str = message
exception_name: str = self.__class__.__name__
super().__init__(status.HTTP_204_NO_CONTENT, f"{exception_name}: {self.message}")

class AMTKeyError(AMTHTTPException):
def __init__(self, field: str) -> None:
self.detail: str = _("Key not correct: {field}").format(field=field)
super().__init__(status.HTTP_400_BAD_REQUEST, self.detail)


class NotFound(AMTHTTPException):
def __init__(self, message: str = "Not found") -> None:
self.message: str = message
exception_name: str = self.__class__.__name__
super().__init__(status.HTTP_404_NOT_FOUND, f"{exception_name}: {self.message}")
class AMTValueError(AMTHTTPException):
def __init__(self, field: str) -> None:
self.detail: str = _("Value not correct: {field}").format(field=field)
super().__init__(status.HTTP_400_BAD_REQUEST, self.detail)
Loading

0 comments on commit 167b020

Please sign in to comment.