From 0bf377ad841803dbc4041f0f5f320e156ba9e0e8 Mon Sep 17 00:00:00 2001 From: ChristopherSpelt Date: Fri, 26 Jul 2024 14:56:43 +0200 Subject: [PATCH] Add CSRF protection --- amt/api/deps.py | 2 +- amt/api/routes/projects.py | 8 +- amt/core/config.py | 7 ++ amt/core/csrf.py | 18 +++++ amt/core/exception_handlers.py | 20 +++++ amt/middleware/csrf.py | 76 +++++++++++++++++++ amt/middleware/htmx.py | 3 + amt/server.py | 3 + .../templates/errors/CsrfProtectError.html.j2 | 5 ++ .../errors/_CsrfProtectError.html.j2 | 2 + amt/site/templates/pages/index.html.j2 | 2 +- amt/site/templates/projects/new.html.j2 | 2 +- poetry.lock | 30 +++++++- pyproject.toml | 1 + tests/api/routes/test_projects.py | 16 ++-- tests/api/routes/test_status.py | 16 +++- tests/api/routes/test_tasks_move.py | 16 +++- tests/conftest.py | 12 +++ tests/constants.py | 4 +- tests/core/test_exception_handlers.py | 32 ++++++++ .../templates/test_template_new_project.py | 2 +- 21 files changed, 253 insertions(+), 24 deletions(-) create mode 100644 amt/core/csrf.py create mode 100644 amt/middleware/csrf.py create mode 100644 amt/site/templates/errors/CsrfProtectError.html.j2 create mode 100644 amt/site/templates/errors/_CsrfProtectError.html.j2 diff --git a/amt/api/deps.py b/amt/api/deps.py index d37aad1e0..b070dc77c 100644 --- a/amt/api/deps.py +++ b/amt/api/deps.py @@ -69,7 +69,7 @@ def TemplateResponse( # pyright: ignore [reportIncompatibleMethodOverride] if context is None: context = {} - + context["csrftoken"] = request.state.csrftoken return super().TemplateResponse(request, name, context, status_code, headers, media_type, background) def Redirect(self, request: Request, url: str) -> HTMLResponse: diff --git a/amt/api/routes/projects.py b/amt/api/routes/projects.py index 07ff815f4..0f82bfd40 100644 --- a/amt/api/routes/projects.py +++ b/amt/api/routes/projects.py @@ -37,8 +37,8 @@ async def get_new( instrument_service: Annotated[InstrumentsService, Depends(InstrumentsService)], ) -> HTMLResponse: instruments = instrument_service.fetch_instruments() - - return templates.TemplateResponse(request, "projects/new.html.j2", {"instruments": instruments}) + response = templates.TemplateResponse(request, "projects/new.html.j2", {"instruments": instruments}) + return response @router.post("/new", response_class=HTMLResponse) @@ -48,5 +48,5 @@ async def post_new( projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: project = projects_service.create(project_new) - - return templates.Redirect(request, f"/project/{project.id}") + response = templates.Redirect(request, f"/project/{project.id}") + return response diff --git a/amt/core/config.py b/amt/core/config.py index 0293c2bee..7ff4d2646 100644 --- a/amt/core/config.py +++ b/amt/core/config.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) + # Self type is not available in Python 3.10 so create our own with TypeVar SelfSettings = TypeVar("SelfSettings", bound="Settings") @@ -50,6 +51,12 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(extra="ignore") + # FastAPI CSRF Protect Settings + CSRF_PROTECT_SECRET_KEY: str = secrets.token_urlsafe(32) + CSRF_TOKEN_LOCATION: str = "header" + CSRF_TOKEN_KEY: str = "csrf-token" + CSRF_COOKIE_SAMESITE: str = "strict" + @computed_field def SQLALCHEMY_ECHO(self) -> bool: return self.DEBUG diff --git a/amt/core/csrf.py b/amt/core/csrf.py new file mode 100644 index 000000000..9bab038ee --- /dev/null +++ b/amt/core/csrf.py @@ -0,0 +1,18 @@ +from typing import Any + +from fastapi_csrf_protect import CsrfProtect # type: ignore + +from amt.core.config import get_settings + + +@CsrfProtect.load_config # type: ignore +def get_csrf_config() -> list[tuple[Any, ...]]: + settings = get_settings() + config = [ + ("secret_key", settings.CSRF_PROTECT_SECRET_KEY), + ("token_location", settings.CSRF_TOKEN_LOCATION), + ("token_key", settings.CSRF_TOKEN_KEY), + ("cookie_samesite", settings.CSRF_COOKIE_SAMESITE), + ] + + return config diff --git a/amt/core/exception_handlers.py b/amt/core/exception_handlers.py index 5bf88d0d3..046e15013 100644 --- a/amt/core/exception_handlers.py +++ b/amt/core/exception_handlers.py @@ -3,6 +3,7 @@ from fastapi import Request, status from fastapi.exceptions import RequestValidationError from fastapi.responses import HTMLResponse +from fastapi_csrf_protect.exceptions import CsrfProtectError # type: ignore from starlette.exceptions import HTTPException as StarletteHTTPException from amt.api.deps import templates @@ -45,3 +46,22 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE return templates.TemplateResponse( request, "errors/RequestValidation.html.j2", {"message": messages}, status_code=status.HTTP_400_BAD_REQUEST ) + + +async def csrf_protect_exception_handler(request: Request, exc: CsrfProtectError) -> HTMLResponse: + logger.debug(f"csrf_protect_exception_handler: {exc.status_code} - {exc.message}") + + if request.state.htmx: + return templates.TemplateResponse( + request, + "errors/_CsrfProtectError.html.j2", + {"status_code": exc.status_code, "message": exc.message}, + status_code=exc.status_code, + ) + + return templates.TemplateResponse( + request, + "errors/CsrfProtectError.html.j2", + {"status_code": exc.status_code, "message": exc.message}, + status_code=exc.status_code, + ) diff --git a/amt/middleware/csrf.py b/amt/middleware/csrf.py new file mode 100644 index 000000000..ff4290cd5 --- /dev/null +++ b/amt/middleware/csrf.py @@ -0,0 +1,76 @@ +import logging +import typing + +from fastapi_csrf_protect import CsrfProtect # type: ignore +from fastapi_csrf_protect.exceptions import CsrfProtectError # type: ignore +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp + +from amt.core.csrf import get_csrf_config # type: ignore # noqa +from amt.core.exception_handlers import csrf_protect_exception_handler + +RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] + + +logger = logging.getLogger(__name__) + + +class CSRFMiddleware(BaseHTTPMiddleware): + """ + This middleware implements CSRF signed double token protection through FastAPI CSRF Protect. + """ + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + self.csrf_protect = CsrfProtect() + self.safe_methods = ("GET", "HEAD", "OPTIONS", "TRACE") + + def _include_request(self, request: Request) -> bool: + """ + This method specifies whether a request should be protected by FastAPI CSRF Protect or not. + The method is needed because we need to in any case exclude GET requests originating from + HTMX or GET requests that fetch static pages becauses this will result in multiple tokens + which make validation impossible due to the asynchronisity of the requests. + """ + is_not_static: bool = "static" not in request.url.path + is_not_htmx: bool = request.state.htmx == "False" + return is_not_static or is_not_htmx + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + signed_token = "" + if self._include_request(request): + request.state.csrftoken = "" + if request.method in self.safe_methods: + csrf_token, signed_token = self.csrf_protect.generate_csrf_tokens() + logger.debug(f"generating tokens: csrf_token={csrf_token}, signed_token={signed_token}") + request.state.csrftoken = csrf_token + else: + csrf_token = request.headers["X-CSRF-Token"] + logger.debug(f"validating tokens: csrf_token={csrf_token}") + await self.csrf_protect.validate_csrf(request) + + response = await call_next(request) + + if self._include_request(request) and request.method in self.safe_methods: + self.csrf_protect.set_csrf_cookie(signed_token, response) + logger.debug(f"set csrf_cookie: signed_token={signed_token}") + + return response + + +class CSRFMiddlewareExceptionHandler(BaseHTTPMiddleware): + """ + This middleware is necessary to propagate CsrfProtectErrors to the csrf_protection_handler. + """ + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + try: + response = await call_next(request) + except CsrfProtectError as e: + return await csrf_protect_exception_handler(request, e) + return response diff --git a/amt/middleware/htmx.py b/amt/middleware/htmx.py index 01a262954..379e66a09 100644 --- a/amt/middleware/htmx.py +++ b/amt/middleware/htmx.py @@ -1,3 +1,4 @@ +import logging import typing from starlette.middleware.base import BaseHTTPMiddleware @@ -6,6 +7,8 @@ RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] +logger = logging.getLogger(__name__) + class HTMXMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: diff --git a/amt/server.py b/amt/server.py index 780643eec..50a6fcd8c 100644 --- a/amt/server.py +++ b/amt/server.py @@ -20,6 +20,7 @@ from amt.utils.mask import Mask from .api.http_browser_caching import static_files +from .middleware.csrf import CSRFMiddleware, CSRFMiddlewareExceptionHandler from .middleware.htmx import HTMXMiddleware from .middleware.route_logging import RequestLoggingMiddleware @@ -54,6 +55,8 @@ def create_app() -> FastAPI: ) app.add_middleware(RequestLoggingMiddleware) + app.add_middleware(CSRFMiddleware) + app.add_middleware(CSRFMiddlewareExceptionHandler) app.add_middleware(HTMXMiddleware) app.mount("/static", static_files, name="static") diff --git a/amt/site/templates/errors/CsrfProtectError.html.j2 b/amt/site/templates/errors/CsrfProtectError.html.j2 new file mode 100644 index 000000000..d77efc1b8 --- /dev/null +++ b/amt/site/templates/errors/CsrfProtectError.html.j2 @@ -0,0 +1,5 @@ +{% extends 'layouts/base.html.j2' %} + +{% block content %} +{% include 'errors/_CsrfProtectError.html.j2' %} +{% endblock %} diff --git a/amt/site/templates/errors/_CsrfProtectError.html.j2 b/amt/site/templates/errors/_CsrfProtectError.html.j2 new file mode 100644 index 000000000..3f2cb32cc --- /dev/null +++ b/amt/site/templates/errors/_CsrfProtectError.html.j2 @@ -0,0 +1,2 @@ +

{{status_code}}

+

{{status_message}}

diff --git a/amt/site/templates/pages/index.html.j2 b/amt/site/templates/pages/index.html.j2 index 52a8b81d4..3ddfe82bc 100644 --- a/amt/site/templates/pages/index.html.j2 +++ b/amt/site/templates/pages/index.html.j2 @@ -20,7 +20,7 @@ {% block content %}
-
+ diff --git a/amt/site/templates/projects/new.html.j2 b/amt/site/templates/projects/new.html.j2 index 2a6679125..5b4fe7787 100644 --- a/amt/site/templates/projects/new.html.j2 +++ b/amt/site/templates/projects/new.html.j2 @@ -3,7 +3,7 @@ {% block content %}

{% trans %}New Project{% endtrans %}

- + {% trans %}Project name{% endtrans %}
diff --git a/poetry.lock b/poetry.lock index 50b90e9d7..d83494f53 100644 --- a/poetry.lock +++ b/poetry.lock @@ -341,6 +341,23 @@ typing-extensions = ">=4.8.0" all = ["email_validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] standard = ["email_validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] +[[package]] +name = "fastapi-csrf-protect" +version = "0.3.4" +description = "Stateless implementation of Cross-Site Request Forgery (XSRF) Protection by using Double Submit Cookie mitigation pattern" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "fastapi_csrf_protect-0.3.4-py3-none-any.whl", hash = "sha256:78ee1d5bcdc10d06f0516fa8ed9e00b4fa8c8dfa8e93bada6ede7a7bba7507ae"}, + {file = "fastapi_csrf_protect-0.3.4.tar.gz", hash = "sha256:a7d170d4e119c22bae8a4f5922529767f4d69c4c69744865c03f939e399e5822"}, +] + +[package.dependencies] +fastapi = ">=0,<1" +itsdangerous = ">=2.0.1,<3.0.0" +pydantic = ">=2.0.0,<3.0.0" +pydantic-settings = ">=2.0.0,<3.0.0" + [[package]] name = "filelock" version = "3.15.4" @@ -607,6 +624,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "itsdangerous" +version = "2.2.0" +description = "Safely pass data to untrusted environments and back." +optional = false +python-versions = ">=3.8" +files = [ + {file = "itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef"}, + {file = "itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173"}, +] + [[package]] name = "jinja2" version = "3.1.4" @@ -1799,4 +1827,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "8d9a0e2ff970256df9508680372d077229163364788a81e376799dc9a935ed26" +content-hash = "4ef06f3d5fc3e52310e23ddeb22d1715f495349926e433f6228eeeb1d14bb47e" diff --git a/pyproject.toml b/pyproject.toml index 2bf8b5ef5..911b295b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ httpx = "^0.27.0" pyyaml-include = "^2.1" click = "^8.1.7" python-ulid = "^2.7.0" +fastapi-csrf-protect = "^0.3.4" [tool.poetry.group.test.dependencies] diff --git a/tests/api/routes/test_projects.py b/tests/api/routes/test_projects.py index 0356f642b..a457b14fd 100644 --- a/tests/api/routes/test_projects.py +++ b/tests/api/routes/test_projects.py @@ -7,6 +7,7 @@ from amt.services.instruments import InstrumentsService from amt.services.storage import FileSystemStorageService from fastapi.testclient import TestClient +from fastapi_csrf_protect import CsrfProtect # type: ignore # noqa from tests.constants import default_instrument @@ -52,9 +53,10 @@ def test_get_new_projects(client: TestClient, init_instruments: Generator[None, ) -def test_post_new_projects_bad_request(client: TestClient) -> None: +def test_post_new_projects_bad_request(client: TestClient, mock_csrf: Generator[None, None, None]) -> None: # when - response = client.post("/projects/new", json={}) + client.cookies["fastapi-csrf-token"] = "1" + response = client.post("/projects/new", json={}, headers={"X-CSRF-Token": "1"}) # then assert response.status_code == 400 @@ -62,11 +64,12 @@ def test_post_new_projects_bad_request(client: TestClient) -> None: assert b"name: Field required" in response.content -def test_post_new_projects(client: TestClient) -> None: +def test_post_new_projects(client: TestClient, mock_csrf: Generator[None, None, None]) -> None: + client.cookies["fastapi-csrf-token"] = "1" new_project = ProjectNew(name="default project") # when - response = client.post("/projects/new", json=new_project.model_dump()) + response = client.post("/projects/new", json=new_project.model_dump(), headers={"X-CSRF-Token": "1"}) # then assert response.status_code == 200 @@ -74,15 +77,16 @@ def test_post_new_projects(client: TestClient) -> None: assert response.headers["HX-Redirect"] == "/project/1" -def test_post_new_projects_write_system_card(client: TestClient) -> None: +def test_post_new_projects_write_system_card(client: TestClient, mock_csrf: Generator[None, None, None]) -> None: # Given + client.cookies["fastapi-csrf-token"] = "1" origin = FileSystemStorageService.write FileSystemStorageService.write = MagicMock() project_new = ProjectNew(name="name1") system_card = SystemCard(name=project_new.name, selected_instruments=[]) # when - client.post("/projects/new", json=project_new.model_dump()) + client.post("/projects/new", json=project_new.model_dump(), headers={"X-CSRF-Token": "1"}) # then FileSystemStorageService.write.assert_called_with(system_card.model_dump()) diff --git a/tests/api/routes/test_status.py b/tests/api/routes/test_status.py index d174bf918..aa7ba67cf 100644 --- a/tests/api/routes/test_status.py +++ b/tests/api/routes/test_status.py @@ -1,3 +1,5 @@ +from collections.abc import Generator + from amt.schema.task import MovedTask from fastapi.testclient import TestClient @@ -5,22 +7,28 @@ from tests.database_test_utils import DatabaseTestUtils -def test_post_move_task(client: TestClient, db: DatabaseTestUtils) -> None: +def test_post_move_task(client: TestClient, db: DatabaseTestUtils, mock_csrf: Generator[None, None, None]) -> None: db.given([default_task(), default_task(), default_task()]) + client.cookies["fastapi-csrf-token"] = "1" move_task: MovedTask = MovedTask(taskId=2, statusId=2, previousSiblingId=1, nextSiblingId=3) - response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True)) + response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True), headers={"X-CSRF-Token": "1"}) assert response.status_code == 200 assert response.headers["content-type"] == "text/html; charset=utf-8" assert b'id="card-content-' in response.content -def test_post_move_task_no_siblings(client: TestClient, db: DatabaseTestUtils) -> None: +def test_post_move_task_no_siblings( + client: TestClient, + db: DatabaseTestUtils, + mock_csrf: Generator[None, None, None], +) -> None: db.given([default_task(), default_task(), default_task()]) + client.cookies["fastapi-csrf-token"] = "1" move_task: MovedTask = MovedTask(taskId=2, statusId=1, previousSiblingId=-1, nextSiblingId=-1) - response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True)) + response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True), headers={"X-CSRF-Token": "1"}) assert response.status_code == 200 assert response.headers["content-type"] == "text/html; charset=utf-8" diff --git a/tests/api/routes/test_tasks_move.py b/tests/api/routes/test_tasks_move.py index 09837194f..cddb8cf05 100644 --- a/tests/api/routes/test_tasks_move.py +++ b/tests/api/routes/test_tasks_move.py @@ -1,23 +1,31 @@ +from collections.abc import Generator + from fastapi.testclient import TestClient from tests.constants import default_task from tests.database_test_utils import DatabaseTestUtils -def test_post_task_move(client: TestClient, db: DatabaseTestUtils) -> None: +def test_post_task_move(client: TestClient, db: DatabaseTestUtils, mock_csrf: Generator[None, None, None]) -> None: db.given([default_task(), default_task(), default_task()]) + client.cookies["fastapi-csrf-token"] = "1" response = client.patch( - "/tasks/", json={"taskId": "1", "statusId": "1", "previousSiblingId": "2", "nextSiblingId": "-1"} + "/tasks/", + json={"taskId": "1", "statusId": "1", "previousSiblingId": "2", "nextSiblingId": "-1"}, + headers={"X-CSRF-Token": "1"}, ) assert response.status_code == 200 assert response.headers["content-type"] == "text/html; charset=utf-8" assert b'id="card-content-1"' in response.content -def test_task_move_error(client: TestClient, db: DatabaseTestUtils) -> None: +def test_task_move_error(client: TestClient, db: DatabaseTestUtils, mock_csrf: Generator[None, None, None]) -> None: + client.cookies["fastapi-csrf-token"] = "1" response = client.patch( - "/tasks/", json={"taskId": "1", "statusId": "1", "previousSiblingId": "2", "nextSiblingId": "-1"} + "/tasks/", + json={"taskId": "1", "statusId": "1", "previousSiblingId": "2", "nextSiblingId": "-1"}, + headers={"X-CSRF-Token": "1"}, ) assert response.status_code == 500 assert response.headers["content-type"] == "text/html; charset=utf-8" diff --git a/tests/conftest.py b/tests/conftest.py index 753902405..65cc58207 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from multiprocessing import Process from pathlib import Path from typing import Any +from unittest.mock import AsyncMock import httpx import pytest @@ -11,6 +12,7 @@ from amt.models import * # noqa from amt.server import create_app from fastapi.testclient import TestClient +from fastapi_csrf_protect import CsrfProtect # type: ignore from playwright.sync_api import Browser from sqlmodel import Session, SQLModel, create_engine @@ -80,7 +82,9 @@ def client(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch) -> Generator[ with TestClient(app, raise_server_exceptions=True) as c: app.dependency_overrides[get_session] = db.get_session + c.timeout = 5 + yield c @@ -110,3 +114,11 @@ def db(tmp_path: Path) -> Generator[DatabaseTestUtils, None, None]: with Session(engine, expire_on_commit=False) as session: yield DatabaseTestUtils(session, database_file) + + +@pytest.fixture() +def mock_csrf() -> Generator[None, None, None]: # noqa: PT004 + original = CsrfProtect.validate_csrf + CsrfProtect.validate_csrf = AsyncMock() + yield + CsrfProtect.validate_csrf = original diff --git a/tests/constants.py b/tests/constants.py index aa8e3dc14..53bbea45b 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -12,7 +12,9 @@ def default_project(name: str = "default project", model_card: str = "/tmp/1.yam def default_fastapi_request() -> Request: - return Request(scope={"type": "http", "http_version": "1.1", "method": "GET", "headers": []}) + request = Request(scope={"type": "http", "http_version": "1.1", "method": "GET", "headers": []}) + request.state.csrftoken = "" + return request def default_instrument( diff --git a/tests/core/test_exception_handlers.py b/tests/core/test_exception_handlers.py index a65dea9e4..e630cdcb2 100644 --- a/tests/core/test_exception_handlers.py +++ b/tests/core/test_exception_handlers.py @@ -1,3 +1,4 @@ +from amt.schema.project import ProjectNew from fastapi import status from fastapi.testclient import TestClient @@ -16,6 +17,16 @@ def test_request_validation_exception_handler(client: TestClient): assert response.headers["content-type"] == "text/html; charset=utf-8" +def test_request_csrf_protect_exception_handler_invalid_token_in_header(client: TestClient): + data = client.get("/projects/new") + new_project = ProjectNew(name="default project") + response = client.post( + "/projects/new", json=new_project.model_dump(), headers={"X-CSRF-Token": "1"}, cookies=data.cookies + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert response.headers["content-type"] == "text/html; charset=utf-8" + + def test_http_exception_handler_htmx(client: TestClient): response = client.get("/raise-http-exception", headers={"HX-Request": "true"}) @@ -28,3 +39,24 @@ def test_request_validation_exception_handler_htmx(client: TestClient): assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.headers["content-type"] == "text/html; charset=utf-8" + + +def test_request_csrf_protect_exception_handler_invalid_token_in_header_htmx(client: TestClient): + data = client.get("/projects/new") + new_project = ProjectNew(name="default project") + response = client.post( + "/projects/new", + json=new_project.model_dump(), + headers={"HX-Request": "true", "X-CSRF-Token": "1"}, + cookies=data.cookies, + ) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert response.headers["content-type"] == "text/html; charset=utf-8" + + +def test_(client: TestClient): + response = client.get("/projects/?skip=a", headers={"HX-Request": "true"}) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.headers["content-type"] == "text/html; charset=utf-8" diff --git a/tests/site/static/templates/test_template_new_project.py b/tests/site/static/templates/test_template_new_project.py index 5008a3cfa..e966003b4 100644 --- a/tests/site/static/templates/test_template_new_project.py +++ b/tests/site/static/templates/test_template_new_project.py @@ -5,7 +5,7 @@ def test_tempate_projects_new(): # given request = default_fastapi_request() - context = {"project": ""} + context = {"project": "", "instruments": ""} # when response = templates.TemplateResponse(request, "projects/new.html.j2", context)