Skip to content

Commit

Permalink
Added Token Validations (#1682)
Browse files Browse the repository at this point in the history
### Feature or Bugfix
<!-- please choose -->
- Bugfix

### Detail
- This PR does the following w.r.t Cognito IdP and Auth Flow
- Changes Authorizer from built-in Cognito Authorizer to Custom
Authorizer to validate token signature, issuer, and expiry time, etc.
- Adds aditional step to execute GET API on Cognito's `/oauth/userInfo/`
endpoint to ensure access Token validity

Allows data.all API request to execute if the above criteria are met


### Relates
- <URL or Ticket>

### Security
Please answer the questions below briefly where applicable, or write
`N/A`. Based on
[OWASP 10](https://owasp.org/Top10/en/).

- Does this PR introduce or modify any input fields or queries - this
includes
fetching data from storage outside the application (e.g. a database, an
S3 bucket)?
  - Is the input sanitized?
- What precautions are you taking before deserializing the data you
consume?
  - Is injection prevented by parametrizing queries?
  - Have you ensured no `eval` or similar functions are used?
- Does this PR introduce any functionality or component that requires
authorization?
- How have you ensured it respects the existing AuthN/AuthZ mechanisms?
  - Are you logging failed auth attempts?
- Are you using or adding any cryptographic features?
  - Do you use a standard proven implementations?
  - Are the used keys controlled by the customer? Where are they stored?
- Are you introducing any new policies/roles/users?
  - Have you used the least-privilege principle? How?


By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.
  • Loading branch information
noah-paige authored Nov 7, 2024
1 parent dd8e6a9 commit 5069bf8
Show file tree
Hide file tree
Showing 17 changed files with 269 additions and 202 deletions.
13 changes: 13 additions & 0 deletions .checkov.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@
"CKV_AWS_115"
]
},
{
"resource": "AWS::Lambda::Function.CustomAuthorizerFunctiondevB38B5CCB",
"check_ids": [
"CKV_AWS_115",
"CKV_AWS_116"
]
},
{
"resource": "AWS::Lambda::Function.ElasticSearchProxyHandlerDBDE7574",
"check_ids": [
Expand All @@ -210,6 +217,12 @@
"CKV_AWS_158"
]
},
{
"resource": "AWS::Logs::LogGroup.customauthorizerloggroup8F3B5B9D",
"check_ids": [
"CKV_AWS_158"
]
},
{
"resource": "AWS::Logs::LogGroup.dataalldevapigateway2625FE76",
"check_ids": [
Expand Down
8 changes: 5 additions & 3 deletions backend/dataall/base/utils/api_handler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
]
ENGINE = get_engine(envname=ENVNAME)
ALLOWED_ORIGINS = os.getenv('ALLOWED_ORIGINS', '*')
AWS_REGION = os.getenv('AWS_REGION')


def redact_creds(event):
if 'headers' in event and 'Authorization' in event['headers']:
if event.get('headers', {}).get('Authorization'):
event['headers']['Authorization'] = 'XXXXXXXXXXXX'
if 'multiValueHeaders' in event and 'Authorization' in event['multiValueHeaders']:

if event.get('multiValueHeaders', {}).get('Authorization'):
event['multiValueHeaders']['Authorization'] = 'XXXXXXXXXXXX'
return event

Expand Down Expand Up @@ -115,7 +117,7 @@ def check_reauth(query, auth_time, username):
# Determine if there are any Operations that Require ReAuth From SSM Parameter
try:
reauth_apis = ParameterStoreManager.get_parameter_value(
region=os.getenv('AWS_REGION', 'eu-west-1'), parameter_path=f'/dataall/{ENVNAME}/reauth/apis'
region=AWS_REGION, parameter_path=f'/dataall/{ENVNAME}/reauth/apis'
).split(',')
except Exception:
log.info('No ReAuth APIs Found in SSM')
Expand Down
2 changes: 1 addition & 1 deletion deploy/custom_resources/custom_authorizer/auth_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def generate_policy(verified_claims: dict, effect, incoming_resource_str: str):

for claim_name, claim_value in verified_claims.items():
if isinstance(claim_value, list):
verified_claims.update({claim_name: json.dumps(claim_value)})
verified_claims.update({claim_name: ','.join(claim_value)})

context = {**verified_claims}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import json

from auth_services import AuthServices
from jwt_services import JWTServices
Expand All @@ -16,21 +17,33 @@
Custom Lambda Authorizer is attached to the API Gateway. Check the deploy/stacks/lambda_api.py for more details on deployment
"""

OPENID_CONFIG_PATH = os.path.join(os.environ['custom_auth_url'], '.well-known', 'openid-configuration')
JWT_SERVICE = JWTServices(OPENID_CONFIG_PATH)


def lambda_handler(incoming_event, context):
# Get the Token which is sent in the Authorization Header
logger.debug(incoming_event)
auth_token = incoming_event['headers']['Authorization']
if not auth_token:
raise Exception('Unauthorized . Token not found')
raise Exception('Unauthorized. Missing JWT')

verified_claims = JWTServices.validate_jwt_token(auth_token)
logger.debug(verified_claims)
# Validate User is Active with Proper Access Token
user_info = JWT_SERVICE.validate_access_token(auth_token)

# Validate JWT
# Note: Removing the 7 Prefix Chars for 'Bearer ' from JWT
verified_claims = JWT_SERVICE.validate_jwt_token(auth_token[7:])
if not verified_claims:
raise Exception('Unauthorized. Token is not valid')
logger.debug(verified_claims)

# Generate Allow Policy w/ Context
effect = 'Allow'
verified_claims.update(user_info)
policy = AuthServices.generate_policy(verified_claims, effect, incoming_event['methodArn'])
logger.debug('Generated policy is ', policy)
logger.debug(f'Generated policy is {json.dumps(policy)}')
print(f'Generated policy is {json.dumps(policy)}')
return policy


Expand All @@ -39,12 +52,13 @@ def lambda_handler(incoming_event, context):
# AWS Lambda and any other local environments
if __name__ == '__main__':
# for testing locally you can enter the JWT ID Token here
token = ''
#
access_token = ''
account_id = ''
api_gw_id = ''
event = {
'headers': {'Authorization': access_token},
'type': 'TOKEN',
'Authorization': token,
'methodArn': f'arn:aws:execute-api:us-east-1:{account_id}:{api_gw_id}/prod/POST/graphql/api',
}
lambda_handler(event, None)
116 changes: 48 additions & 68 deletions deploy/custom_resources/custom_authorizer/jwt_services.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,81 @@
import os

import requests
from jose import jwk
from jose.jwt import get_unverified_header, decode, ExpiredSignatureError, JWTError
import jwt

import logging

logger = logging.getLogger()
logger.setLevel(os.environ.get('LOG_LEVEL', 'INFO'))

# Configs required to fetch public keys from JWKS
ISSUER_CONFIGS = {
f'{os.environ.get("custom_auth_url")}': {
'jwks_uri': f'{os.environ.get("custom_auth_jwks_url")}',
'allowed_audiences': f'{os.environ.get("custom_auth_client")}',
},
}

issuer_keys = {}


# instead of re-downloading the public keys every time
# we download them only on cold start
# https://aws.amazon.com/blogs/compute/container-reuse-in-lambda/
def fetch_public_keys():
try:
for issuer, issuer_config in ISSUER_CONFIGS.items():
jwks_response = requests.get(issuer_config['jwks_uri'])
jwks_response.raise_for_status()
jwks: dict = jwks_response.json()
for key in jwks['keys']:
value = {
'issuer': issuer,
'audience': issuer_config['allowed_audiences'],
'jwk': jwk.construct(key),
'public_key': jwk.construct(key).public_key(),
}
issuer_keys.update({key['kid']: value})
except Exception as e:
raise Exception(f'Unable to fetch public keys due to {str(e)}')


fetch_public_keys()

# Options to validate the JWT token
# Only modification from default is to turn off verify_at_hash as we don't provide the access token for this validation
# Only modification from default is to turn off verify_aud as Cognito Access Token does not provide this claim
jwt_options = {
'verify_signature': True,
'verify_aud': True,
'verify_aud': False,
'verify_iat': True,
'verify_exp': True,
'verify_nbf': True,
'verify_iss': True,
'verify_sub': True,
'verify_jti': True,
'verify_at_hash': False,
'require_aud': True,
'require_iat': True,
'require_exp': True,
'require_nbf': False,
'require_iss': True,
'require_sub': True,
'require_jti': True,
'require_at_hash': False,
'leeway': 0,
'require': ['iat', 'exp', 'iss', 'sub', 'jti'],
}


class JWTServices:
@staticmethod
def validate_jwt_token(jwt_token):
def __init__(self, openid_config_path):
# Get OpenID Config JSON
self.openid_config = self._fetch_openid_config(openid_config_path)

# Init pyJWT.JWKClient with JWK URI
self.jwks_client = jwt.PyJWKClient(self.openid_config.get('jwks_uri'))

def _fetch_openid_config(self, openid_config_path):
response = requests.get(openid_config_path)
response.raise_for_status()
return response.json()

def validate_jwt_token(self, jwt_token) -> dict:
try:
# Decode and verify the JWT token
header = get_unverified_header(jwt_token)
kid = header['kid']
if kid not in issuer_keys:
logger.info('Public key not found in provided set of keys')
# Retry Fetching the public certificates again in case rotation occurs and lambda has cached the publicKeys
fetch_public_keys()
if kid not in issuer_keys:
raise Exception('Unauthorized')
public_key = issuer_keys.get(kid)
payload = decode(
# get signing_key from JWT
signing_key = self.jwks_client.get_signing_key_from_jwt(jwt_token)

# Decode and Verify JWT
payload = jwt.decode(
jwt_token,
public_key.get('jwk'),
signing_key.key,
algorithms=['RS256', 'HS256'],
issuer=public_key.get('issuer'),
audience=public_key.get('audience'),
issuer=os.environ['custom_auth_url'],
audience=os.environ.get('custom_auth_client'),
leeway=0,
options=jwt_options,
)

# verify client_id if Cognito JWT
if 'client_id' in payload and payload['client_id'] != os.environ.get('custom_auth_client'):
raise Exception('Invalid Client ID in JWT Token')

# verify cid for other IdPs
if 'cid' in payload and payload['cid'] != os.environ.get('custom_auth_client'):
raise Exception('Invalid Client ID in JWT Token')

return payload
except ExpiredSignatureError:
except jwt.exceptions.ExpiredSignatureError as e:
logger.error('JWT token has expired.')
return None
except JWTError as e:
raise e
except jwt.exceptions.PyJWTError as e:
logger.error(f'JWT token validation failed: {str(e)}')
return None
raise e
except Exception as e:
logger.error(f'Failed to validate token - {str(e)}')
return None
raise e

def validate_access_token(self, access_token) -> dict:
# get UserInfo URI from OpenId Configuration
user_info_url = self.openid_config.get('userinfo_endpoint')
r = requests.get(user_info_url, headers={'Authorization': access_token})
r.raise_for_status()
logger.debug(r.json())
return r.json()
5 changes: 2 additions & 3 deletions deploy/custom_resources/custom_authorizer/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
certifi==2024.7.4
charset-normalizer==3.1.0
ecdsa==0.18.0
idna==3.7
pyasn1==0.5.0
python-jose==3.3.0
requests==2.32.2
rsa==4.9
six==1.16.0
urllib3==1.26.19
urllib3==1.26.19
pyjwt==2.9.0
3 changes: 2 additions & 1 deletion deploy/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ aws-cdk-lib==2.160.0
boto3==1.35.26
boto3-stubs==1.35.26
cdk-nag==2.7.2
typeguard==4.2.1
typeguard==4.2.1
cdk-klayers==0.3.0
1 change: 1 addition & 0 deletions deploy/stacks/backend_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __init__(
apig_vpce=apig_vpce,
prod_sizing=prod_sizing,
user_pool=cognito_stack.user_pool if custom_auth is None else None,
user_pool_client=cognito_stack.client if custom_auth is None else None,
pivot_role_name=self.pivot_role_name,
reauth_ttl=reauth_config.get('ttl', 5) if reauth_config else 5,
email_notification_sender_email_id=email_sender,
Expand Down
Loading

0 comments on commit 5069bf8

Please sign in to comment.