From 20a7c2b8a29abdd3edec404029bba9db901882b5 Mon Sep 17 00:00:00 2001 From: Dzmitry Hramyka Date: Mon, 19 Feb 2024 19:24:11 +0100 Subject: [PATCH] fix: Proper naming of openid accounts (#533) (#534) --- .../versions/6f14afa8ea47_update_auth_models.py | 17 +++++++++++++++++ backend/app/api/api_v1/api.py | 2 ++ 2 files changed, 19 insertions(+) diff --git a/backend/alembic/versions/6f14afa8ea47_update_auth_models.py b/backend/alembic/versions/6f14afa8ea47_update_auth_models.py index ab827088..4766cab7 100644 --- a/backend/alembic/versions/6f14afa8ea47_update_auth_models.py +++ b/backend/alembic/versions/6f14afa8ea47_update_auth_models.py @@ -34,6 +34,23 @@ def upgrade(): type_=sa.JSON(), existing_nullable=True, ) + + # Delete conflicting rows in oauth_account before adding unique constraint + op.execute( + """ + BEGIN; + -- Create a temporary table to store the ids of the rows to keep + CREATE TEMP TABLE keep_rows AS + SELECT DISTINCT ON (oauth_name, user_id) id + FROM oauth_account; + -- Delete rows from oauth_account that are not in the keep_rows temporary table + DELETE FROM oauth_account + WHERE id NOT IN (SELECT id FROM keep_rows); + DROP TABLE keep_rows; + COMMIT; + """ + ) + op.create_unique_constraint(None, "oauth_account", ["oauth_name", "user_id"]) # ### end Alembic commands ### diff --git a/backend/app/api/api_v1/api.py b/backend/app/api/api_v1/api.py index 164b4e1a..af6694ba 100644 --- a/backend/app/api/api_v1/api.py +++ b/backend/app/api/api_v1/api.py @@ -129,6 +129,7 @@ async def get_id_email(self, token: str) -> Tuple[str, Optional[str]]: # pragma client_id=config.client_id, client_secret=config.client_secret, openid_configuration_endpoint=str(config.config_url), + name=config.name, base_scopes=["openid", "/read-limited"], ) else: @@ -136,6 +137,7 @@ async def get_id_email(self, token: str) -> Tuple[str, Optional[str]]: # pragma client_id=config.client_id, client_secret=config.client_secret, openid_configuration_endpoint=str(config.config_url), + name=config.name, ) # Add route for authentication via OAuth2 endpoint. oauth_router = fastapi_users.get_oauth_router(