From 748d21ae984af090c84d506af11d4778e5a0198f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Birk=20Jernstr=C3=B6m?= Date: Tue, 14 May 2024 17:14:18 +0200 Subject: [PATCH] server: Disable user & organization API for blocked resources --- server/polar/auth/dependencies.py | 5 ++ server/polar/organization/service.py | 30 +++++++- server/polar/user/service.py | 3 + server/tests/fixtures/auth.py | 12 ++- server/tests/fixtures/random_objects.py | 46 ++++++++++++ server/tests/organization/test_endpoints.py | 83 +++++++++++++++++++++ server/tests/user/test_endpoints.py | 40 ++++++++++ 7 files changed, 215 insertions(+), 4 deletions(-) diff --git a/server/polar/auth/dependencies.py b/server/polar/auth/dependencies.py index 06241b5856..a91fc570e1 100644 --- a/server/polar/auth/dependencies.py +++ b/server/polar/auth/dependencies.py @@ -97,6 +97,11 @@ def __call__(self, auth_subject: AuthSubject[Subject]) -> AuthSubject[Subject]: else: raise Unauthorized() + # Blocked subjects + blocked_at = getattr(auth_subject.subject, "blocked_at", None) + if blocked_at is not None: + raise Unauthorized() + # Not allowed subject subject_type = type(auth_subject.subject) if subject_type not in self.allowed_subjects: diff --git a/server/polar/organization/service.py b/server/polar/organization/service.py index 63d375f8f7..bcc76644bd 100644 --- a/server/polar/organization/service.py +++ b/server/polar/organization/service.py @@ -46,26 +46,49 @@ class OrganizationService(ResourceServiceReader[Organization]): async def list_installed(self, session: AsyncSession) -> Sequence[Organization]: stmt = sql.select(Organization).where( Organization.deleted_at.is_(None), + Organization.blocked_at.is_(None), Organization.installation_id.is_not(None), ) res = await session.execute(stmt) return res.scalars().all() + # Override get method to include `blocked_at` filter + async def get( + self, session: AsyncSession, id: UUID, allow_deleted: bool = False + ) -> Organization | None: + conditions = [Organization.id == id] + if not allow_deleted: + conditions.append(Organization.deleted_at.is_(None)) + + conditions.append(Organization.blocked_at.is_(None)) + query = sql.select(Organization).where(*conditions) + res = await session.execute(query) + return res.scalars().unique().one_or_none() + async def get_by_platform( self, session: AsyncSession, platform: Platforms, external_id: int ) -> Organization | None: - return await self.get_by(session, platform=platform, external_id=external_id) + # TODO: Also add deleted_at=None in a separate commit + return await self.get_by( + session, + platform=platform, + external_id=external_id, + blocked_at=None, + ) async def get_by_name( self, session: AsyncSession, platform: Platforms, name: str ) -> Organization | None: - return await self.get_by(session, platform=platform, name=name) + # TODO: Also add deleted_at=None in a separate commit + return await self.get_by(session, platform=platform, name=name, blocked_at=None) async def get_by_custom_domain( self, session: AsyncSession, custom_domain: str ) -> Organization | None: + # TODO: Also add deleted_at=None in a separate commit query = sql.select(Organization).where( - Organization.custom_domain == custom_domain + Organization.custom_domain == custom_domain, + Organization.blocked_at.is_(None), ) res = await session.execute(query) return res.scalars().unique().one_or_none() @@ -78,6 +101,7 @@ async def get_personal( .join(UserOrganization) .where( Organization.deleted_at.is_(None), + Organization.blocked_at.is_(None), Organization.is_personal.is_(True), UserOrganization.user_id == user_id, UserOrganization.deleted_at.is_(None), diff --git a/server/polar/user/service.py b/server/polar/user/service.py index e02f8cb1d4..c343b5189b 100644 --- a/server/polar/user/service.py +++ b/server/polar/user/service.py @@ -38,6 +38,7 @@ async def get_by_email(self, session: AsyncSession, email: str) -> User | None: query = sql.select(User).where( func.lower(User.email) == email.lower(), User.deleted_at.is_(None), + User.blocked_at.is_(None), ) res = await session.execute(query) return res.scalars().unique().one_or_none() @@ -48,6 +49,7 @@ async def get_by_username( query = sql.select(User).where( User.username == username, User.deleted_at.is_(None), + User.blocked_at.is_(None), ) res = await session.execute(query) return res.scalars().unique().one_or_none() @@ -58,6 +60,7 @@ async def get_by_stripe_customer_id( query = sql.select(User).where( User.stripe_customer_id == stripe_customer_id, User.deleted_at.is_(None), + User.blocked_at.is_(None), ) res = await session.execute(query) return res.scalars().unique().one_or_none() diff --git a/server/tests/fixtures/auth.py b/server/tests/fixtures/auth.py index 1166fad799..e15a105e88 100644 --- a/server/tests/fixtures/auth.py +++ b/server/tests/fixtures/auth.py @@ -13,7 +13,13 @@ def __init__( self, *, subject: Literal[ - "anonymous", "user", "user_second", "organization", "organization_second" + "anonymous", + "user", + "user_second", + "user_blocked", + "organization", + "organization_second", + "organization_blocked", ] = "user", scopes: set[Scope] = {Scope.web_default}, method: AuthMethod = AuthMethod.COOKIE, @@ -37,8 +43,10 @@ def auth_subject( request: pytest.FixtureRequest, user: User, user_second: User, + user_blocked: User, organization: Organization, organization_second: Organization, + organization_blocked: Organization, ) -> AuthSubject[Subject]: """ This fixture generates an AuthSubject instance used by the `client` fixture @@ -53,8 +61,10 @@ def auth_subject( "anonymous": Anonymous(), "user": user, "user_second": user_second, + "user_blocked": user_blocked, "organization": organization, "organization_second": organization_second, + "organization_blocked": organization_blocked, } return AuthSubject( subjects_map[auth_subject_fixture.subject], diff --git a/server/tests/fixtures/random_objects.py b/server/tests/fixtures/random_objects.py index 4c1f9184ef..8477248cc6 100644 --- a/server/tests/fixtures/random_objects.py +++ b/server/tests/fixtures/random_objects.py @@ -56,6 +56,25 @@ async def second_organization(save_fixture: SaveFixture) -> Organization: return await create_organization(save_fixture) +@pytest_asyncio.fixture(scope="function") +async def organization_blocked(save_fixture: SaveFixture) -> Organization: + organization = Organization( + platform=Platforms.github, + name=rstr("testorg"), + external_id=secrets.randbelow(100000), + avatar_url="https://avatars.githubusercontent.com/u/105373340?s=200&v=4", + is_personal=True, + installation_id=secrets.randbelow(100000), + installation_created_at=datetime.now(), + installation_updated_at=datetime.now(), + installation_suspended_at=None, + created_from_user_maintainer_upgrade=True, + blocked_at=utc_now(), + ) + await save_fixture(organization) + return organization + + async def create_organization(save_fixture: SaveFixture) -> Organization: organization = Organization( platform=Platforms.github, @@ -218,6 +237,19 @@ async def user_second(save_fixture: SaveFixture) -> User: return user +@pytest_asyncio.fixture(scope="function") +async def user_blocked(save_fixture: SaveFixture) -> User: + user = User( + id=uuid.uuid4(), + username=rstr("DEPRECATED_testuser"), + email=rstr("test") + "@example.com", + avatar_url="https://avatars.githubusercontent.com/u/47952?v=4", + blocked_at=utc_now(), + ) + await save_fixture(user) + return user + + async def create_pledge( save_fixture: SaveFixture, organization: Organization, @@ -398,6 +430,20 @@ async def user_organization_second( return user_organization +@pytest_asyncio.fixture(scope="function") +async def user_organization_blocked( + save_fixture: SaveFixture, + organization_blocked: Organization, + user: User, +) -> UserOrganization: + user_organization = UserOrganization( + user_id=user.id, + organization_id=organization_blocked.id, + ) + await save_fixture(user_organization) + return user_organization + + @pytest_asyncio.fixture async def open_collective_account(save_fixture: SaveFixture, user: User) -> Account: account = Account( diff --git a/server/tests/organization/test_endpoints.py b/server/tests/organization/test_endpoints.py index 57bcba221a..e556708393 100644 --- a/server/tests/organization/test_endpoints.py +++ b/server/tests/organization/test_endpoints.py @@ -28,6 +28,17 @@ async def test_get_organization( assert org.id == organization.id +@pytest.mark.asyncio +@pytest.mark.http_auto_expunge +@pytest.mark.auth +async def test_get_blocked_organization_404( + organization_blocked: Organization, client: AsyncClient +) -> None: + response = await client.get(f"/api/v1/organizations/{organization_blocked.id}") + + assert response.status_code == 404 + + @pytest.mark.asyncio @pytest.mark.http_auto_expunge @pytest.mark.auth @@ -129,6 +140,23 @@ async def test_list_organization_member( assert len(response.json()["items"]) == 0 +@pytest.mark.asyncio +@pytest.mark.http_auto_expunge +@pytest.mark.auth +async def test_list_blocked_organization_member( + organization_blocked: Organization, + user_organization_blocked: UserOrganization, # makes User a member of Organization + client: AsyncClient, +) -> None: + response = await client.get("/api/v1/organizations") + + assert response.status_code == 200 + + orgs = response.json()["items"] + for org in orgs: + assert org.id != str(organization_blocked.id) + + @pytest.mark.asyncio @pytest.mark.http_auto_expunge @pytest.mark.auth @@ -189,6 +217,19 @@ async def test_organization_lookup( assert response.json()["id"] == str(organization.id) +@pytest.mark.asyncio +@pytest.mark.http_auto_expunge +@pytest.mark.auth +async def test_organization_blocked_lookup_404( + organization_blocked: Organization, client: AsyncClient +) -> None: + response = await client.get( + f"/api/v1/organizations/lookup?platform=github&organization_name={organization_blocked.name}" + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio @pytest.mark.http_auto_expunge @pytest.mark.auth @@ -217,6 +258,20 @@ async def test_organization_search_no_matches( assert response.json()["items"] == [] +@pytest.mark.asyncio +@pytest.mark.http_auto_expunge +@pytest.mark.auth +async def test_organization_blocked_search( + organization_blocked: Organization, client: AsyncClient +) -> None: + response = await client.get( + f"/api/v1/organizations/search?platform=github&organization_name={organization_blocked.name}" + ) + + assert response.status_code == 200 + assert response.json()["items"] == [] + + @pytest.mark.asyncio @pytest.mark.http_auto_expunge @pytest.mark.auth @@ -235,6 +290,20 @@ async def test_get_organization_deleted( assert response.status_code == 404 +@pytest.mark.asyncio +@pytest.mark.http_auto_expunge +@pytest.mark.auth +async def test_get_organization_blocked_404( + save_fixture: SaveFixture, + organization_blocked: Organization, + user_organization: UserOrganization, # makes User a member of Organization + client: AsyncClient, +) -> None: + response = await client.get(f"/api/v1/organizations/{organization_blocked.id}") + + assert response.status_code == 404 + + @pytest.mark.asyncio @pytest.mark.http_auto_expunge @pytest.mark.auth @@ -249,6 +318,20 @@ async def test_update_organization_no_admin( assert response.status_code == 401 +@pytest.mark.asyncio +@pytest.mark.http_auto_expunge +@pytest.mark.auth +async def test_update_blocked_organization_no_admin_404( + organization_blocked: Organization, client: AsyncClient +) -> None: + response = await client.patch( + f"/api/v1/organizations/{organization_blocked.id}", + json={"default_upfront_split_to_contributors": 85}, + ) + + assert response.status_code == 404 + + @pytest.mark.asyncio @pytest.mark.auth async def test_update_organization( diff --git a/server/tests/user/test_endpoints.py b/server/tests/user/test_endpoints.py index d211302876..2facbb4b24 100644 --- a/server/tests/user/test_endpoints.py +++ b/server/tests/user/test_endpoints.py @@ -3,6 +3,7 @@ from polar.models.account import Account from polar.models.user import User +from tests.fixtures.auth import AuthSubjectFixture @pytest.mark.asyncio @@ -28,6 +29,14 @@ async def test_get_users_me_no_auth(client: AsyncClient) -> None: assert response.status_code == 401 +@pytest.mark.asyncio +@pytest.mark.http_auto_expunge +@pytest.mark.auth(AuthSubjectFixture(subject="user_blocked")) +async def test_get_users_me_blocked(user_blocked: User, client: AsyncClient) -> None: + response = await client.get("/api/v1/users/me") + assert response.status_code == 401 + + @pytest.mark.asyncio @pytest.mark.auth @pytest.mark.http_auto_expunge @@ -68,6 +77,21 @@ async def test_set_preferences_false(client: AsyncClient) -> None: assert "oauth_accounts" in json +@pytest.mark.asyncio +@pytest.mark.auth(AuthSubjectFixture(subject="user_blocked")) +@pytest.mark.http_auto_expunge +async def test_blocked_user_set_preferences(client: AsyncClient) -> None: + response = await client.put( + "/api/v1/users/me", + json={ + "email_newsletters_and_changelogs": False, + "email_promotions_and_events": False, + }, + ) + + assert response.status_code == 401 + + @pytest.mark.asyncio @pytest.mark.auth @pytest.mark.http_auto_expunge @@ -85,3 +109,19 @@ async def test_set_account( json = response.json() assert json["account_id"] == str(open_collective_account.id) + + +@pytest.mark.asyncio +@pytest.mark.auth(AuthSubjectFixture(subject="user_blocked")) +@pytest.mark.http_auto_expunge +async def test_blocked_user_set_account( + client: AsyncClient, open_collective_account: Account +) -> None: + response = await client.patch( + "/api/v1/users/me/account", + json={ + "account_id": str(open_collective_account.id), + }, + ) + + assert response.status_code == 401