diff --git a/docs/release-notes.rst b/docs/release-notes.rst index aba3c878..d26395fe 100644 --- a/docs/release-notes.rst +++ b/docs/release-notes.rst @@ -43,6 +43,7 @@ Features * Added functions for downloading and uploading samples: :meth:`Client.get_sample`, :meth:`Client.upload_new_sample_now`. * Added :class:`transfer.link.LinkFileTransfer`. +* :class:`Client` and :class:`ScicatClient` now check whether a token has expired and raise and exception if it has. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/src/scitacean/_internal/jwt.py b/src/scitacean/_internal/jwt.py new file mode 100644 index 00000000..0266fe98 --- /dev/null +++ b/src/scitacean/_internal/jwt.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) +"""Tools for JSON web tokens.""" + +import base64 +import json +from datetime import datetime, timezone +from typing import cast + + +def decode(token: str) -> tuple[dict[str, str | int], dict[str, str | int], str]: + """Decode the components of a JSOn web token.""" + h, p, signature = token.split(".") + header = _decode_part(h) + payload = _decode_part(p) + return header, payload, signature + + +def expiry(token: str) -> datetime: + """Return the expiration time of a JWT in UTC.""" + _, payload, _ = decode(token) + # 'exp' should always be given in UTC. Since we have no way of checking that, + # assume that it is the case. + return datetime.fromtimestamp(float(payload["exp"]), tz=timezone.utc) + + +def _decode_part(s: str) -> dict[str, str | int]: + # urlsafe_b64decode requires a properly padded input but SciCat + # doesn't pad its tokens. + padded = s + "=" * (len(s) % 4) + decoded_str = base64.urlsafe_b64decode(padded).decode("utf-8") + return cast(dict[str, str | int], json.loads(decoded_str)) diff --git a/src/scitacean/client.py b/src/scitacean/client.py index 4e9506cd..102c2e7c 100644 --- a/src/scitacean/client.py +++ b/src/scitacean/client.py @@ -25,7 +25,7 @@ from .logging import get_logger from .pid import PID from .typing import DownloadConnection, FileTransfer, UploadConnection -from .util.credentials import SecretStr, StrStorage +from .util.credentials import ExpiringToken, SecretStr, StrStorage class Client: @@ -566,7 +566,9 @@ def __init__( self._base_url = url[:-1] if url.endswith("/") else url self._timeout = datetime.timedelta(seconds=10) if timeout is None else timeout self._token: StrStorage | None = ( - SecretStr(token) if isinstance(token, str) else token + ExpiringToken.from_jwt(SecretStr(token)) + if isinstance(token, str) + else token ) @classmethod diff --git a/src/scitacean/util/credentials.py b/src/scitacean/util/credentials.py index 9ba8ac2b..6d11bf1b 100644 --- a/src/scitacean/util/credentials.py +++ b/src/scitacean/util/credentials.py @@ -4,9 +4,11 @@ from __future__ import annotations -import datetime +from datetime import datetime, timedelta, timezone from typing import NoReturn +from .._internal.jwt import expiry + class StrStorage: """Base class for storing a string. @@ -62,29 +64,47 @@ def __reduce_ex__(self, protocol: object) -> NoReturn: raise TypeError("SecretStr must not be pickled") -class TimeLimitedStr(StrStorage): - """A string that expires after some time.""" +class ExpiringToken(StrStorage): + """A JWT token that expires after some time.""" def __init__( self, *, value: str | StrStorage, - expires_at: datetime.datetime, - tolerance: datetime.timedelta | None = None, + expires_at: datetime, + denial_period: timedelta | None = None, ): super().__init__(value) - if tolerance is None: - tolerance = datetime.timedelta(seconds=10) - self._expires_at = expires_at - tolerance + if denial_period is None: + denial_period = timedelta(seconds=2) + self._expires_at = expires_at - denial_period + self._check_expiry() + + @classmethod + def from_jwt(cls, value: str | StrStorage) -> ExpiringToken: + """Create a new ExpiringToken from a JSON web token.""" + value_str = value if isinstance(value, str) else value.get_str() + try: + expires_at = expiry(value_str) + except ValueError: + expires_at = datetime.now(tz=timezone.utc) + timedelta(weeks=100) + return cls( + value=value, + expires_at=expires_at, + ) def get_str(self) -> str: """Return the stored plain str object.""" - if self._is_expired(): - raise RuntimeError("Login has expired") + self._check_expiry() return super().get_str() - def _is_expired(self) -> bool: - return datetime.datetime.now() > self._expires_at + def _check_expiry(self) -> None: + if datetime.now(tz=self._expires_at.tzinfo) > self._expires_at: + raise RuntimeError( + "SciCat login has expired. You need to create a new client either by " + "logging in through `Client.from_credentials` or by getting a new " + "access token from the SciCat web interface." + ) def __repr__(self) -> str: return ( diff --git a/tests/client/client_test.py b/tests/client/client_test.py index bdc79986..7e78c801 100644 --- a/tests/client/client_test.py +++ b/tests/client/client_test.py @@ -1,11 +1,17 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean) +import base64 +import json import pickle +import time +from datetime import datetime, timedelta, timezone +from typing import Any import pytest from scitacean import PID, Client +from scitacean.testing.backend.seed import INITIAL_DATASETS from scitacean.testing.client import FakeClient from scitacean.util.credentials import SecretStr @@ -29,9 +35,7 @@ def test_from_credentials_fake(): ) -def test_from_credentials_real(scicat_access, scicat_backend): - if not scicat_backend: - pytest.skip("No backend") +def test_from_credentials_real(scicat_access, require_scicat_backend): Client.from_credentials(url=scicat_access.url, **scicat_access.user.credentials) @@ -80,3 +84,46 @@ def test_fake_can_disable_functions(): client.scicat.get_dataset_model(PID(pid="some-pid")) with pytest.raises(IndexError, match="custom index error"): client.scicat.get_orig_datablocks(PID(pid="some-pid")) + + +def encode_jwt_part(part: dict[str, Any]) -> str: + return base64.urlsafe_b64encode(json.dumps(part).encode("utf-8")).decode("ascii") + + +def make_token(exp_in: timedelta) -> str: + now = datetime.now(tz=timezone.utc) + exp = now + exp_in + + # This is what a SciCat token looks like as of 2024-04-19 + header = {"alg": "HS256", "typ": "JWT"} + payload = { + "_id": "7fc0856e50a8", + "username": "Weatherwax", + "email": "g.weatherwax@wyrd.lancre", + "authStrategy": "ldap", + "id": "7fc0856e50a8", + "userId": "7fc0856e50a8", + "iat": now.timestamp(), + "exp": exp.timestamp(), + } + # Scitacean never validates the signature because it doesn't have the secret key, + # so it doesn't matter what we use here. + signature = "123abc" + + return ".".join((encode_jwt_part(header), encode_jwt_part(payload), signature)) + + +def test_detects_expired_token_init(): + token = make_token(timedelta(milliseconds=0)) + with pytest.raises(RuntimeError, match="SciCat login has expired"): + Client.from_token(url="scicat.com", token=token) + + +def test_detects_expired_token_get_dataset(scicat_access, require_scicat_backend): + # The token is invalid, but the expiration should be detected before + # even sending it to SciCat. + token = make_token(timedelta(milliseconds=2100)) # > than denial period = 2s + client = Client.from_token(url=scicat_access.url, token=token) + time.sleep(0.5) + with pytest.raises(RuntimeError, match="SciCat login has expired"): + client.get_dataset(INITIAL_DATASETS["public"].pid) # type: ignore[arg-type]