Skip to content

Commit

Permalink
Add CSRF protection
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Aug 5, 2024
1 parent 9e5db95 commit 0bf377a
Show file tree
Hide file tree
Showing 21 changed files with 253 additions and 24 deletions.
2 changes: 1 addition & 1 deletion amt/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def TemplateResponse( # pyright: ignore [reportIncompatibleMethodOverride]

if context is None:
context = {}

context["csrftoken"] = request.state.csrftoken
return super().TemplateResponse(request, name, context, status_code, headers, media_type, background)

def Redirect(self, request: Request, url: str) -> HTMLResponse:
Expand Down
8 changes: 4 additions & 4 deletions amt/api/routes/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ async def get_new(
instrument_service: Annotated[InstrumentsService, Depends(InstrumentsService)],
) -> HTMLResponse:
instruments = instrument_service.fetch_instruments()

return templates.TemplateResponse(request, "projects/new.html.j2", {"instruments": instruments})
response = templates.TemplateResponse(request, "projects/new.html.j2", {"instruments": instruments})
return response


@router.post("/new", response_class=HTMLResponse)
Expand All @@ -48,5 +48,5 @@ async def post_new(
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = projects_service.create(project_new)

return templates.Redirect(request, f"/project/{project.id}")
response = templates.Redirect(request, f"/project/{project.id}")
return response
7 changes: 7 additions & 0 deletions amt/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

logger = logging.getLogger(__name__)


# Self type is not available in Python 3.10 so create our own with TypeVar
SelfSettings = TypeVar("SelfSettings", bound="Settings")

Expand Down Expand Up @@ -50,6 +51,12 @@ class Settings(BaseSettings):

model_config = SettingsConfigDict(extra="ignore")

# FastAPI CSRF Protect Settings
CSRF_PROTECT_SECRET_KEY: str = secrets.token_urlsafe(32)
CSRF_TOKEN_LOCATION: str = "header"
CSRF_TOKEN_KEY: str = "csrf-token"
CSRF_COOKIE_SAMESITE: str = "strict"

@computed_field
def SQLALCHEMY_ECHO(self) -> bool:
return self.DEBUG
Expand Down
18 changes: 18 additions & 0 deletions amt/core/csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any

from fastapi_csrf_protect import CsrfProtect # type: ignore

from amt.core.config import get_settings


@CsrfProtect.load_config # type: ignore
def get_csrf_config() -> list[tuple[Any, ...]]:
settings = get_settings()
config = [
("secret_key", settings.CSRF_PROTECT_SECRET_KEY),
("token_location", settings.CSRF_TOKEN_LOCATION),
("token_key", settings.CSRF_TOKEN_KEY),
("cookie_samesite", settings.CSRF_COOKIE_SAMESITE),
]

return config
20 changes: 20 additions & 0 deletions amt/core/exception_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import HTMLResponse
from fastapi_csrf_protect.exceptions import CsrfProtectError # type: ignore
from starlette.exceptions import HTTPException as StarletteHTTPException

from amt.api.deps import templates
Expand Down Expand Up @@ -45,3 +46,22 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
return templates.TemplateResponse(
request, "errors/RequestValidation.html.j2", {"message": messages}, status_code=status.HTTP_400_BAD_REQUEST
)


async def csrf_protect_exception_handler(request: Request, exc: CsrfProtectError) -> HTMLResponse:
logger.debug(f"csrf_protect_exception_handler: {exc.status_code} - {exc.message}")

if request.state.htmx:
return templates.TemplateResponse(
request,
"errors/_CsrfProtectError.html.j2",
{"status_code": exc.status_code, "message": exc.message},
status_code=exc.status_code,
)

return templates.TemplateResponse(
request,
"errors/CsrfProtectError.html.j2",
{"status_code": exc.status_code, "message": exc.message},
status_code=exc.status_code,
)
76 changes: 76 additions & 0 deletions amt/middleware/csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import logging
import typing

from fastapi_csrf_protect import CsrfProtect # type: ignore
from fastapi_csrf_protect.exceptions import CsrfProtectError # type: ignore
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp

from amt.core.csrf import get_csrf_config # type: ignore # noqa
from amt.core.exception_handlers import csrf_protect_exception_handler

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]


logger = logging.getLogger(__name__)


class CSRFMiddleware(BaseHTTPMiddleware):
"""
This middleware implements CSRF signed double token protection through FastAPI CSRF Protect.
"""

def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
self.csrf_protect = CsrfProtect()
self.safe_methods = ("GET", "HEAD", "OPTIONS", "TRACE")

def _include_request(self, request: Request) -> bool:
"""
This method specifies whether a request should be protected by FastAPI CSRF Protect or not.
The method is needed because we need to in any case exclude GET requests originating from
HTMX or GET requests that fetch static pages becauses this will result in multiple tokens
which make validation impossible due to the asynchronisity of the requests.
"""
is_not_static: bool = "static" not in request.url.path
is_not_htmx: bool = request.state.htmx == "False"
return is_not_static or is_not_htmx

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
signed_token = ""
if self._include_request(request):
request.state.csrftoken = ""
if request.method in self.safe_methods:
csrf_token, signed_token = self.csrf_protect.generate_csrf_tokens()
logger.debug(f"generating tokens: csrf_token={csrf_token}, signed_token={signed_token}")
request.state.csrftoken = csrf_token
else:
csrf_token = request.headers["X-CSRF-Token"]
logger.debug(f"validating tokens: csrf_token={csrf_token}")
await self.csrf_protect.validate_csrf(request)

response = await call_next(request)

if self._include_request(request) and request.method in self.safe_methods:
self.csrf_protect.set_csrf_cookie(signed_token, response)
logger.debug(f"set csrf_cookie: signed_token={signed_token}")

return response


class CSRFMiddlewareExceptionHandler(BaseHTTPMiddleware):
"""
This middleware is necessary to propagate CsrfProtectErrors to the csrf_protection_handler.
"""

def __init__(self, app: ASGIApp) -> None:
super().__init__(app)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
try:
response = await call_next(request)
except CsrfProtectError as e:
return await csrf_protect_exception_handler(request, e)
return response
3 changes: 3 additions & 0 deletions amt/middleware/htmx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import typing

from starlette.middleware.base import BaseHTTPMiddleware
Expand All @@ -6,6 +7,8 @@

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]

