Skip to content

Commit

Permalink
Allow passing access_token directly
Browse files Browse the repository at this point in the history
  • Loading branch information
Koncopd committed Feb 26, 2024
1 parent 20e9767 commit f7eca90
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 27 deletions.
40 changes: 30 additions & 10 deletions lamindb_setup/dev/_hub_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,21 @@ def connect_hub(


def connect_hub_with_auth(
fallback_env: bool = False, renew_token: bool = False
fallback_env: bool = False,
renew_token: bool = False,
access_token: Optional[str] = None,
) -> Client:
from lamindb_setup import settings

hub = connect_hub(fallback_env=fallback_env)
if renew_token:
settings.user.access_token = get_access_token(
settings.user.email, settings.user.password
)
hub.postgrest.auth(settings.user.access_token)
hub.functions.set_auth(settings.user.access_token)
if access_token is None:
from lamindb_setup import settings

if renew_token:
settings.user.access_token = get_access_token(
settings.user.email, settings.user.password
)
access_token = settings.user.access_token
hub.postgrest.auth(access_token)
hub.functions.set_auth(access_token)
return hub


Expand Down Expand Up @@ -98,6 +102,19 @@ def call_with_fallback_auth(
callable,
**kwargs,
):
access_token = kwargs.pop("access_token", None)

if access_token is not None:
try:
client = connect_hub_with_auth(access_token=access_token)
result = callable(**kwargs, client=client)

Check warning on line 110 in lamindb_setup/dev/_hub_client.py

View check run for this annotation

Codecov / codecov/patch

lamindb_setup/dev/_hub_client.py#L108-L110

Added lines #L108 - L110 were not covered by tests
finally:
try:
client.auth.sign_out()
except NameError:
pass
return result

Check warning on line 116 in lamindb_setup/dev/_hub_client.py

View check run for this annotation

Codecov / codecov/patch

lamindb_setup/dev/_hub_client.py#L112-L116

Added lines #L112 - L116 were not covered by tests

for renew_token, fallback_env in [(False, False), (True, False), (False, True)]:
try:
client = connect_hub_with_auth(
Expand All @@ -109,7 +126,10 @@ def call_with_fallback_auth(
if fallback_env:
raise e
finally:
client.auth.sign_out()
try:
client.auth.sign_out()
except NameError:
pass

Check warning on line 132 in lamindb_setup/dev/_hub_client.py

View check run for this annotation

Codecov / codecov/patch

lamindb_setup/dev/_hub_client.py#L131-L132

Added lines #L131 - L132 were not covered by tests
return result


Expand Down
6 changes: 3 additions & 3 deletions lamindb_setup/dev/_hub_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ def _load_instance(
return instance, storage


def access_aws() -> Dict[str, str]:
def access_aws(access_token: Optional[str] = None) -> Dict[str, str]:
from .._settings import settings

if settings.user.handle != "anonymous":
credentials = call_with_fallback_auth(_access_aws)
if settings.user.handle != "anonymous" or access_token is not None:
credentials = call_with_fallback_auth(_access_aws, access_token=access_token)
logger.important("loaded AWS credentials")
return credentials
else:
Expand Down
5 changes: 4 additions & 1 deletion lamindb_setup/dev/_settings_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
region: Optional[str] = None,
uid: Optional[str] = None,
uuid: Optional[UUID] = None,
access_token: Optional[str] = None,
):
self._uid = uid
self._uuid = uuid
Expand All @@ -141,6 +142,8 @@ def __init__(
self._cache_dir = _process_cache_path(cache_path)
else:
self._cache_dir = None
# save access_token here for use in self.root
self.access_token = access_token

@property
def id(self) -> int:
Expand Down Expand Up @@ -185,7 +188,7 @@ def root(self) -> UPath:
if self._root is None:
# below also makes network requests to get credentials
# right
root_path = create_path(self._root_init)
root_path = create_path(self._root_init, access_token=self.access_token)
self._root = root_path
return self._root

Expand Down
29 changes: 16 additions & 13 deletions lamindb_setup/dev/upath.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def convert_pathlike(pathlike: Union[str, Path, UPath]) -> UPath:
AWS_CREDENTIALS_PRESENT = None


def create_path(path: UPath) -> UPath:
def create_path(path: UPath, access_token: Optional[str] = None) -> UPath:
path = convert_pathlike(path)
# test whether we have an AWS S3 path
if not isinstance(path, S3Path):
Expand Down Expand Up @@ -465,18 +465,21 @@ def create_path(path: UPath) -> UPath:
if not is_hosted_storage:
# make anon request if no credentials present
return UPath(path, cache_regions=True, anon=anon)

root_folder = path_str.replace("s3://", "").split("/")[0]

if access_token is None and root_folder in credentials_cache:
credentials = credentials_cache[root_folder]
else:
root_folder = path_str.replace("s3://", "").split("/")[0]
if root_folder in credentials_cache:
credentials = credentials_cache[root_folder]
else:
from ._hub_core import access_aws
from ._hub_core import access_aws

credentials = access_aws()
credentials = access_aws(access_token=access_token)
if access_token is None:
credentials_cache[root_folder] = credentials
return UPath(
path,
key=credentials["key"],
secret=credentials["secret"],
token=credentials["token"],
)

return UPath(
path,
key=credentials["key"],
secret=credentials["secret"],
token=credentials["token"],
)

0 comments on commit f7eca90

Please sign in to comment.