Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix absolute/relative URLs #3

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/saml_idp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Settings(BaseSettings):
saml_idp_metadata_key_file: str = ""
"""The path of the SAML metadata key file."""

saml_idp_base_url: HttpUrl | Literal[""] = ""
"""The Base URL used for the URLs in the SAML Metadata."""

saml_idp_logout_url: HttpUrl | Literal[""] = ""
"""The logout URL to redirect to."""

Expand Down
28 changes: 18 additions & 10 deletions src/saml_idp/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Annotated
from urllib.parse import urljoin

from fastapi import APIRouter, Form, Query
from lxml import etree
Expand All @@ -22,6 +23,7 @@
LogoutResponse,
SamlMetadata,
)
from .urls import rel_url_for
from .utils import is_out_of_date

template_path = Path(__file__).parent.resolve() / "templates"
Expand All @@ -36,11 +38,17 @@ def metadata_xml(request: Request) -> Response:
"""Return the IdP's metadata.xml."""
lines = [line.strip() for line in settings.saml_idp_metadata_cert.splitlines()]
cert = "".join(lines[1:-1])
if base_url := str(settings.saml_idp_base_url):
signon_url = urljoin(base_url, rel_url_for(request, "signin"))
logout_url = urljoin(base_url, rel_url_for(request, "logout"))
else:
signon_url = str(request.url_for("signin"))
logout_url = str(request.url_for("logout"))

metadata = SamlMetadata(
entity_id=settings.saml_idp_entity_id,
signon_url=str(request.url_for("signin")),
logout_url=str(request.url_for("logout")),
signon_url=signon_url,
logout_url=logout_url,
valid_until=datetime.now(UTC) + timedelta(days=365),
cert=cert,
)
Expand All @@ -55,8 +63,8 @@ async def main(request: Request, user: GetUser) -> Response:
"main.html",
{
"user": user,
"logout_url": request.url_for("logout_post"),
"login_url": request.url_for("login"),
"logout_url": rel_url_for(request, "logout_post"),
"login_url": rel_url_for(request, "login"),
},
)

Expand Down Expand Up @@ -136,7 +144,7 @@ async def signin(
"destination": destination,
"request_issuer": request_issuer,
"relay_state": relay_state,
"action": request.url_for("login"),
"action": rel_url_for(request, "login"),
}
return templates.TemplateResponse(request, "login.html", context)

Expand All @@ -150,7 +158,7 @@ async def login(request: Request) -> Response:
{
"show_users": settings.saml_idp_show_users,
"users": settings.saml_idp_users,
"action": request.url_for("login"),
"action": rel_url_for(request, "login"),
},
)

Expand Down Expand Up @@ -188,7 +196,7 @@ async def login_post(
# This is the normal login
# Set a cookie and redirect
response = RedirectResponse(
request.url_for("main"), status_code=status.HTTP_302_FOUND
rel_url_for(request, "main"), status_code=status.HTTP_302_FOUND
)
response.set_cookie("session_id", session_id, max_age=3600)
return response
Expand All @@ -201,7 +209,7 @@ async def login_post(
"destination": destination,
"request_issuer": request_issuer,
"relay_state": relay_state,
"action": request.url_for("login"),
"action": rel_url_for(request, "login"),
}
return templates.TemplateResponse(request, "login.html", context)

Expand Down Expand Up @@ -260,9 +268,9 @@ async def logout(

@router.post("/logout-form")
async def logout_post(request: Request) -> Response:
"""Provide a non-SAML login."""
"""Provide a non-SAML logout."""
response = RedirectResponse(
request.url_for("main"), status_code=status.HTTP_302_FOUND
rel_url_for(request, "main"), status_code=status.HTTP_302_FOUND
)
response.delete_cookie("session_id")
return response
20 changes: 20 additions & 0 deletions src/saml_idp/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Utilities for constructing relative URLs."""

from typing import TYPE_CHECKING, Any

from starlette.requests import Request

if TYPE_CHECKING:
from starlette.applications import Starlette # pragma: nocover
from starlette.routing import Router # pragma: nocover


def rel_url_for(req: Request, name: str, /, **path_params: Any) -> str:
"""Provide a relative URL for a path."""
url_path_provider: Router | Starlette | None = req.scope.get(
"router"
) or req.scope.get("app")
if url_path_provider is None:
msg = "`rel_url_for` method can only be used inside a Starlette application."
raise RuntimeError(msg)
return url_path_provider.url_path_for(name, **path_params)
11 changes: 11 additions & 0 deletions tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ async def test_metadata_xml(ac: AsyncClient) -> None:
schema.assertValid(etree.fromstring(xml))


async def test_metadata_xml_base_url(ac: AsyncClient) -> None:
"""You can use the base URL to change the signin/logout URLs."""
settings.saml_idp_base_url = "https://example.com"
response = await ac.get("/metadata.xml")
assert response.status_code == status.HTTP_200_OK
assert "text/xml" in response.headers["content-type"]
xml = response.content.decode()
assert "https://example.com/signin" in xml
assert "https://example.com/logout" in xml


async def test_main_unauthenticated(ac: AsyncClient) -> None:
"""You can get the main page."""
response = await ac.get("/login")
Expand Down
11 changes: 11 additions & 0 deletions tests/test_urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from starlette.requests import Request

from saml_idp.urls import rel_url_for


def test_rel_url_for_error() -> None:
"""Throw an error when there's no router."""
req = Request({"type": "http"})
with pytest.raises(RuntimeError, match=r"can only be used"):
rel_url_for(req, "login")