logger = logging.getLogger(__name__)


class HTMXMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
Expand Down
3 changes: 3 additions & 0 deletions amt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from amt.utils.mask import Mask

from .api.http_browser_caching import static_files
from .middleware.csrf import CSRFMiddleware, CSRFMiddlewareExceptionHandler
from .middleware.htmx import HTMXMiddleware
from .middleware.route_logging import RequestLoggingMiddleware

Expand Down Expand Up @@ -54,6 +55,8 @@ def create_app() -> FastAPI:
)

app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(CSRFMiddleware)
app.add_middleware(CSRFMiddlewareExceptionHandler)
app.add_middleware(HTMXMiddleware)

app.mount("/static", static_files, name="static")
Expand Down
5 changes: 5 additions & 0 deletions amt/site/templates/errors/CsrfProtectError.html.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{% extends 'layouts/base.html.j2' %}

{% block content %}
{% include 'errors/_CsrfProtectError.html.j2' %}
{% endblock %}
2 changes: 2 additions & 0 deletions amt/site/templates/errors/_CsrfProtectError.html.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
<h2>{{status_code}}</h2>
<p>{{status_message}}</p>
2 changes: 1 addition & 1 deletion amt/site/templates/pages/index.html.j2
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

{% block content %}
<div class="container">
<form id="cardMovedForm" hx-patch="/tasks/" hx-ext="json-enc" hx-target-5*="#errorContainer" hx-trigger="cardmoved" hx-swap="outerHTML" hx-target="#board" class="">
<form id="cardMovedForm" hx-patch="/tasks/" hx-ext="json-enc" hx-headers='{"X-CSRF-Token": "{{ csrftoken }}"}' hx-target-5*="#errorContainer" hx-trigger="cardmoved" hx-swap="outerHTML" hx-target="#board" class="">
<input type="hidden" name="taskId" value="">
<input type="hidden" name="statusId" value="">
<input type="hidden" name="previousSiblingId" value="">
Expand Down
2 changes: 1 addition & 1 deletion amt/site/templates/projects/new.html.j2
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{% block content %}
<h1 class="margin-bottom--large">{% trans %}New Project{% endtrans %}</h1>

