Skip to content

Commit

Permalink
Detect expired tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
jl-wynen committed Apr 19, 2024
1 parent 77977d9 commit e8f8070
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 17 deletions.
1 change: 1 addition & 0 deletions docs/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Features

* Added functions for downloading and uploading samples: :meth:`Client.get_sample`, :meth:`Client.upload_new_sample_now`.
* Added :class:`transfer.link.LinkFileTransfer`.
* :class:`Client` and :class:`ScicatClient` now check whether a token has expired and raise and exception if it has.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
32 changes: 32 additions & 0 deletions src/scitacean/_internal/jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean)
"""Tools for JSON web tokens."""

import base64
import json
from datetime import datetime, timezone
from typing import cast


def decode(token: str) -> tuple[dict[str, str | int], dict[str, str | int], str]:
"""Decode the components of a JSOn web token."""
h, p, signature = token.split(".")
header = _decode_part(h)
payload = _decode_part(p)
return header, payload, signature


def expiry(token: str) -> datetime:
"""Return the expiration time of a JWT in UTC."""
_, payload, _ = decode(token)
# 'exp' should always be given in UTC. Since we have no way of checking that,
# assume that it is the case.
return datetime.fromtimestamp(float(payload["exp"]), tz=timezone.utc)


def _decode_part(s: str) -> dict[str, str | int]:
# urlsafe_b64decode requires a properly padded input but SciCat
# doesn't pad its tokens.
padded = s + "=" * (len(s) % 4)
decoded_str = base64.urlsafe_b64decode(padded).decode("utf-8")
return cast(dict[str, str | int], json.loads(decoded_str))
6 changes: 4 additions & 2 deletions src/scitacean/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .logging import get_logger
from .pid import PID
from .typing import DownloadConnection, FileTransfer, UploadConnection
from .util.credentials import SecretStr, StrStorage
from .util.credentials import ExpiringToken, SecretStr, StrStorage


class Client:
Expand Down Expand Up @@ -566,7 +566,9 @@ def __init__(
self._base_url = url[:-1] if url.endswith("/") else url
self._timeout = datetime.timedelta(seconds=10) if timeout is None else timeout
self._token: StrStorage | None = (
SecretStr(token) if isinstance(token, str) else token
ExpiringToken.from_jwt(SecretStr(token))
if isinstance(token, str)
else token
)

@classmethod
Expand Down
44 changes: 32 additions & 12 deletions src/scitacean/util/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from __future__ import annotations

import datetime
from datetime import datetime, timedelta, timezone
from typing import NoReturn

from .._internal.jwt import expiry


class StrStorage:
"""Base class for storing a string.
Expand Down Expand Up @@ -62,29 +64,47 @@ def __reduce_ex__(self, protocol: object) -> NoReturn:
raise TypeError("SecretStr must not be pickled")


class TimeLimitedStr(StrStorage):
"""A string that expires after some time."""
class ExpiringToken(StrStorage):
"""A JWT token that expires after some time."""

def __init__(
self,
*,
value: str | StrStorage,
expires_at: datetime.datetime,
tolerance: datetime.timedelta | None = None,
expires_at: datetime,
denial_period: timedelta | None = None,
):
super().__init__(value)
if tolerance is None:
tolerance = datetime.timedelta(seconds=10)
self._expires_at = expires_at - tolerance
if denial_period is None:
denial_period = timedelta(seconds=2)
self._expires_at = expires_at - denial_period
self._check_expiry()

@classmethod
def from_jwt(cls, value: str | StrStorage) -> ExpiringToken:
"""Create a new ExpiringToken from a JSON web token."""
value_str = value if isinstance(value, str) else value.get_str()
try:
expires_at = expiry(value_str)
except ValueError:
expires_at = datetime.now(tz=timezone.utc) + timedelta(weeks=100)
return cls(
value=value,
expires_at=expires_at,
)

def get_str(self) -> str:
"""Return the stored plain str object."""
if self._is_expired():
raise RuntimeError("Login has expired")
self._check_expiry()
return super().get_str()

def _is_expired(self) -> bool:
return datetime.datetime.now() > self._expires_at
def _check_expiry(self) -> None:
if datetime.now(tz=self._expires_at.tzinfo) > self._expires_at:
raise RuntimeError(
"SciCat login has expired. You need to create a new client either by "
"logging in through `Client.from_credentials` or by getting a new "
"access token from the SciCat web interface."
)

def __repr__(self) -> str:
return (
Expand Down
53 changes: 50 additions & 3 deletions tests/client/client_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 SciCat Project (https://github.com/SciCatProject/scitacean)

import base64
import json
import pickle
import time
from datetime import datetime, timedelta, timezone
from typing import Any

import pytest

from scitacean import PID, Client
from scitacean.testing.backend.seed import INITIAL_DATASETS
from scitacean.testing.client import FakeClient
from scitacean.util.credentials import SecretStr

Expand All @@ -29,9 +35,7 @@ def test_from_credentials_fake():
)


def test_from_credentials_real(scicat_access, scicat_backend):
if not scicat_backend:
pytest.skip("No backend")
def test_from_credentials_real(scicat_access, require_scicat_backend):
Client.from_credentials(url=scicat_access.url, **scicat_access.user.credentials)


Expand Down Expand Up @@ -80,3 +84,46 @@ def test_fake_can_disable_functions():
client.scicat.get_dataset_model(PID(pid="some-pid"))
with pytest.raises(IndexError, match="custom index error"):
client.scicat.get_orig_datablocks(PID(pid="some-pid"))


def encode_jwt_part(part: dict[str, Any]) -> str:
return base64.urlsafe_b64encode(json.dumps(part).encode("utf-8")).decode("ascii")


def make_token(exp_in: timedelta) -> str:
now = datetime.now(tz=timezone.utc)
exp = now + exp_in

# This is what a SciCat token looks like as of 2024-04-19
header = {"alg": "HS256", "typ": "JWT"}
payload = {
"_id": "7fc0856e50a8",
"username": "Weatherwax",
"email": "[email protected]",
"authStrategy": "ldap",
"id": "7fc0856e50a8",
"userId": "7fc0856e50a8",
"iat": now.timestamp(),
"exp": exp.timestamp(),
}
# Scitacean never validates the signature because it doesn't have the secret key,
# so it doesn't matter what we use here.
signature = "123abc"

return ".".join((encode_jwt_part(header), encode_jwt_part(payload), signature))


def test_detects_expired_token_init():
token = make_token(timedelta(milliseconds=0))
with pytest.raises(RuntimeError, match="SciCat login has expired"):
Client.from_token(url="scicat.com", token=token)


def test_detects_expired_token_get_dataset(scicat_access, require_scicat_backend):
# The token is invalid, but the expiration should be detected before
# even sending it to SciCat.
token = make_token(timedelta(milliseconds=2100)) # > than denial period = 2s
client = Client.from_token(url=scicat_access.url, token=token)
time.sleep(0.5)
with pytest.raises(RuntimeError, match="SciCat login has expired"):
client.get_dataset(INITIAL_DATASETS["public"].pid) # type: ignore[arg-type]

0 comments on commit e8f8070

Please sign in to comment.