Skip to content

Commit

Permalink
feature to pass in audience and options into PyJWT
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard Kolb committed Apr 5, 2020
1 parent e9bb5f2 commit 189bcff
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 3 deletions.
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,25 @@ The user name field in the JWT token payload:
app.add_middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend(secret_key='secret', username_field='user'))
```

*audience*

The audience field in the JWT token is validated:
```python
# Example: changes the username field to "user"
app.add_middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend(secret_key='secret', username_field='user', audience='test_aud'))
```

*options*

The options set to ignore audience verification:
```python
# Example: changes the username field to "user"
app.add_middleware(AuthenticationMiddleware, backend=JWTAuthenticationBackend(secret_key='secret', username_field='user', options={"verify_aud": False}))
```

## Todo

* Support JWT token standard payload
* Set JWT options (time expiration for example)


## Developing
Expand Down
8 changes: 6 additions & 2 deletions starlette_jwt/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,14 @@ async def authenticate(self, request):
class JWTWebSocketAuthenticationBackend(AuthenticationBackend):

def __init__(self, secret_key: str, algorithm: str = 'HS256', query_param_name: str = 'jwt',
username_field: str = 'username'):
username_field: str = 'username', audience = None, options = {}):
self.secret_key = secret_key
self.algorithm = algorithm
self.query_param_name = query_param_name
self.username_field = username_field
self.audience = audience
self.options = options


async def authenticate(self, request):
if self.query_param_name not in request.query_params:
Expand All @@ -74,7 +77,8 @@ async def authenticate(self, request):
token = request.query_params[self.query_param_name]

try:
payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm)
payload = jwt.decode(token, key=self.secret_key, algorithms=self.algorithm, audience=self.audience,
options=self.options)
except jwt.InvalidTokenError as e:
raise AuthenticationError(str(e))

Expand Down
42 changes: 42 additions & 0 deletions tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,48 @@ def test_websocket_valid_authentication():
assert websocket.scope['user'].is_authenticated


def test_websocket_valid_authentication_and_audience():
secret_key = 'example'
app = create_app()
app.add_middleware(AuthenticationMiddleware, backend=JWTWebSocketAuthenticationBackend(secret_key=secret_key,
audience="test_aud"))
client = TestClient(app)
token = jwt.encode(dict(username="user", aud="test_aud"), secret_key, algorithm="HS256").decode()
with client.websocket_connect(f"/ws-auth?jwt={token}") as websocket:
data = websocket.receive_text()
assert data == 'Authentication valid'
assert websocket.scope['user'].is_authenticated


def test_websocket_valid_authentication_and_audience_list():
secret_key = 'example'
app = create_app()
app.add_middleware(AuthenticationMiddleware, backend=JWTWebSocketAuthenticationBackend(secret_key=secret_key,
audience=["test_aud"]))
client = TestClient(app)
token = jwt.encode(dict(username="user", aud="test_aud"), secret_key, algorithm="HS256").decode()
with client.websocket_connect(f"/ws-auth?jwt={token}") as websocket:
data = websocket.receive_text()
assert data == 'Authentication valid'
assert websocket.scope['user'].is_authenticated


def test_websocket_valid_authentication_and_audience_and_option_ignore_audience():
secret_key = 'example'
app = create_app()
options = {"verify_aud": False}
app.add_middleware(AuthenticationMiddleware, backend=JWTWebSocketAuthenticationBackend(secret_key=secret_key,
audience="test_aud",
options=options))
client = TestClient(app)
token = jwt.encode(dict(username="user"), secret_key, algorithm="HS256",
).decode()
with client.websocket_connect(f"/ws-auth?jwt={token}") as websocket:
data = websocket.receive_text()
assert data == 'Authentication valid'
assert websocket.scope['user'].is_authenticated


def test_websocket_invalid_token():
secret_key = 'example'
app = create_app()
Expand Down

0 comments on commit 189bcff

Please sign in to comment.