<form hx-ext="json-enc" hx-post="/projects/new" hx-target-error="#errorContainer" hx-swap="innerHTML" method="post">
<form hx-ext="json-enc" hx-post="/projects/new" hx-headers='{"X-CSRF-Token": "{{ csrftoken }}"}' hx-target-error="#errorContainer" hx-swap="innerHTML" method="post">
{% trans %}Project name{% endtrans %} <input type="text" id="name" name="name">

<fieldset>
Expand Down
30 changes: 29 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ httpx = "^0.27.0"
pyyaml-include = "^2.1"
click = "^8.1.7"
python-ulid = "^2.7.0"
fastapi-csrf-protect = "^0.3.4"


[tool.poetry.group.test.dependencies]
Expand Down
16 changes: 10 additions & 6 deletions tests/api/routes/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from amt.services.instruments import InstrumentsService
from amt.services.storage import FileSystemStorageService
from fastapi.testclient import TestClient
from fastapi_csrf_protect import CsrfProtect # type: ignore # noqa

from tests.constants import default_instrument

Expand Down Expand Up @@ -52,37 +53,40 @@ def test_get_new_projects(client: TestClient, init_instruments: Generator[None,
)


def test_post_new_projects_bad_request(client: TestClient) -> None:
def test_post_new_projects_bad_request(client: TestClient, mock_csrf: Generator[None, None, None]) -> None:
# when
response = client.post("/projects/new", json={})
client.cookies["fastapi-csrf-token"] = "1"
response = client.post("/projects/new", json={}, headers={"X-CSRF-Token": "1"})

# then
assert response.status_code == 400
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert b"name: Field required" in response.content


def test_post_new_projects(client: TestClient) -> None:
def test_post_new_projects(client: TestClient, mock_csrf: Generator[None, None, None]) -> None:
client.cookies["fastapi-csrf-token"] = "1"
new_project = ProjectNew(name="default project")

# when
response = client.post("/projects/new", json=new_project.model_dump())
response = client.post("/projects/new", json=new_project.model_dump(), headers={"X-CSRF-Token": "1"})

# then
assert response.status_code == 200
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert response.headers["HX-Redirect"] == "/project/1"


def test_post_new_projects_write_system_card(client: TestClient) -> None:
def test_post_new_projects_write_system_card(client: TestClient, mock_csrf: Generator[None, None, None]) -> None:
# Given
client.cookies["fastapi-csrf-token"] = "1"
origin = FileSystemStorageService.write
FileSystemStorageService.write = MagicMock()
project_new = ProjectNew(name="name1")
system_card = SystemCard(name=project_new.name, selected_instruments=[])

# when
client.post("/projects/new", json=project_new.model_dump())
client.post("/projects/new", json=project_new.model_dump(), headers={"X-CSRF-Token": "1"})

# then
FileSystemStorageService.write.assert_called_with(system_card.model_dump())
Expand Down
16 changes: 12 additions & 4 deletions tests/api/routes/test_status.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
from collections.abc import Generator

from amt.schema.task import MovedTask
from fastapi.testclient import TestClient

from tests.constants import default_task
from tests.database_test_utils import DatabaseTestUtils


def test_post_move_task(client: TestClient, db: DatabaseTestUtils) -> None:
def test_post_move_task(client: TestClient, db: DatabaseTestUtils, mock_csrf: Generator[None, None, None]) -> None:
db.given([default_task(), default_task(), default_task()])
client.cookies["fastapi-csrf-token"] = "1"

move_task: MovedTask = MovedTask(taskId=2, statusId=2, previousSiblingId=1, nextSiblingId=3)

response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True))
response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True), headers={"X-CSRF-Token": "1"})
assert response.status_code == 200
assert response.headers["content-type"] == "text/html; charset=utf-8"
assert b'id="card-content-' in response.content


def test_post_move_task_no_siblings(client: TestClient, db: DatabaseTestUtils) -> None:
def test_post_move_task_no_siblings(
client: TestClient,
db: DatabaseTestUtils,
mock_csrf: Generator[None, None, None],
) -> None:
db.given([default_task(), default_task(), default_task()])
client.cookies["fastapi-csrf-token"] = "1"

move_task: MovedTask = MovedTask(taskId=2, statusId=1, previousSiblingId=-1, nextSiblingId=-1)
response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True))
response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True), headers={"X-CSRF-Token": "1"})

assert response.status_code == 200
assert response.headers["content-type"] == "text/html; charset=utf-8"
Expand Down
Loading

0 comments on commit 0bf377a

Please sign in to comment.