Skip to content

Commit

Permalink
Merge pull request #3 from paulcwatts/relative-urls
Browse files Browse the repository at this point in the history
Fix absolute/relative URLs
  • Loading branch information
paulcwatts authored Oct 17, 2024
2 parents b921e01 + 0b772ad commit 660d69f
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/saml_idp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class Settings(BaseSettings):
saml_idp_metadata_key_file: str = ""
"""The path of the SAML metadata key file."""

saml_idp_base_url: HttpUrl | Literal[""] = ""
"""The Base URL used for the URLs in the SAML Metadata."""

saml_idp_logout_url: HttpUrl | Literal[""] = ""
"""The logout URL to redirect to."""

Expand Down
28 changes: 18 additions & 10 deletions src/saml_idp/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Annotated
from urllib.parse import urljoin

from fastapi import APIRouter, Form, Query
from lxml import etree
Expand All @@ -22,6 +23,7 @@
LogoutResponse,
SamlMetadata,
)
from .urls import rel_url_for
from .utils import is_out_of_date

template_path = Path(__file__).parent.resolve() / "templates"
Expand All @@ -36,11 +38,17 @@ def metadata_xml(request: Request) -> Response:
"""Return the IdP's metadata.xml."""
lines = [line.strip() for line in settings.saml_idp_metadata_cert.splitlines()]
cert = "".join(lines[1:-1])
if base_url := str(settings.saml_idp_base_url):
signon_url = urljoin(base_url, rel_url_for(request, "signin"))
logout_url = urljoin(base_url, rel_url_for(request, "logout"))
else:
signon_url = str(request.url_for("signin"))
logout_url = str(request.url_for("logout"))

metadata = SamlMetadata(
entity_id=settings.saml_idp_entity_id,
signon_url=str(request.url_for("signin")),
logout_url=str(request.url_for("logout")),
signon_url=signon_url,
logout_url=logout_url,
valid_until=datetime.now(UTC) + timedelta(days=365),
cert=cert,
)
Expand All @@ -55,8 +63,8 @@ async def main(request: Request, user: GetUser) -> Response:
"main.html",
{
"user": user,
"logout_url": request.url_for("logout_post"),
"login_url": request.url_for("login"),
"logout_url": rel_url_for(request, "logout_post"),
"login_url": rel_url_for(request, "login"),
},
)

Expand Down Expand Up @@ -136,7 +144,7 @@ async def signin(
"destination": destination,
"request_issuer": request_issuer,
"relay_state": relay_state,
"action": request.url_for("login"),
"action": rel_url_for(request, "login"),
}
return templates.TemplateResponse(request, "login.html", context)

Expand All @@ -150,7 +158,7 @@ async def login(request: Request) -> Response:
{
"show_users": settings.saml_idp_show_users,
"users": settings.saml_idp_users,
"action": request.url_for("login"),
"action": rel_url_for(request, "login"),
},
)

Expand Down Expand Up @@ -188,7 +196,7 @@ async def login_post(
# This is the normal login
# Set a cookie and redirect
response = RedirectResponse(
request.url_for("main"), status_code=status.HTTP_302_FOUND
rel_url_for(request, "main"), status_code=status.HTTP_302_FOUND
)
response.set_cookie("session_id", session_id, max_age=3600)
return response
Expand All @@ -201,7 +209,7 @@ async def login_post(
"destination": destination,
"request_issuer": request_issuer,
"relay_state": relay_state,
"action": request.url_for("login"),
"action": rel_url_for(request, "login"),
}
return templates.TemplateResponse(request, "login.html", context)

Expand Down Expand Up @@ -260,9 +268,9 @@ async def logout(

@router.post("/logout-form")
async def logout_post(request: Request) -> Response:
"""Provide a non-SAML login."""
"""Provide a non-SAML logout."""
response = RedirectResponse(
request.url_for("main"), status_code=status.HTTP_302_FOUND
rel_url_for(request, "main"), status_code=status.HTTP_302_FOUND
)
response.delete_cookie("session_id")
return response
20 changes: 20 additions & 0 deletions src/saml_idp/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Utilities for constructing relative URLs."""

from typing import TYPE_CHECKING, Any

from starlette.requests import Request

if TYPE_CHECKING:
from starlette.applications import Starlette # pragma: nocover
from starlette.routing import Router # pragma: nocover


def rel_url_for(req: Request, name: str, /, **path_params: Any) -> str:
"""Provide a relative URL for a path."""
url_path_provider: Router | Starlette | None = req.scope.get(
"router"
) or req.scope.get("app")
if url_path_provider is None:
msg = "`rel_url_for` method can only be used inside a Starlette application."
raise RuntimeError(msg)
return url_path_provider.url_path_for(name, **path_params)
11 changes: 11 additions & 0 deletions tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ async def test_metadata_xml(ac: AsyncClient) -> None:
schema.assertValid(etree.fromstring(xml))


async def test_metadata_xml_base_url(ac: AsyncClient) -> None:
"""You can use the base URL to change the signin/logout URLs."""
settings.saml_idp_base_url = "https://example.com"
response = await ac.get("/metadata.xml")
assert response.status_code == status.HTTP_200_OK
assert "text/xml" in response.headers["content-type"]
xml = response.content.decode()
assert "https://example.com/signin" in xml
assert "https://example.com/logout" in xml


async def test_main_unauthenticated(ac: AsyncClient) -> None:
"""You can get the main page."""
response = await ac.get("/login")
Expand Down
11 changes: 11 additions & 0 deletions tests/test_urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import pytest
from starlette.requests import Request

from saml_idp.urls import rel_url_for


def test_rel_url_for_error() -> None:
"""Throw an error when there's no router."""
req = Request({"type": "http"})
with pytest.raises(RuntimeError, match=r"can only be used"):
rel_url_for(req, "login")

0 comments on commit 660d69f

Please sign in to comment.