Skip to content

Commit

Permalink
feat!: dpop tests, jwt has its own folder
Browse files Browse the repository at this point in the history
  • Loading branch information
peppelinux committed Jul 23, 2023
1 parent 24d62b3 commit d167f46
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 59 deletions.
12 changes: 6 additions & 6 deletions pyeudiw/jwk/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Literal, Optional

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator


class JwkSchema(BaseModel):
Expand Down Expand Up @@ -32,11 +32,11 @@ def check_value_for_ec(value, name, values):
if "RSA" == values.get("kty") and value:
raise ValueError(f"{name} must be present only for kty = EC")

@validator("n")
@field_validator("n")
def validate_n(cls, n_value, values):
cls.check_value_for_rsa(n_value, "n", values)

@validator("e")
@field_validator("e")
def validate_e(cls, e_value, values):
cls.check_value_for_rsa(e_value, "e", values)

Expand All @@ -46,15 +46,15 @@ class JwkSchemaEC(JwkSchema):
y: Optional[str] # Base64url-encoded
crv: Optional[Literal["P-256", "P-384", "P-521"]]

@validator("x")
@field_validator("x")
def validate_x(cls, x_value, values):
cls.check_value_for_ec(x_value, "x", values)

@validator("y")
@field_validator("y")
def validate_y(cls, y_value, values):
cls.check_value_for_ec(y_value, "y", values)

@validator("crv")
@field_validator("crv")
def validate_crv(cls, crv_value, values):
cls.check_value_for_ec(crv_value, "crv", values)

Expand Down
14 changes: 5 additions & 9 deletions pyeudiw/tools/jwt.py → pyeudiw/jwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Union

from pyeudiw.jwk import JWK
from pyeudiw.jwt.utils import unpad_jwt_header

DEFAULT_HASH_FUNC = "SHA-256"

Expand All @@ -25,13 +26,6 @@
DEFAULT_JWE_ENC = "A256CBC-HS512"


def unpad_jwt_header(jwt: str) -> dict:
b = jwt.split(".")[0]
padded = f"{b}{'=' * divmod(len(b), 4)[1]}"
data = json.loads(base64.urlsafe_b64decode(padded))
return data


class JWEHelper():
def __init__(self, jwk: JWK):
self.jwk = jwk
Expand Down Expand Up @@ -88,9 +82,11 @@ def decrypt(self, jwe: str) -> dict:


class JWSHelper:
def __init__(self, jwk: JWK):
def __init__(self, jwk: Union[JWK, dict]):
self.jwk = jwk
self.alg = DEFAUL_SIG_KTY_MAP[jwk.key.kty]
if isinstance(jwk, dict):
self.jwk = JWK(jwk)
self.alg = DEFAUL_SIG_KTY_MAP[self.jwk.key.kty]

