Skip to content

Commit

Permalink
Simple type fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
athornton committed Dec 19, 2023
1 parent 508c014 commit 0fe98e6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
12 changes: 9 additions & 3 deletions giftless/auth/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _generate_token_for_action(self, identity: Identity, org: str, repo: str, ac
if lifetime:
token_payload['exp'] = datetime.now(tz=UTC) + timedelta(seconds=lifetime)

return self._generate_token(**token_payload).decode('ascii')
return self._generate_token(**token_payload)

@staticmethod
def _generate_action_scopes(org: str, repo: str, actions: Optional[Set[str]] = None, oid: Optional[str] = None) \
Expand All @@ -163,7 +163,7 @@ def _generate_action_scopes(org: str, repo: str, actions: Optional[Set[str]] = N
obj_id = f'{org}/{repo}/{oid}'
return str(Scope('obj', obj_id, actions))

def _generate_token(self, **kwargs) -> bytes:
def _generate_token(self, **kwargs) -> str:
"""Generate a JWT token that can be used later to authenticate a request
"""
if not self.private_key:
Expand All @@ -187,7 +187,13 @@ def _generate_token(self, **kwargs) -> bytes:
if self.key_id:
headers['kid'] = self.key_id

return jwt.encode(payload, self.private_key, algorithm=self.algorithm, headers=headers) # type: ignore
token = jwt.encode(payload, self.private_key, algorithm=self.algorithm, headers=headers)
# Type of jwt.encode() went from bytes to str in jwt 2.x, but the
# typing hints somehow aren't keeping up. This lets us do the
# right thing with jwt 2.x.
if type(token) == str:
return token # type: ignore
return token.decode('ascii')

def _authenticate(self, request: Request):
"""Authenticate a request
Expand Down
16 changes: 11 additions & 5 deletions giftless/storage/google_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import json
import posixpath
from datetime import timedelta
from typing import Any, BinaryIO, Dict, Optional
from typing import Any, BinaryIO, Dict, Optional, Union

import google.auth
import google.auth # type: ignore
from google.auth import impersonated_credentials
from google.cloud import storage # type: ignore
from google.oauth2 import service_account # type: ignore
Expand All @@ -30,15 +30,19 @@ def __init__(self,
**_):
self.bucket_name = bucket_name
self.path_prefix = path_prefix
self.credentials: Optional[service_account.Credentials] = self._load_credentials(account_key_file, account_key_base64)
self.credentials: Optional[Union[
service_account.Credentials, impersonated_credentials.Credentials
]] = self._load_credentials(
account_key_file, account_key_base64
)
self.storage_client = storage.Client(project=project_name, credentials=self.credentials)
if not self.credentials:
if not serviceaccount_email:
raise ValueError(
"If no account key is given, serviceaccount_email must "
"be set in order to use workload identity."
)
self._serviceaccount_email=serviceaccount_email
self._serviceaccount_email = serviceaccount_email

def get(self, prefix: str, oid: str) -> BinaryIO:
bucket = self.storage_client.bucket(self.bucket_name)
Expand Down Expand Up @@ -137,7 +141,9 @@ def _load_credentials(account_key_file: Optional[str], account_key_base64: Optio
else:
return None # Will use Workload Identity if available

def _get_workload_identity_credentials(self, expires_in: int) -> None:
def _get_workload_identity_credentials(
self, expires_in: int
) -> impersonated_credentials.Credentials:
lifetime = expires_in
if lifetime > 3600:
lifetime = 3600 # Signing credentials are good for one hour max
Expand Down
2 changes: 1 addition & 1 deletion tests/auth/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,4 @@ def _get_test_token(lifetime=300, headers=None, algo='HS256', **kwargs):
else:
raise ValueError("Don't know how to test algo: {}".format(algo))

return jwt.encode(payload, key, algorithm=algo, headers=headers).decode('utf8')
return jwt.encode(payload, key, algorithm=algo, headers=headers)

0 comments on commit 0fe98e6

Please sign in to comment.