diff --git a/pyeudiw/jwk/schema.py b/pyeudiw/jwk/schema.py index a372eb6b..dede8c9c 100644 --- a/pyeudiw/jwk/schema.py +++ b/pyeudiw/jwk/schema.py @@ -1,6 +1,6 @@ from typing import List, Literal, Optional -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator class JwkSchema(BaseModel): @@ -32,11 +32,11 @@ def check_value_for_ec(value, name, values): if "RSA" == values.get("kty") and value: raise ValueError(f"{name} must be present only for kty = EC") - @validator("n") + @field_validator("n") def validate_n(cls, n_value, values): cls.check_value_for_rsa(n_value, "n", values) - @validator("e") + @field_validator("e") def validate_e(cls, e_value, values): cls.check_value_for_rsa(e_value, "e", values) @@ -46,15 +46,15 @@ class JwkSchemaEC(JwkSchema): y: Optional[str] # Base64url-encoded crv: Optional[Literal["P-256", "P-384", "P-521"]] - @validator("x") + @field_validator("x") def validate_x(cls, x_value, values): cls.check_value_for_ec(x_value, "x", values) - @validator("y") + @field_validator("y") def validate_y(cls, y_value, values): cls.check_value_for_ec(y_value, "y", values) - @validator("crv") + @field_validator("crv") def validate_crv(cls, crv_value, values): cls.check_value_for_ec(crv_value, "crv", values) diff --git a/pyeudiw/tools/jwt.py b/pyeudiw/jwt/__init__.py similarity index 92% rename from pyeudiw/tools/jwt.py rename to pyeudiw/jwt/__init__.py index 04f9b28f..5a55e2e0 100644 --- a/pyeudiw/tools/jwt.py +++ b/pyeudiw/jwt/__init__.py @@ -12,6 +12,7 @@ from typing import Union from pyeudiw.jwk import JWK +from pyeudiw.jwt.utils import unpad_jwt_header DEFAULT_HASH_FUNC = "SHA-256" @@ -25,13 +26,6 @@ DEFAULT_JWE_ENC = "A256CBC-HS512" -def unpad_jwt_header(jwt: str) -> dict: - b = jwt.split(".")[0] - padded = f"{b}{'=' * divmod(len(b), 4)[1]}" - data = json.loads(base64.urlsafe_b64decode(padded)) - return data - - class JWEHelper(): def __init__(self, jwk: JWK): self.jwk = jwk @@ -88,9 +82,11 @@ def decrypt(self, jwe: str) -> dict: class JWSHelper: - def __init__(self, jwk: JWK): + def __init__(self, jwk: Union[JWK, dict]): self.jwk = jwk - self.alg = DEFAUL_SIG_KTY_MAP[jwk.key.kty] + if isinstance(jwk, dict): + self.jwk = JWK(jwk) + self.alg = DEFAUL_SIG_KTY_MAP[self.jwk.key.kty] def sign( self, diff --git a/pyeudiw/jwt/utils.py b/pyeudiw/jwt/utils.py new file mode 100644 index 00000000..245c5cd2 --- /dev/null +++ b/pyeudiw/jwt/utils.py @@ -0,0 +1,31 @@ +import base64 +import json + + +def unpad_jwt_element(jwt: str, position: int) -> dict: + b = jwt.split(".")[position] + padded = f"{b}{'=' * divmod(len(b), 4)[1]}" + data = json.loads(base64.urlsafe_b64decode(padded)) + return data + + +def unpad_jwt_header(jwt: str) -> dict: + return unpad_jwt_element(jwt, position=0) + + +def unpad_jwt_payload(jwt: str) -> dict: + return unpad_jwt_element(jwt, position=1) + + +def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict: + """ + docs here + """ + head = unpad_jwt_header(jwt) + kid = head["kid"] + if isinstance(provider_jwks, dict) and provider_jwks.get('keys'): + provider_jwks = provider_jwks['keys'] + for jwk in provider_jwks: + if jwk["kid"] == kid: + return jwk + return {} diff --git a/pyeudiw/oauth2/dpop.py b/pyeudiw/oauth2/dpop.py index 065f6ef0..1ccec7b4 100644 --- a/pyeudiw/oauth2/dpop.py +++ b/pyeudiw/oauth2/dpop.py @@ -1,7 +1,11 @@ -from pydantic import BaseModel, HttpUrl +import hashlib +import uuid +from pydantic import BaseModel, HttpUrl from pyeudiw.jwk.schema import JwkSchema - +from pyeudiw.jwt import JWSHelper +from pyeudiw.jwt.utils import unpad_jwt_payload +from pyeudiw.tools.utils import iat_now from typing import Literal @@ -32,24 +36,46 @@ class DPoPTokenPayloadSchema(BaseModel): class DPoPIssuer: - def __init__(self, token: str, private_jwk: dict): + def __init__(self, htu :str, token: str, private_jwk: dict): self.token = token self.private_jwk = private_jwk + self.signer = JWSHelper(private_jwk) + self.htu = htu @property def proof(self): - pass - + data = { + "jti": str(uuid.uuid4()), + "htm": "GET", + "htu": self.htu, + "iat": iat_now(), + "ath": hashlib.sha256(self.token.encode()).hexdigest() + } + jwt = self.signer.sign(data) + return jwt + # TODO assertion class DPoPVerifier: + dpop_header_prefix = 'DPoP ' + def __init__( - self, token: str, + self, public_jwk: dict, - http_header_authz: str, - http_header_dpop: str, + http_header_authz :str, + http_header_dpop :str, ): - self.token = token self.public_jwk = public_jwk - - def validate(self): - pass + self.dpop_token = ( + http_header_authz.replace(self.dpop_header_prefix, '') + if self.dpop_header_prefix in http_header_authz + else http_header_authz + ) + self.proof = http_header_dpop + + @property + def is_valid(self): + jws_verifier = JWSHelper(self.public_jwk) + dpop_valid = jws_verifier.verify(self.dpop_token) + payload = unpad_jwt_payload(self.proof) + proof_valid = hashlib.sha256(self.dpop_token.encode()).hexdigest() == payload['ath'] + return dpop_valid and proof_valid diff --git a/pyeudiw/satosa/backend.py b/pyeudiw/satosa/backend.py index 15a6804f..e03e1c33 100644 --- a/pyeudiw/satosa/backend.py +++ b/pyeudiw/satosa/backend.py @@ -12,7 +12,7 @@ from pyeudiw.satosa.html_template import Jinja2TemplateHandler from pyeudiw.tools.qr_code import QRCode from pyeudiw.jwk import JWK -from pyeudiw.tools.jwt import JWSHelper +from pyeudiw.jwt import JWSHelper from pyeudiw.tools.mobile import is_smartphone logger = logging.getLogger("openid4vp_backend") @@ -155,13 +155,7 @@ def redirect_endpoint(self, context, *args): jwk = self.metadata_jwk helper = JWSHelper(jwk) - data = { - "jti": str(uuid.uuid4()), - "htm": "GET", - "htu": f"{self.client_id}/request_uri", - "iat": int(datetime.now().timestamp()), - "ath": "fUHyO2r2Z3DZ53EsNrWBb0xWXoaNy59IiKCAqksmQEo" - } + data = {} #TODO jwt = helper.sign(data) response = {"request": jwt} diff --git a/pyeudiw/tests/oauth2/test_dpop.py b/pyeudiw/tests/oauth2/test_dpop.py index 0273085e..9d342782 100644 --- a/pyeudiw/tests/oauth2/test_dpop.py +++ b/pyeudiw/tests/oauth2/test_dpop.py @@ -1,10 +1,13 @@ import pytest +from pyeudiw.oauth2.dpop import DPoPIssuer, DPoPVerifier from pyeudiw.jwk import JWK -from pyeudiw.tools.jwt import JWSHelper +from pyeudiw.jwt import JWSHelper +from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header from pyeudiw.tools.utils import iat_now + WALLET_INSTANCE_ATTESTATION = { "iss": "https://wallet-provider.example.org", "sub": "vbeXJksM45xphtANnCiG6mCyuU4jfGNzopGuKvogg9c", @@ -54,10 +57,35 @@ def private_jwk(): def jwshelper(private_jwk): return JWSHelper(private_jwk) +@pytest.fixture +def wia_jws(jwshelper): + wia = jwshelper.sign( + WALLET_INSTANCE_ATTESTATION, + protected={'trust_chain': [], 'x5c': []} + ) + return wia -def test_create(jwshelper): - jwshelper.sign(WALLET_INSTANCE_ATTESTATION) - - -def test_validate(): - pass +def test_create_validate_dpop_http_headers(wia_jws, private_jwk): + # create + header = unpad_jwt_header(wia_jws) + payload = unpad_jwt_payload(wia_jws) + # TODO assertions + + new_dpop = DPoPIssuer( + htu='https://example.org/redirect', + token=wia_jws, + private_jwk=private_jwk + ) + proof = new_dpop.proof + + # TODO assertions + + # verify + dpop = DPoPVerifier( + public_jwk = private_jwk.public_key, + http_header_authz = f"DPoP {wia_jws}", + http_header_dpop = proof + ) + + assert dpop.is_valid + # TODO assertions diff --git a/pyeudiw/tests/tools/test_jwt.py b/pyeudiw/tests/tools/test_jwt.py index 7d2fe990..78988a1b 100644 --- a/pyeudiw/tests/tools/test_jwt.py +++ b/pyeudiw/tests/tools/test_jwt.py @@ -1,7 +1,8 @@ import pytest from pyeudiw.jwk import JWK -from pyeudiw.tools.jwt import JWEHelper, JWSHelper, unpad_jwt_header, DEFAULT_JWE_ALG, DEFAULT_JWE_ENC +from pyeudiw.jwt import JWEHelper, JWSHelper, DEFAULT_JWE_ALG, DEFAULT_JWE_ENC +from pyeudiw.jwt.utils import unpad_jwt_header JWKs_EC = [ (JWK(key_type="EC"), {"key": "value"}), diff --git a/pyeudiw/tools/utils.py b/pyeudiw/tools/utils.py index ea1d6e8b..5758ea08 100644 --- a/pyeudiw/tools/utils.py +++ b/pyeudiw/tools/utils.py @@ -2,8 +2,6 @@ # from django.utils.timezone import make_aware from secrets import token_hex -from . jwt import unpad_jwt_header - import datetime import json @@ -63,19 +61,5 @@ def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list = []) -> return jwks_list -def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict: - """ - docs here - """ - head = unpad_jwt_header(jwt) - kid = head["kid"] - if isinstance(provider_jwks, dict) and provider_jwks.get('keys'): - provider_jwks = provider_jwks['keys'] - for jwk in provider_jwks: - if jwk["kid"] == kid: - return jwk - return {} - - def random_token(n=254): return token_hex(n)