def sign(
self,
Expand Down
31 changes: 31 additions & 0 deletions pyeudiw/jwt/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import base64
import json


def unpad_jwt_element(jwt: str, position: int) -> dict:
b = jwt.split(".")[position]
padded = f"{b}{'=' * divmod(len(b), 4)[1]}"
data = json.loads(base64.urlsafe_b64decode(padded))
return data


def unpad_jwt_header(jwt: str) -> dict:
return unpad_jwt_element(jwt, position=0)


def unpad_jwt_payload(jwt: str) -> dict:
return unpad_jwt_element(jwt, position=1)


def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict:
"""
docs here
"""
head = unpad_jwt_header(jwt)
kid = head["kid"]
if isinstance(provider_jwks, dict) and provider_jwks.get('keys'):
provider_jwks = provider_jwks['keys']
for jwk in provider_jwks:
if jwk["kid"] == kid:
return jwk
return {}
50 changes: 38 additions & 12 deletions pyeudiw/oauth2/dpop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from pydantic import BaseModel, HttpUrl
import hashlib
import uuid

from pydantic import BaseModel, HttpUrl
from pyeudiw.jwk.schema import JwkSchema

from pyeudiw.jwt import JWSHelper
from pyeudiw.jwt.utils import unpad_jwt_payload
from pyeudiw.tools.utils import iat_now
from typing import Literal


Expand Down Expand Up @@ -32,24 +36,46 @@ class DPoPTokenPayloadSchema(BaseModel):


class DPoPIssuer:
def __init__(self, token: str, private_jwk: dict):
def __init__(self, htu :str, token: str, private_jwk: dict):
self.token = token
self.private_jwk = private_jwk
self.signer = JWSHelper(private_jwk)
self.htu = htu

@property
def proof(self):
pass

data = {
"jti": str(uuid.uuid4()),
"htm": "GET",
"htu": self.htu,
"iat": iat_now(),
"ath": hashlib.sha256(self.token.encode()).hexdigest()
}
jwt = self.signer.sign(data)
return jwt
# TODO assertion

class DPoPVerifier:
dpop_header_prefix = 'DPoP '

def __init__(
self, token: str,
self,
public_jwk: dict,
http_header_authz: str,
http_header_dpop: str,
http_header_authz :str,
http_header_dpop :str,
):
self.token = token
self.public_jwk = public_jwk

def validate(self):
pass
self.dpop_token = (
http_header_authz.replace(self.dpop_header_prefix, '')
if self.dpop_header_prefix in http_header_authz
else http_header_authz
)
self.proof = http_header_dpop

@property
def is_valid(self):
jws_verifier = JWSHelper(self.public_jwk)
dpop_valid = jws_verifier.verify(self.dpop_token)
payload = unpad_jwt_payload(self.proof)
proof_valid = hashlib.sha256(self.dpop_token.encode()).hexdigest() == payload['ath']
return dpop_valid and proof_valid
10 changes: 2 additions & 8 deletions pyeudiw/satosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pyeudiw.satosa.html_template import Jinja2TemplateHandler
from pyeudiw.tools.qr_code import QRCode
from pyeudiw.jwk import JWK
from pyeudiw.tools.jwt import JWSHelper
from pyeudiw.jwt import JWSHelper
from pyeudiw.tools.mobile import is_smartphone

logger = logging.getLogger("openid4vp_backend")
Expand Down Expand Up @@ -155,13 +155,7 @@ def redirect_endpoint(self, context, *args):
jwk = self.metadata_jwk

helper = JWSHelper(jwk)
data = {
"jti": str(uuid.uuid4()),
"htm": "GET",
"htu": f"{self.client_id}/request_uri",
"iat": int(datetime.now().timestamp()),
"ath": "fUHyO2r2Z3DZ53EsNrWBb0xWXoaNy59IiKCAqksmQEo"
}
data = {} #TODO
jwt = helper.sign(data)
response = {"request": jwt}

Expand Down
42 changes: 35 additions & 7 deletions pyeudiw/tests/oauth2/test_dpop.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pytest

from pyeudiw.oauth2.dpop import DPoPIssuer, DPoPVerifier
from pyeudiw.jwk import JWK
from pyeudiw.tools.jwt import JWSHelper
from pyeudiw.jwt import JWSHelper
from pyeudiw.jwt.utils import unpad_jwt_payload, unpad_jwt_header
from pyeudiw.tools.utils import iat_now



WALLET_INSTANCE_ATTESTATION = {
"iss": "https://wallet-provider.example.org",
"sub": "vbeXJksM45xphtANnCiG6mCyuU4jfGNzopGuKvogg9c",
Expand Down Expand Up @@ -54,10 +57,35 @@ def private_jwk():
def jwshelper(private_jwk):
return JWSHelper(private_jwk)

@pytest.fixture
def wia_jws(jwshelper):
wia = jwshelper.sign(
WALLET_INSTANCE_ATTESTATION,
protected={'trust_chain': [], 'x5c': []}
)
return wia

def test_create(jwshelper):
jwshelper.sign(WALLET_INSTANCE_ATTESTATION)


def test_validate():
pass
def test_create_validate_dpop_http_headers(wia_jws, private_jwk):
# create
header = unpad_jwt_header(wia_jws)
payload = unpad_jwt_payload(wia_jws)
# TODO assertions

new_dpop = DPoPIssuer(
htu='https://example.org/redirect',
token=wia_jws,
private_jwk=private_jwk
)
proof = new_dpop.proof

# TODO assertions

# verify
dpop = DPoPVerifier(
public_jwk = private_jwk.public_key,
http_header_authz = f"DPoP {wia_jws}",
http_header_dpop = proof
)

assert dpop.is_valid
# TODO assertions
3 changes: 2 additions & 1 deletion pyeudiw/tests/tools/test_jwt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

from pyeudiw.jwk import JWK
from pyeudiw.tools.jwt import JWEHelper, JWSHelper, unpad_jwt_header, DEFAULT_JWE_ALG, DEFAULT_JWE_ENC
from pyeudiw.jwt import JWEHelper, JWSHelper, DEFAULT_JWE_ALG, DEFAULT_JWE_ENC
from pyeudiw.jwt.utils import unpad_jwt_header

JWKs_EC = [
(JWK(key_type="EC"), {"key": "value"}),
Expand Down
16 changes: 0 additions & 16 deletions pyeudiw/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# from django.utils.timezone import make_aware
from secrets import token_hex

from . jwt import unpad_jwt_header


import datetime
import json
Expand Down Expand Up @@ -63,19 +61,5 @@ def get_jwks(httpc_params: dict, metadata: dict, federation_jwks: list = []) ->
return jwks_list


def get_jwk_from_jwt(jwt: str, provider_jwks: dict) -> dict:
"""
docs here
"""
head = unpad_jwt_header(jwt)
kid = head["kid"]
if isinstance(provider_jwks, dict) and provider_jwks.get('keys'):
provider_jwks = provider_jwks['keys']
for jwk in provider_jwks:
if jwk["kid"] == kid:
return jwk
return {}


def random_token(n=254):
return token_hex(n)

0 comments on commit d167f46

Please sign in to comment.