From 0b772ad04778634fcf80d74196e545d490b16e31 Mon Sep 17 00:00:00 2001 From: Paul Watts Date: Thu, 17 Oct 2024 15:09:31 -0700 Subject: [PATCH] Fix absolute/relative URLs 1. All URLs in the HTML should be relative -- this makes configuration easier because we don't care if we're behind a proxy or not. 2. For the URLs that we need to be absolute -- the ones in the metadata -- allow for the base URL to be specified through settings, rather than inferred from the request. --- src/saml_idp/config.py | 3 +++ src/saml_idp/router.py | 28 ++++++++++++++++++---------- src/saml_idp/urls.py | 20 ++++++++++++++++++++ tests/test_router.py | 11 +++++++++++ tests/test_urls.py | 11 +++++++++++ 5 files changed, 63 insertions(+), 10 deletions(-) create mode 100644 src/saml_idp/urls.py create mode 100644 tests/test_urls.py diff --git a/src/saml_idp/config.py b/src/saml_idp/config.py index 72cdf08..0595195 100644 --- a/src/saml_idp/config.py +++ b/src/saml_idp/config.py @@ -34,6 +34,9 @@ class Settings(BaseSettings): saml_idp_metadata_key_file: str = "" """The path of the SAML metadata key file.""" + saml_idp_base_url: HttpUrl | Literal[""] = "" + """The Base URL used for the URLs in the SAML Metadata.""" + saml_idp_logout_url: HttpUrl | Literal[""] = "" """The logout URL to redirect to.""" diff --git a/src/saml_idp/router.py b/src/saml_idp/router.py index b787c50..530832b 100644 --- a/src/saml_idp/router.py +++ b/src/saml_idp/router.py @@ -4,6 +4,7 @@ from datetime import UTC, datetime, timedelta from pathlib import Path from typing import Annotated +from urllib.parse import urljoin from fastapi import APIRouter, Form, Query from lxml import etree @@ -22,6 +23,7 @@ LogoutResponse, SamlMetadata, ) +from .urls import rel_url_for from .utils import is_out_of_date template_path = Path(__file__).parent.resolve() / "templates" @@ -36,11 +38,17 @@ def metadata_xml(request: Request) -> Response: """Return the IdP's metadata.xml.""" lines = [line.strip() for line in settings.saml_idp_metadata_cert.splitlines()] cert = "".join(lines[1:-1]) + if base_url := str(settings.saml_idp_base_url): + signon_url = urljoin(base_url, rel_url_for(request, "signin")) + logout_url = urljoin(base_url, rel_url_for(request, "logout")) + else: + signon_url = str(request.url_for("signin")) + logout_url = str(request.url_for("logout")) metadata = SamlMetadata( entity_id=settings.saml_idp_entity_id, - signon_url=str(request.url_for("signin")), - logout_url=str(request.url_for("logout")), + signon_url=signon_url, + logout_url=logout_url, valid_until=datetime.now(UTC) + timedelta(days=365), cert=cert, ) @@ -55,8 +63,8 @@ async def main(request: Request, user: GetUser) -> Response: "main.html", { "user": user, - "logout_url": request.url_for("logout_post"), - "login_url": request.url_for("login"), + "logout_url": rel_url_for(request, "logout_post"), + "login_url": rel_url_for(request, "login"), }, ) @@ -136,7 +144,7 @@ async def signin( "destination": destination, "request_issuer": request_issuer, "relay_state": relay_state, - "action": request.url_for("login"), + "action": rel_url_for(request, "login"), } return templates.TemplateResponse(request, "login.html", context) @@ -150,7 +158,7 @@ async def login(request: Request) -> Response: { "show_users": settings.saml_idp_show_users, "users": settings.saml_idp_users, - "action": request.url_for("login"), + "action": rel_url_for(request, "login"), }, ) @@ -188,7 +196,7 @@ async def login_post( # This is the normal login # Set a cookie and redirect response = RedirectResponse( - request.url_for("main"), status_code=status.HTTP_302_FOUND + rel_url_for(request, "main"), status_code=status.HTTP_302_FOUND ) response.set_cookie("session_id", session_id, max_age=3600) return response @@ -201,7 +209,7 @@ async def login_post( "destination": destination, "request_issuer": request_issuer, "relay_state": relay_state, - "action": request.url_for("login"), + "action": rel_url_for(request, "login"), } return templates.TemplateResponse(request, "login.html", context) @@ -260,9 +268,9 @@ async def logout( @router.post("/logout-form") async def logout_post(request: Request) -> Response: - """Provide a non-SAML login.""" + """Provide a non-SAML logout.""" response = RedirectResponse( - request.url_for("main"), status_code=status.HTTP_302_FOUND + rel_url_for(request, "main"), status_code=status.HTTP_302_FOUND ) response.delete_cookie("session_id") return response diff --git a/src/saml_idp/urls.py b/src/saml_idp/urls.py new file mode 100644 index 0000000..6a4ffa4 --- /dev/null +++ b/src/saml_idp/urls.py @@ -0,0 +1,20 @@ +"""Utilities for constructing relative URLs.""" + +from typing import TYPE_CHECKING, Any + +from starlette.requests import Request + +if TYPE_CHECKING: + from starlette.applications import Starlette # pragma: nocover + from starlette.routing import Router # pragma: nocover + + +def rel_url_for(req: Request, name: str, /, **path_params: Any) -> str: + """Provide a relative URL for a path.""" + url_path_provider: Router | Starlette | None = req.scope.get( + "router" + ) or req.scope.get("app") + if url_path_provider is None: + msg = "`rel_url_for` method can only be used inside a Starlette application." + raise RuntimeError(msg) + return url_path_provider.url_path_for(name, **path_params) diff --git a/tests/test_router.py b/tests/test_router.py index e4355ff..bcead34 100644 --- a/tests/test_router.py +++ b/tests/test_router.py @@ -46,6 +46,17 @@ async def test_metadata_xml(ac: AsyncClient) -> None: schema.assertValid(etree.fromstring(xml)) +async def test_metadata_xml_base_url(ac: AsyncClient) -> None: + """You can use the base URL to change the signin/logout URLs.""" + settings.saml_idp_base_url = "https://example.com" + response = await ac.get("/metadata.xml") + assert response.status_code == status.HTTP_200_OK + assert "text/xml" in response.headers["content-type"] + xml = response.content.decode() + assert "https://example.com/signin" in xml + assert "https://example.com/logout" in xml + + async def test_main_unauthenticated(ac: AsyncClient) -> None: """You can get the main page.""" response = await ac.get("/login") diff --git a/tests/test_urls.py b/tests/test_urls.py new file mode 100644 index 0000000..a0a2b20 --- /dev/null +++ b/tests/test_urls.py @@ -0,0 +1,11 @@ +import pytest +from starlette.requests import Request + +from saml_idp.urls import rel_url_for + + +def test_rel_url_for_error() -> None: + """Throw an error when there's no router.""" + req = Request({"type": "http"}) + with pytest.raises(RuntimeError, match=r"can only be used"): + rel_url_for(req, "login")