Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make OAuth2 work with ORCID #92

Merged
merged 5 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
**/__tests__/**
**/.env
**/*~
**/__tests__/**
1 change: 0 additions & 1 deletion backend/.dockerignore

This file was deleted.

587 changes: 303 additions & 284 deletions backend/Pipfile.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion backend/alembic/versions/c8009ed33089_init_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def upgrade():
sa.Column("user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False),
sa.Column("oauth_name", sa.String(length=100), nullable=False),
sa.Column("access_token", sa.String(length=TOKEN_SIZE), nullable=False),
sa.Column("expires_at", sa.Integer(), nullable=True),
sa.Column("expires_at", sa.BigInteger(), nullable=True),
sa.Column("refresh_token", sa.String(length=TOKEN_SIZE), nullable=True),
sa.Column("account_id", sa.String(length=320), nullable=False),
sa.Column("account_email", sa.String(length=320), nullable=False),
Expand Down
78 changes: 73 additions & 5 deletions backend/app/api/api_v1/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Any, Dict, Optional, Tuple

from fastapi import APIRouter
from httpx_oauth.clients.openid import OpenID
from httpx_oauth.errors import GetIdEmailError
from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Error

from app.api.api_v1.endpoints import adminmsgs, auth
from app.core.auth import auth_backend_bearer, auth_backend_cookie, fastapi_users
Expand Down Expand Up @@ -37,18 +41,82 @@
tags=["users"],
)

#: Base scopes for OrcID authentication.
BASE_SCOPES = ["openid", "/read-limited"]


class OrcidOpenId(OpenID):
"""Custom OrcID OpenID client that fetches the user's email from the OrcID API.

Note that users must have given access to their email address for "trusted parties".
"""

def __init__(
self,
client_id: str,
client_secret: str,
openid_configuration_endpoint: str,
name: str = "orcid",
base_scopes: list[str] | None = BASE_SCOPES,
):
super().__init__(
client_id,
client_secret,
openid_configuration_endpoint,
name=name,
base_scopes=base_scopes,
)

async def get_id_email(self, token: str) -> Tuple[str, Optional[str]]:
"""Custom implementation that returns the user ID and email."""
async with self.get_httpx_client() as client:
response_user = await client.get(
self.openid_configuration["userinfo_endpoint"],
headers={**self.request_headers, "Authorization": f"Bearer {token}"},
)

if response_user.status_code >= 400:
raise GetIdEmailError(response_user.json())
data_user: Dict[str, Any] = response_user.json()

response_record = await client.get(
f"https://api.sandbox.orcid.org/v3.0/{data_user['sub']}/record",
headers={**self.request_headers, "Authorization": f"Bearer {token}"},
)
if response_user.status_code >= 400:
raise GetIdEmailError(response_user.json())
data_record: Dict[str, Any] = response_record.json()

data_record_emails = data_record.get("person", {}).get("emails", {}).get("email", [])
if data_record_emails:
email = data_record_emails[0].get("email", None)
else:
email = None

return str(data_user["sub"]), email


# For now, we only provide oauth clients for cookie-based authentication.
for config in settings.OAUTH2_PROVIDERS:
oauth_client = OpenID(
client_id=config.client_id,
client_secret=config.client_secret,
openid_configuration_endpoint=str(config.config_url),
)
if config.name == "orcid":
oauth_client: OpenID = OrcidOpenId(
client_id=config.client_id,
client_secret=config.client_secret,
openid_configuration_endpoint=str(config.config_url),
base_scopes=["openid", "/read-limited"],
)
else:
oauth_client = OpenID(
client_id=config.client_id,
client_secret=config.client_secret,
openid_configuration_endpoint=str(config.config_url),
)
oauth_router = fastapi_users.get_oauth_router(
oauth_client=oauth_client,
backend=auth_backend_cookie,
state_secret=settings.SECRET_KEY,
associate_by_email=True,
is_verified_by_default=True,
)
api_router.include_router(
oauth_router,
Expand Down
2 changes: 1 addition & 1 deletion backend/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def assemble_cors_origins(cls, v: str | list[str]) -> list[str] | str: # pragma
BACKEND_PREFIX_NGINX: str = "http://nginx:80"

#: URL to REDIS service.
REDIS_URL: str = "redis://redis:5379"
REDIS_URL: str = "redis://redis:6379"

# -- User-Related Configuration ---------------------------------------------

Expand Down
4 changes: 3 additions & 1 deletion backend/app/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
SQLAlchemyBaseOAuthAccountTableUUID,
SQLAlchemyBaseUserTableUUID,
)
from sqlalchemy import Integer, String
from sqlalchemy import BigInteger, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from app.db.base import Base
Expand All @@ -17,10 +17,12 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
if TYPE_CHECKING: # pragma: no cover
access_token: str
refresh_token: Optional[str]
expires_at: Optional[int]
else:
# We need to increase the token size for the OAuthAccount table.
access_token: Mapped[str] = mapped_column(String(TOKEN_SIZE), nullable=False)
refresh_token: Mapped[Optional[str]] = mapped_column(String(TOKEN_SIZE), nullable=True)
expires_at: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)


class User(SQLAlchemyBaseUserTableUUID, Base):
Expand Down
36 changes: 24 additions & 12 deletions frontend/src/views/LoginView.vue
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,30 @@ onMounted(async () => {
<v-card class="mb-5 mt-5" variant="tonal" v-if="userStore.oauth2Providers.length > 0">
<v-card-title> Login With </v-card-title>
<v-card-text class="text-medium-emphasis text-caption mt-3">
<v-btn
block
size="large"
variant="tonal"
color="green"
class="mb-3"
@click="handleProviderLogin(provider)"
v-for="provider in userStore.oauth2Providers"
v-bind:key="provider.name"
>
{{ provider.label }}
</v-btn>
<template v-for="provider in userStore.oauth2Providers" v-bind:key="provider.name">
<v-btn
block
size="large"
variant="tonal"
color="green"
class="mb-3"
@click="handleProviderLogin(provider)"
v-if="provider.name === 'orcid'"
>
Login with ORCID
</v-btn>
<v-btn
block
size="large"
variant="tonal"
color="green"
class="mb-3"
@click="handleProviderLogin(provider)"
v-else
>
Login With {{ provider.label }}
</v-btn>
</template>
<v-btn
block
prepend-icon="mdi-arrow-left"
Expand Down
Loading