diff --git a/backend/pennmobile/urls.py b/backend/pennmobile/urls.py index 5e94960e..934a8f12 100644 --- a/backend/pennmobile/urls.py +++ b/backend/pennmobile/urls.py @@ -1,11 +1,15 @@ +from typing import List, Union + from django.conf import settings from django.contrib import admin -from django.urls import include, path +from django.urls import URLPattern, URLResolver, include, path from django.views.generic import TemplateView from rest_framework.schemas import get_schema_view -urlpatterns = [ +URLPatternList = List[Union[URLPattern, URLResolver]] + +urlpatterns: URLPatternList = [ path("gsr/", include("gsr_booking.urls")), path("portal/", include("portal.urls")), path("admin/", admin.site.urls), @@ -29,7 +33,7 @@ path("sublet/", include("sublet.urls")), ] -urlpatterns = [ +urlpatterns: URLPatternList = [ path("api/", include(urlpatterns)), path("", include((urlpatterns, "apex"))), ] @@ -37,4 +41,4 @@ if settings.DEBUG: import debug_toolbar - urlpatterns = [path("__debug__/", include(debug_toolbar.urls))] + urlpatterns + urlpatterns: URLPatternList = [path("__debug__/", include(debug_toolbar.urls))] + urlpatterns diff --git a/backend/portal/admin.py b/backend/portal/admin.py index a92c248b..a39cfb05 100644 --- a/backend/portal/admin.py +++ b/backend/portal/admin.py @@ -1,51 +1,57 @@ +from typing import Any + from django.contrib import admin -from django.utils.html import escape, mark_safe +from django.db.models import QuerySet +from django.utils.html import escape +from django.utils.safestring import SafeString, mark_safe from portal.models import Content, Poll, PollOption, PollVote, Post, TargetPopulation class ContentAdmin(admin.ModelAdmin): @admin.action(description="Set status to Approved") - def action_approved(modeladmin, request, queryset): + def action_approved(modeladmin: Any, request: Any, queryset: QuerySet) -> None: queryset.update(status=Content.STATUS_APPROVED) @admin.action(description="Set status to Draft") - def action_draft(modeladmin, request, queryset): + def action_draft(modeladmin: Any, request: Any, queryset: QuerySet) -> None: queryset.update(status=Content.STATUS_DRAFT) @admin.action(description="Set status to Revision") - def action_revision(modeladmin, request, queryset): + def action_revision(modeladmin: Any, request: Any, queryset: QuerySet) -> None: queryset.update(status=Content.STATUS_REVISION) actions = [action_approved, action_draft, action_revision] - def get_queryset(self, request): + def get_queryset(self, request: Any) -> QuerySet: queryset = super().get_queryset(request) return queryset.annotate(ar=Content.ACTION_REQUIRED_CONDITION).order_by( "-ar", "-created_date" ) - def ar(self, obj): - return obj.ar + # Using any for the ar property since it comes from a queryset annotation + def ar(self, obj: Any) -> bool: + return bool(obj.ar) - ar.boolean = True + ar.boolean = True # type: ignore[attr-defined] class PostAdmin(ContentAdmin): - def image_tag(instance, height): + @staticmethod + def image_tag(instance: Post, height: int) -> SafeString: return mark_safe( f'' % escape(instance.image and instance.image.url) ) - def small_image(self, instance): + def small_image(self, instance: Post) -> SafeString: return PostAdmin.image_tag(instance, 100) - small_image.short_description = "Post Image" + small_image.short_description = "Post Image" # type: ignore[attr-defined] - def large_image(self, instance): + def large_image(self, instance: Post) -> SafeString: return PostAdmin.image_tag(instance, 300) - large_image.short_description = "Post Image" + large_image.short_description = "Post Image" # type: ignore[attr-defined] readonly_fields = ("large_image",) list_display = ( diff --git a/backend/portal/logic.py b/backend/portal/logic.py index 8234a246..874bc1c0 100644 --- a/backend/portal/logic.py +++ b/backend/portal/logic.py @@ -1,25 +1,16 @@ import json from collections import defaultdict -from typing import TYPE_CHECKING, Any +from typing import Any, Optional from accounts.ipc import authenticated_request -from django.contrib.auth import get_user_model from rest_framework.exceptions import PermissionDenied from portal.models import Poll, PollOption, PollVote, TargetPopulation +from portal.types import PopulationGroups, PopulationList +from utils.types import DjangoUserType -if TYPE_CHECKING: - from django.contrib.auth.models import AbstractUser - - UserType = AbstractUser -else: - UserType = Any - -User = get_user_model() - - -def get_user_info(user: "UserType") -> dict[str, Any]: +def get_user_info(user: DjangoUserType) -> dict[str, Any]: """Returns Platform user information""" response = authenticated_request(user, "GET", "https://platform.pennlabs.org/accounts/me/") if response.status_code == 403: @@ -27,7 +18,7 @@ def get_user_info(user: "UserType") -> dict[str, Any]: return json.loads(response.content) -def get_user_clubs(user: "UserType") -> list[dict[str, Any]]: +def get_user_clubs(user: DjangoUserType) -> list[dict[str, Any]]: """Returns list of clubs that user is a member of""" response = authenticated_request(user, "GET", "https://pennclubs.com/api/memberships/") if response.status_code == 403: @@ -36,7 +27,7 @@ def get_user_clubs(user: "UserType") -> list[dict[str, Any]]: return res_json -def get_club_info(user: "UserType", club_code: str) -> dict[str, Any]: +def get_club_info(user: DjangoUserType, club_code: str) -> dict[str, Any]: """Returns club information based on club code""" response = authenticated_request(user, "GET", f"https://pennclubs.com/api/clubs/{club_code}/") if response.status_code == 403: @@ -45,12 +36,12 @@ def get_club_info(user: "UserType", club_code: str) -> dict[str, Any]: return {"name": res_json["name"], "image": res_json["image_url"], "club_code": club_code} -def get_user_populations(user: "UserType") -> list[TargetPopulation]: +def get_user_populations(user: DjangoUserType) -> PopulationGroups: """Returns the target populations that the user belongs to""" user_info = get_user_info(user) - year = ( + year: PopulationList = ( [ TargetPopulation.objects.get( kind=TargetPopulation.KIND_YEAR, population=user_info["student"]["graduation_year"] @@ -60,7 +51,7 @@ def get_user_populations(user: "UserType") -> list[TargetPopulation]: else [] ) - school = ( + school: PopulationList = ( [ TargetPopulation.objects.get(kind=TargetPopulation.KIND_SCHOOL, population=x["name"]) for x in user_info["student"]["school"] @@ -69,7 +60,7 @@ def get_user_populations(user: "UserType") -> list[TargetPopulation]: else [] ) - major = ( + major: PopulationList = ( [ TargetPopulation.objects.get(kind=TargetPopulation.KIND_MAJOR, population=x["name"]) for x in user_info["student"]["major"] @@ -78,7 +69,7 @@ def get_user_populations(user: "UserType") -> list[TargetPopulation]: else [] ) - degree = ( + degree: PopulationList = ( [ TargetPopulation.objects.get( kind=TargetPopulation.KIND_DEGREE, population=x["degree_type"] @@ -92,29 +83,30 @@ def get_user_populations(user: "UserType") -> list[TargetPopulation]: return [year, school, major, degree] -def check_targets(obj: Poll, user: "UserType") -> bool: +def check_targets(obj: Poll, user: DjangoUserType) -> bool: """ Check if user aligns with target populations of poll or post """ - populations = get_user_populations(user) + population_groups = get_user_populations(user) - year = set(obj.target_populations.filter(kind=TargetPopulation.KIND_YEAR)) - school = set(obj.target_populations.filter(kind=TargetPopulation.KIND_SCHOOL)) - major = set(obj.target_populations.filter(kind=TargetPopulation.KIND_MAJOR)) - degree = set(obj.target_populations.filter(kind=TargetPopulation.KIND_DEGREE)) + year_targets = set(obj.target_populations.filter(kind=TargetPopulation.KIND_YEAR)) + school_targets = set(obj.target_populations.filter(kind=TargetPopulation.KIND_SCHOOL)) + major_targets = set(obj.target_populations.filter(kind=TargetPopulation.KIND_MAJOR)) + degree_targets = set(obj.target_populations.filter(kind=TargetPopulation.KIND_DEGREE)) - return ( - set(populations[0]).issubset(year) - and set(populations[1]).issubset(school) - and set(populations[2]).issubset(major) - and set(populations[3]).issubset(degree) + return all( + set(group).issubset(targets) + for group, targets in zip( + population_groups, [year_targets, school_targets, major_targets, degree_targets] + ) ) -def get_demographic_breakdown(poll_id: int) -> list[dict[str, Any]]: +def get_demographic_breakdown(poll_id: Optional[int] = None) -> list[dict[str, Any]]: """Collects Poll statistics on school and graduation year demographics""" - + if poll_id is None: + raise ValueError("poll_id is required") # passing in id is necessary because # poll info is already serialized poll = Poll.objects.get(id=poll_id) diff --git a/backend/portal/management/commands/polls_populate.py b/backend/portal/management/commands/polls_populate.py index 1414ab37..341ea0a5 100644 --- a/backend/portal/management/commands/polls_populate.py +++ b/backend/portal/management/commands/polls_populate.py @@ -1,19 +1,29 @@ import datetime +from typing import Any -from django.contrib.auth import get_user_model from django.core.management import call_command from django.core.management.base import BaseCommand from django.utils import timezone from portal.models import Poll, PollOption, PollVote, TargetPopulation from user.models import Profile - - -User = get_user_model() +from utils.types import DjangoUserModel, DjangoUserType class Command(BaseCommand): - def handle(self, *args, **kwargs): + def _create_user( + self, username: str, email: str, password: str, graduation_date: datetime.date + ) -> DjangoUserType: + """Helper to create a user with profile""" + if not DjangoUserModel.objects.filter(username=username).exists(): + user = DjangoUserModel.objects.create_user(username, email, password) + profile = Profile.objects.get(user=user) + setattr(profile, "expected_graduation", graduation_date) + profile.save() + return user + return DjangoUserModel.objects.get(username=username) + + def handle(self, *args: Any, **kwargs: Any) -> None: # Define graduation years df_2022 = datetime.date(2022, 5, 15) @@ -22,55 +32,12 @@ def handle(self, *args, **kwargs): df_2025 = datetime.date(2025, 5, 17) # Create users and set graduation years - if not User.objects.filter(username="user1").first(): - user1 = User.objects.create_user("user1", "user@seas.upenn.edu", "user") - user1_profile = Profile.objects.get(user=user1) - user1_profile.expected_graduation = df_2022 - user1_profile.save() - else: - user1 = User.objects.get(username="user1") - - if not User.objects.filter(username="user2").first(): - user2 = User.objects.create_user("user2", "user2@seas.upenn.edu", "user2") - user2_profile = Profile.objects.get(user=user2) - user2_profile.expected_graduation = df_2023 - user2_profile.save() - else: - user2 = User.objects.get(username="user2") - - if not User.objects.filter(username="user3").first(): - user3 = User.objects.create_user("user3", "user3@seas.upenn.edu", "user3") - user3_profile = Profile.objects.get(user=user3) - user3_profile.expected_graduation = df_2024 - user3_profile.save() - else: - user3 = User.objects.get(username="user3") - - if not User.objects.filter(username="user_cas").first(): - user_cas = User.objects.create_user("user_cas", "user@sas.upenn.edu", "user_cas") - user_cas_profile = Profile.objects.get(user=user_cas) - user_cas_profile.expected_graduation = df_2025 - user_cas_profile.save() - else: - user_cas = User.objects.get(username="user_cas") - - if not User.objects.filter(username="user_wh").first(): - user_wh = User.objects.create_user("user_wh", "user@wharton.upenn.edu", "user_wh") - user_wh_profile = Profile.objects.get(user=user_wh) - user_wh_profile.expected_graduation = df_2024 - user_wh_profile.save() - else: - user_wh = User.objects.get(username="user_wh") - - if not User.objects.filter(username="user_nursing").first(): - user_nursing = User.objects.create_user( - "user_nursing", "user@nursing.upenn.edu", "user_nursing" - ) - user_nursing_profile = Profile.objects.get(user=user_nursing) - user_nursing_profile.expected_graduation = df_2023 - user_nursing_profile.save() - else: - user_nursing = User.objects.get(username="user_nursing") + self._create_user("user1", "user@seas.upenn.edu", "user", df_2022) + self._create_user("user2", "user2@seas.upenn.edu", "user2", df_2023) + self._create_user("user3", "user3@seas.upenn.edu", "user3", df_2024) + self._create_user("user_cas", "user@sas.upenn.edu", "user_cas", df_2025) + self._create_user("user_wh", "user@wharton.upenn.edu", "user_wh", df_2024) + self._create_user("user_nursing", "user@nursing.upenn.edu", "user_nursing", df_2023) # Create target populations call_command("load_target_populations", "--years", "2022, 2023, 2024, 2025") diff --git a/backend/portal/models.py b/backend/portal/models.py index b15aeee1..99db4bc0 100644 --- a/backend/portal/models.py +++ b/backend/portal/models.py @@ -1,3 +1,6 @@ +import datetime +from typing import Any + from django.contrib.auth import get_user_model from django.db import models from django.db.models import Q @@ -21,8 +24,9 @@ class TargetPopulation(models.Model): (KIND_DEGREE, "Degree"), ) - kind = models.CharField(max_length=10, choices=KIND_OPTIONS, default=KIND_SCHOOL) - population = models.CharField(max_length=255) + id: int + kind: str = models.CharField(max_length=10, choices=KIND_OPTIONS, default=KIND_SCHOOL) + population: str = models.CharField(max_length=255) def __str__(self): return self.population @@ -41,24 +45,29 @@ class Content(models.Model): ACTION_REQUIRED_CONDITION = Q(expire_date__gt=timezone.now()) & Q(status=STATUS_DRAFT) - club_code = models.CharField(max_length=255, blank=True) - created_date = models.DateTimeField(default=timezone.now) - start_date = models.DateTimeField(default=timezone.now) - expire_date = models.DateTimeField() - status = models.CharField(max_length=30, choices=STATUS_OPTIONS, default=STATUS_DRAFT) - club_comment = models.CharField(max_length=255, null=True, blank=True) - admin_comment = models.CharField(max_length=255, null=True, blank=True) - target_populations = models.ManyToManyField(TargetPopulation, blank=True) - priority = models.IntegerField(default=0) - creator = models.ForeignKey(User, on_delete=models.SET_NULL, null=True, blank=True) + id: int + club_code: str = models.CharField(max_length=255, blank=True) + created_date: datetime.datetime = models.DateTimeField(default=timezone.now) + start_date: datetime.datetime = models.DateTimeField(default=timezone.now) + expire_date: datetime.datetime = models.DateTimeField() + status: str = models.CharField(max_length=30, choices=STATUS_OPTIONS, default=STATUS_DRAFT) + club_comment: str = models.CharField(max_length=255, null=True, blank=True) + admin_comment: str = models.CharField(max_length=255, null=True, blank=True) + target_populations: models.ManyToManyField = models.ManyToManyField( + TargetPopulation, blank=True + ) + priority: int = models.IntegerField(default=0) + creator: models.ForeignKey = models.ForeignKey( + User, on_delete=models.SET_NULL, null=True, blank=True + ) class Meta: abstract = True - def _get_email_subject(self): + def _get_email_subject(self) -> str: return f"[Portal] {self.__class__._meta.model_name.capitalize()} #{self.id}" - def _on_create(self): + def _on_create(self) -> None: send_automated_email.delay_on_commit( self._get_email_subject(), get_backend_manager_emails(), @@ -68,7 +77,7 @@ def _on_create(self): ), ) - def _on_status_change(self): + def _on_status_change(self) -> None: if email := getattr(self.creator, "email", None): send_automated_email.delay_on_commit( self._get_email_subject(), @@ -82,7 +91,7 @@ def _on_status_change(self): ), ) - def save(self, *args, **kwargs): + def save(self, *args: Any, **kwargs: Any) -> None: prev = self.__class__.objects.filter(id=self.id).first() super().save(*args, **kwargs) if prev is None: @@ -93,35 +102,39 @@ def save(self, *args, **kwargs): class Poll(Content): - question = models.CharField(max_length=255) - multiselect = models.BooleanField(default=False) + question: str = models.CharField(max_length=255) + multiselect: bool = models.BooleanField(default=False) def __str__(self): return self.question class PollOption(models.Model): - poll = models.ForeignKey(Poll, on_delete=models.CASCADE) - choice = models.CharField(max_length=255) - vote_count = models.IntegerField(default=0) + id: int + poll: models.ForeignKey = models.ForeignKey(Poll, on_delete=models.CASCADE) + choice: str = models.CharField(max_length=255) + vote_count: int = models.IntegerField(default=0) - def __str__(self): + def __str__(self) -> str: return f"{self.poll.id} - Option - {self.choice}" class PollVote(models.Model): - id_hash = models.CharField(max_length=255, blank=True) - poll = models.ForeignKey(Poll, on_delete=models.CASCADE) - poll_options = models.ManyToManyField(PollOption) - created_date = models.DateTimeField(default=timezone.now) - target_populations = models.ManyToManyField(TargetPopulation, blank=True) + id: int + id_hash: str = models.CharField(max_length=255, blank=True) + poll: models.ForeignKey = models.ForeignKey(Poll, on_delete=models.CASCADE) + poll_options: models.ManyToManyField = models.ManyToManyField(PollOption) + created_date: datetime.datetime = models.DateTimeField(default=timezone.now) + target_populations: models.ManyToManyField = models.ManyToManyField( + TargetPopulation, blank=True + ) class Post(Content): - title = models.CharField(max_length=255) - subtitle = models.CharField(max_length=255) - post_url = models.CharField(max_length=255, null=True, blank=True) - image = models.ImageField(upload_to="portal/images", null=True, blank=True) + title: str = models.CharField(max_length=255) + subtitle: str = models.CharField(max_length=255) + post_url: str = models.CharField(max_length=255, null=True, blank=True) + image: models.ImageField = models.ImageField(upload_to="portal/images", null=True, blank=True) - def __str__(self): + def __str__(self) -> str: return self.title diff --git a/backend/portal/permissions.py b/backend/portal/permissions.py index adc319ba..c8b09364 100644 --- a/backend/portal/permissions.py +++ b/backend/portal/permissions.py @@ -1,7 +1,11 @@ +from typing import Any, cast + from rest_framework import permissions +from rest_framework.request import Request from portal.logic import get_user_clubs -from portal.models import Poll +from portal.models import Poll, PollOption +from utils.types import get_auth_user class IsSuperUser(permissions.BasePermission): @@ -9,66 +13,79 @@ class IsSuperUser(permissions.BasePermission): Grants permission if the current user is a superuser. """ - def has_object_permission(self, request, view, obj): - return request.user.is_superuser + def has_object_permission(self, request: Request, view: Any, obj: Any) -> bool: + return get_auth_user(request).is_superuser - def has_permission(self, request, view): - return request.user.is_superuser + def has_permission(self, request: Request, view: Any) -> bool: + return get_auth_user(request).is_superuser class PollOwnerPermission(permissions.BasePermission): """Permission that checks authentication and only permits owner to update/destroy objects""" - def has_object_permission(self, request, view, obj): + def _get_club_code(self, obj: Any) -> str: + """Helper to get club_code from either Poll or PollOption object""" + if isinstance(obj, Poll): + return obj.club_code + elif isinstance(obj, PollOption): + poll = cast(Poll, obj.poll) + return poll.club_code + raise ValueError(f"Unexpected object type: {type(obj)}") + + def has_object_permission(self, request: Request, view: Any, obj: Any) -> bool: # only creator can edit + user = get_auth_user(request) if view.action in ["partial_update", "update", "destroy"]: - return obj.club_code in [x["club"]["code"] for x in get_user_clubs(request.user)] - return request.user.is_authenticated + club_code = self._get_club_code(obj) + return club_code in [x["club"]["code"] for x in get_user_clubs(user)] + return user.is_authenticated - def has_permission(self, request, view): - return request.user.is_authenticated + def has_permission(self, request: Request, view: Any) -> bool: + return get_auth_user(request).is_authenticated class OptionOwnerPermission(permissions.BasePermission): """Permission that checks authentication and only permits owner of Poll to update corresponding Option objects""" - def has_object_permission(self, request, view, obj): + def has_object_permission(self, request: Request, view: Any, obj: Any) -> bool: # only creator can edit + user = get_auth_user(request) if view.action in ["partial_update", "update", "destroy"]: - return obj.poll.club_code in [x["club"]["code"] for x in get_user_clubs(request.user)] + return obj.poll.club_code in [x["club"]["code"] for x in get_user_clubs(user)] return True - def has_permission(self, request, view): + def has_permission(self, request: Request, view: Any) -> bool: # only creator of poll can create poll option + user = get_auth_user(request) if view.action == "create" and request.data: poll = Poll.objects.get(id=request.data["poll"]) - return poll.club_code in [x["club"]["code"] for x in get_user_clubs(request.user)] - return request.user.is_authenticated + return poll.club_code in [x["club"]["code"] for x in get_user_clubs(user)] + return user.is_authenticated class TimeSeriesPermission(permissions.BasePermission): """Permission that checks for Time Series access (only creator of Poll and admins)""" - def has_permission(self, request, view): - poll = Poll.objects.filter(id=view.kwargs["poll_id"]) + def has_permission(self, request: Request, view: Any) -> bool: + poll = Poll.objects.filter(id=view.kwargs["poll_id"]).first() + user = get_auth_user(request) # checks if poll exists - if poll.exists(): + if poll is not None: # only poll creator and admin can access - return poll.first().club_code in [ - x["club"]["code"] for x in get_user_clubs(request.user) - ] + return poll.club_code in [x["club"]["code"] for x in get_user_clubs(user)] return False class PostOwnerPermission(permissions.BasePermission): """checks authentication and only permits owner to update/destroy posts""" - def has_object_permission(self, request, view, obj): + def has_object_permission(self, request: Request, view: Any, obj: Any) -> bool: # only creator can edit + user = get_auth_user(request) if view.action in ["partial_update", "update", "destroy"]: - return obj.club_code in [x["club"]["code"] for x in get_user_clubs(request.user)] + return obj.club_code in [x["club"]["code"] for x in get_user_clubs(user)] return True - def has_permission(self, request, view): + def has_permission(self, request: Request, view: Any) -> bool: return request.user.is_authenticated diff --git a/backend/portal/serializers.py b/backend/portal/serializers.py index d9a65e44..8ac796c6 100644 --- a/backend/portal/serializers.py +++ b/backend/portal/serializers.py @@ -1,14 +1,12 @@ -from typing import Any, Dict, TypeAlias +from typing import Any, ClassVar, Type, cast +from django.db.models import Model from django.http.request import QueryDict from rest_framework import serializers from portal.logic import check_targets, get_user_clubs, get_user_populations from portal.models import Content, Poll, PollOption, PollVote, Post, TargetPopulation - - -ClubCode: TypeAlias = str -ValidationData: TypeAlias = Dict[str, Any] +from portal.types import ClubCode, ValidationData class TargetPopulationSerializer(serializers.ModelSerializer): @@ -19,7 +17,8 @@ class Meta: class ContentSerializer(serializers.ModelSerializer): class Meta: - fields = ( + model: ClassVar[Type[Model]] + fields: tuple[str, ...] = ( "id", "club_code", "created_date", @@ -30,7 +29,7 @@ class Meta: "status", "target_populations", ) - read_only_fields = ("id", "created_date") + read_only_fields: tuple[str, ...] = ("id", "created_date") abstract = True def _auto_add_target_population(self, validated_data: ValidationData) -> None: @@ -52,10 +51,14 @@ def create(self, validated_data: ValidationData) -> Poll: user = self.context["request"].user # ensures user is part of club if not any([x["club"]["code"] == club_code for x in get_user_clubs(user)]): + model_name = ( + self.Meta.model._meta.model_name.capitalize() + if self.Meta.model._meta.model_name is not None + else "content" + ) raise serializers.ValidationError( detail={ - "detail": "You do not have access to create a " - + f"{self.Meta.model._meta.model_name.capitalize()} under this club." + "detail": f"You do not have access to create a {model_name} under this club." } ) @@ -82,7 +85,7 @@ def update(self, instance: Content, validated_data: ValidationData) -> Content: class PollSerializer(ContentSerializer): class Meta(ContentSerializer.Meta): model = Poll - fields = ( + fields: tuple[str, ...] = ( *ContentSerializer.Meta.fields, "question", "multiselect", @@ -92,13 +95,13 @@ class Meta(ContentSerializer.Meta): class PollOptionSerializer(serializers.ModelSerializer): class Meta: model = PollOption - fields = ( + fields: tuple[str, ...] = ( "id", "poll", "choice", "vote_count", ) - read_only_fields = ("id", "vote_count") + read_only_fields: tuple[str, ...] = ("id", "vote_count") def create(self, validated_data: ValidationData) -> PollOption: poll_options_count = PollOption.objects.filter(poll=validated_data["poll"]).count() @@ -108,10 +111,11 @@ def create(self, validated_data: ValidationData) -> PollOption: ) return super().create(validated_data) - def update(self, instance, validated_data): + def update(self, instance: PollOption, validated_data: ValidationData) -> PollOption: # if Poll Option is updated, then corresponding Poll approval should be false - instance.poll.status = Poll.STATUS_DRAFT - instance.poll.save() + poll = cast(Poll, instance.poll) + poll.status = Poll.STATUS_DRAFT + poll.save() return super().update(instance, validated_data) @@ -122,7 +126,7 @@ class RetrievePollSerializer(serializers.ModelSerializer): class Meta: model = Poll - fields = ( + fields: tuple[str, ...] = ( "id", "club_code", "question", @@ -140,8 +144,13 @@ class Meta: class PollVoteSerializer(serializers.ModelSerializer): class Meta: model = PollVote - fields = ("id", "id_hash", "poll_options", "created_date") - read_only_fields = ( + fields: tuple[str, ...] = ( + "id", + "id_hash", + "poll_options", + "created_date", + ) + read_only_fields: tuple[str, ...] = ( "id", "created_date", ) @@ -201,8 +210,8 @@ class RetrievePollVoteSerializer(serializers.ModelSerializer): class Meta: model = PollVote - fields = ("id", "id_hash", "poll", "poll_options", "created_date") - read_only_fields = ( + fields: tuple[str, ...] = ("id", "id_hash", "poll", "poll_options", "created_date") + read_only_fields: tuple[str, ...] = ( "id", "created_date", ) @@ -229,7 +238,7 @@ def get_image_url(self, obj: Post) -> str | None: class Meta(ContentSerializer.Meta): model = Post - fields = ( + fields: tuple[str, ...] = ( *ContentSerializer.Meta.fields, "title", "subtitle", @@ -240,12 +249,13 @@ class Meta(ContentSerializer.Meta): def is_valid(self, *args: Any, **kwargs: Any) -> bool: if isinstance(self.initial_data, QueryDict): - self.initial_data = self.initial_data.dict() - self.initial_data["target_populations"] = list( - ( - map(int, self.initial_data["target_populations"].split(",")) - if "target_populations" in self.initial_data + data = self.initial_data.dict() + target_populations = data.get("target_populations", "") + if isinstance(target_populations, str): + data["target_populations"] = ( + list(map(int, target_populations.split(","))) + if target_populations.strip() else [] - ), - ) + ) + self.initial_data = data return super().is_valid(*args, **kwargs) diff --git a/backend/portal/types.py b/backend/portal/types.py new file mode 100644 index 00000000..5192ba1d --- /dev/null +++ b/backend/portal/types.py @@ -0,0 +1,18 @@ +from typing import Any, Dict, List, Set, TypeAlias +from django.db.models import Manager, QuerySet +from portal.models import Poll, PollOption, PollVote, Post, TargetPopulation + +# QuerySet types +PollQuerySet: TypeAlias = QuerySet[Poll, Manager[Poll]] +PostQuerySet: TypeAlias = QuerySet[Post, Manager[Post]] +PollVoteQuerySet: TypeAlias = QuerySet[PollVote, Manager[PollVote]] +PollOptionQuerySet: TypeAlias = QuerySet[PollOption, Manager[PollOption]] + +# Data structure types +VoteStatistics: TypeAlias = Dict[str, Any] +ClubCode: TypeAlias = str +ValidationData: TypeAlias = Dict[str, Any] + +# Population types +PopulationList: TypeAlias = List[TargetPopulation] +PopulationGroups: TypeAlias = List[PopulationList] diff --git a/backend/portal/views.py b/backend/portal/views.py index 61ac2413..d3f20d12 100644 --- a/backend/portal/views.py +++ b/backend/portal/views.py @@ -1,7 +1,6 @@ -from typing import Any, Dict, List, TypeAlias +from typing import Any, List, Optional -from django.contrib.auth import get_user_model -from django.db.models import Count, Manager, Q, QuerySet +from django.db.models import Count, Q from django.db.models.functions import Trunc from django.utils import timezone from rest_framework import generics, viewsets @@ -36,17 +35,14 @@ RetrievePollVoteSerializer, TargetPopulationSerializer, ) - - -PollQuerySet: TypeAlias = QuerySet[Poll, Manager[Poll]] -PostQuerySet: TypeAlias = QuerySet[Post, Manager[Post]] -PollVoteQuerySet: TypeAlias = QuerySet[PollVote, Manager[PollVote]] -ClubData: TypeAlias = List[Dict[str, Any]] -PollOptionQuerySet: TypeAlias = QuerySet[PollOption, Manager[PollOption]] -TimeSeriesData: TypeAlias = Dict[str, Any] -VoteStatistics: TypeAlias = Dict[str, Any] - -User = get_user_model() +from portal.types import ( + PollOptionQuerySet, + PollQuerySet, + PollVoteQuerySet, + PostQuerySet, + VoteStatistics, +) +from utils.types import AuthRequest, get_auth_user class UserInfo(APIView): @@ -54,7 +50,7 @@ class UserInfo(APIView): permission_classes = [IsAuthenticated] - def get(self, request: Request) -> Response: + def get(self, request: AuthRequest) -> Response: return Response({"user": get_user_info(request.user)}) @@ -63,8 +59,8 @@ class UserClubs(APIView): permission_classes = [IsAuthenticated] - def get(self, request: Request) -> Response: - club_data: ClubData = [ + def get(self, request: AuthRequest) -> Response: + club_data = [ get_club_info(request.user, club["club"]["code"]) for club in get_user_clubs(request.user) ] @@ -104,21 +100,23 @@ class Polls(viewsets.ModelViewSet[Poll]): def get_queryset(self) -> PollQuerySet: # all polls if superuser, polls corresponding to club for regular user + user = get_auth_user(self.request) return ( Poll.objects.all() - if self.request.user.is_superuser + if user.is_superuser else Poll.objects.filter( - club_code__in=[x["club"]["code"] for x in get_user_clubs(self.request.user)] + club_code__in=[x["club"]["code"] for x in get_user_clubs(user)] ) ) @action(detail=False, methods=["post"]) - def browse(self, request: Request) -> Response: + def browse(self, request: AuthRequest) -> Response: """Returns list of all possible polls user can answer but has yet to For admins, returns list of all polls they have not voted for and have yet to expire """ id_hash = request.data["id_hash"] + user = get_auth_user(request) # unvoted polls in draft/approaved mode for superuser # unvoted and approved polls within time frame for regular user @@ -128,7 +126,7 @@ def browse(self, request: Request) -> Response: Q(status=Poll.STATUS_DRAFT) | Q(status=Poll.STATUS_APPROVED), expire_date__gte=timezone.localtime(), ) - if request.user.is_superuser + if user.is_superuser else Poll.objects.filter( ~Q(id__in=PollVote.objects.filter(id_hash=id_hash).values_list("poll_id")), status=Poll.STATUS_APPROVED, @@ -140,9 +138,9 @@ def browse(self, request: Request) -> Response: # list of polls where user doesn't identify with # target populations bad_polls = [] - if not request.user.is_superuser: + if not user.is_superuser: for unfiltered_poll in unfiltered_polls: - if not check_targets(unfiltered_poll, request.user): + if not check_targets(unfiltered_poll, user): bad_polls.append(unfiltered_poll.id) # excludes the bad polls @@ -175,7 +173,7 @@ def review(self, request: Request) -> Response: ) @action(detail=True, methods=["get"]) - def option_view(self, request: Request, pk: int = None) -> Response: + def option_view(self, request: Request, pk: Optional[int] = None) -> Response: """Returns information on specific poll, including options and vote counts""" return Response(RetrievePollSerializer(Poll.objects.filter(id=pk).first(), many=False).data) @@ -199,12 +197,13 @@ class PollOptions(viewsets.ModelViewSet[PollOption]): def get_queryset(self) -> PollOptionQuerySet: # if user is admin, they can update anything # if user is not admin, they can only update their own options + user = get_auth_user(self.request) return ( PollOption.objects.all() - if self.request.user.is_superuser + if user.is_superuser else PollOption.objects.filter( poll__in=Poll.objects.filter( - club_code__in=[x["club"]["code"] for x in get_user_clubs(self.request.user)] + club_code__in=[x["club"]["code"] for x in get_user_clubs(user)] ) ) ) @@ -239,7 +238,8 @@ def all(self, request: Request) -> Response: return Response(RetrievePollVoteSerializer(poll_votes, many=True).data) def create(self, request: Request, *args: Any, **kwargs: Any) -> Response: - record_analytics(Metric.PORTAL_POLL_VOTED, request.user.username) + user = get_auth_user(request) + record_analytics(Metric.PORTAL_POLL_VOTED, user.username) return super().create(request, *args, **kwargs) @@ -286,26 +286,28 @@ class Posts(viewsets.ModelViewSet[Post]): serializer_class = PostSerializer def get_queryset(self) -> PostQuerySet: + user = get_auth_user(self.request) return ( Post.objects.all() - if self.request.user.is_superuser + if user.is_superuser else Post.objects.filter( - club_code__in=[x["club"]["code"] for x in get_user_clubs(self.request.user)] + club_code__in=[x["club"]["code"] for x in get_user_clubs(user)] ) ) @action(detail=False, methods=["get"]) - def browse(self, request: Request) -> Response: + def browse(self, request: AuthRequest) -> Response: """ Returns a list of all posts that are targeted at the current user For admins, returns list of posts that they have not approved and have yet to expire """ + user = get_auth_user(request) unfiltered_posts = ( Post.objects.filter( Q(status=Post.STATUS_DRAFT) | Q(status=Post.STATUS_APPROVED), expire_date__gte=timezone.localtime(), ) - if request.user.is_superuser + if user.is_superuser else Post.objects.filter( status=Post.STATUS_APPROVED, start_date__lte=timezone.localtime(), diff --git a/backend/tests/portal/test_permissions.py b/backend/tests/portal/test_permissions.py index a3081b76..fa829d80 100644 --- a/backend/tests/portal/test_permissions.py +++ b/backend/tests/portal/test_permissions.py @@ -1,8 +1,8 @@ import datetime import json +from typing import Any from unittest import mock -from django.contrib.auth import get_user_model from django.core.management import call_command from django.test import TestCase from django.urls import reverse @@ -10,48 +10,58 @@ from rest_framework.test import APIClient from portal.models import Poll, PollOption, PollVote +from utils.types import DjangoUserModel, DjangoUserType -User = get_user_model() - - -def mock_get_user_clubs(*args, **kwargs): +def mock_get_user_clubs(*args: Any, **kwargs: Any) -> list[dict[str, Any]]: with open("tests/portal/get_user_clubs.json") as data: return json.load(data) class PollPermissions(TestCase): - def setUp(self): + def setUp(self) -> None: call_command("load_target_populations", "--years", "2022, 2023, 2024, 2025") - self.client = APIClient() - self.admin = User.objects.create_superuser("admin@example.com", "admin", "admin") - self.user1 = User.objects.create_user("user1", "user@seas.upenn.edu", "user") - self.user2 = User.objects.create_user("user2", "user@seas.upenn.edu", "user") + self.client: APIClient = APIClient() + self.admin: DjangoUserType = DjangoUserModel.objects.create_superuser( + "admin@example.com", "admin", "admin" + ) + self.user1: DjangoUserType = DjangoUserModel.objects.create_user( + "user1", "user@seas.upenn.edu", "user" + ) + self.user2: DjangoUserType = DjangoUserModel.objects.create_user( + "user2", "user@seas.upenn.edu", "user" + ) - self.poll_1 = Poll.objects.create( + self.poll_1: Poll = Poll.objects.create( club_code="pennlabs", question="poll question 1", expire_date=timezone.now() + datetime.timedelta(days=1), status=Poll.STATUS_APPROVED, ) - self.poll_option_1 = PollOption.objects.create(poll=self.poll_1, choice="hello!") - self.poll_option_2 = PollOption.objects.create(poll=self.poll_1, choice="hello!!!!") - self.poll_option_3 = PollOption.objects.create(poll=self.poll_1, choice="hello!!!!!!!") + self.poll_option_1: PollOption = PollOption.objects.create( + poll=self.poll_1, choice="hello!" + ) + self.poll_option_2: PollOption = PollOption.objects.create( + poll=self.poll_1, choice="hello!!!!" + ) + self.poll_option_3: PollOption = PollOption.objects.create( + poll=self.poll_1, choice="hello!!!!!!!" + ) - self.poll_2 = Poll.objects.create( + self.poll_2: Poll = Poll.objects.create( club_code="pennlabs", question="poll question 2", expire_date=timezone.now() + datetime.timedelta(days=1), status=Poll.STATUS_APPROVED, ) - self.poll_vote = PollVote.objects.create(id_hash="2", poll=self.poll_1) + self.poll_vote: PollVote = PollVote.objects.create(id_hash="2", poll=self.poll_1) self.poll_vote.poll_options.add(self.poll_option_1) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) - def test_authentication(self): + def test_authentication(self) -> None: # asserts that anonymous users cannot access any route list_urls = [ "poll-list", @@ -70,7 +80,7 @@ def test_authentication(self): @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.views.get_user_clubs", mock_get_user_clubs) - def test_update_poll(self): + def test_update_poll(self) -> None: # users in same club can edit self.client.force_authenticate(user=self.user2) payload_1 = {"status": Poll.STATUS_REVISION} @@ -90,7 +100,7 @@ def test_update_poll(self): @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.views.get_user_clubs", mock_get_user_clubs) - def test_create_update_options(self): + def test_create_update_options(self) -> None: # users in same club can edit poll option self.client.force_authenticate(user=self.user2) payload_1 = {"poll": self.poll_1.id, "choice": "hello"} diff --git a/backend/tests/portal/test_polls.py b/backend/tests/portal/test_polls.py index 1002d803..1ca2e149 100644 --- a/backend/tests/portal/test_polls.py +++ b/backend/tests/portal/test_polls.py @@ -1,8 +1,8 @@ import datetime import json +from typing import Any from unittest import mock -from django.contrib.auth import get_user_model from django.core.management import call_command from django.test import TestCase from django.utils import timezone @@ -10,27 +10,25 @@ from portal.models import Poll, PollOption, PollVote, TargetPopulation from utils.email import get_backend_manager_emails +from utils.types import DjangoUserModel, DjangoUserType -User = get_user_model() - - -def mock_get_user_clubs(*args, **kwargs): +def mock_get_user_clubs(*args: Any, **kwargs: Any) -> list[dict[str, Any]]: with open("tests/portal/get_user_clubs.json") as data: return json.load(data) -def mock_get_user_info(*args, **kwargs): +def mock_get_user_info(*args: Any, **kwargs: Any) -> dict[str, Any]: with open("tests/portal/get_user_info.json") as data: return json.load(data) -def mock_get_null_user_info(*args, **kwargs): +def mock_get_null_user_info(*args: Any, **kwargs: Any) -> dict[str, Any]: with open("tests/portal/get_null_user_info.json") as data: return json.load(data) -def mock_get_club_info(*args, **kwargs): +def mock_get_club_info(*args: Any, **kwargs: Any) -> dict[str, Any]: with open("tests/portal/get_club_info.json") as data: return json.load(data) @@ -38,22 +36,28 @@ def mock_get_club_info(*args, **kwargs): class TestUserClubs(TestCase): """Test User and Club information""" - def setUp(self): - self.client = APIClient() - self.test_user = User.objects.create_user("user", "user@seas.upenn.edu", "user") + def setUp(self) -> None: + self.client: APIClient = APIClient() + self.test_user: DjangoUserType = DjangoUserModel.objects.create_user( + "user", "user@seas.upenn.edu", "user" + ) self.client.force_authenticate(user=self.test_user) @mock.patch("portal.views.get_user_info", mock_get_user_info) - def test_user_info(self): + def test_user_info(self) -> None: response = self.client.get("/portal/user/") res_json = json.loads(response.content) + assert isinstance(res_json, dict) + assert isinstance(res_json["user"], dict) self.assertEqual(12345678, res_json["user"]["pennid"]) @mock.patch("portal.views.get_club_info", mock_get_club_info) @mock.patch("portal.views.get_user_clubs", mock_get_user_clubs) - def test_user_clubs(self): + def test_user_clubs(self) -> None: response = self.client.get("/portal/clubs/") res_json = json.loads(response.content) + assert isinstance(res_json, dict) + assert isinstance(res_json["clubs"], list) self.assertEqual("pennlabs", res_json["clubs"][0]["code"]) @@ -61,17 +65,21 @@ class TestPolls(TestCase): """Tests Create/Update/Retrieve for Polls and Poll Options""" @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) - def setUp(self): + def setUp(self) -> None: call_command("load_target_populations", "--years", "2022, 2023, 2024, 2025") - self.target_id = TargetPopulation.objects.get(population="2024").id - year = TargetPopulation.objects.get(population="2024").id + target = TargetPopulation.objects.get(population="2024") + self.target_id = target.id + year = target.id major = TargetPopulation.objects.get(population="Computer Science, BSE").id school = TargetPopulation.objects.get( population="School of Engineering and Applied Science" ).id degree = TargetPopulation.objects.get(population="BACHELORS").id - self.client = APIClient() - self.test_user = User.objects.create_user("user", "user@seas.upenn.edu", "user") + + self.client: APIClient = APIClient() + self.test_user: DjangoUserType = DjangoUserModel.objects.create_user( + "user", "user@seas.upenn.edu", "user" + ) self.client.force_authenticate(user=self.test_user) # creates an approved poll to work with payload = { @@ -98,10 +106,10 @@ def setUp(self): poll.save() poll_1 = Poll.objects.get(question="How is your day") - self.id = poll_1.id + self.poll_id = poll_1.id @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) - def test_create_poll(self): + def test_create_poll(self) -> None: # creates an unapproved poll payload = { "club_code": "pennlabs", @@ -119,19 +127,19 @@ def test_create_poll(self): @mock.patch("portal.views.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) - def test_update_poll(self): + def test_update_poll(self) -> None: payload = { "question": "New question", } - response = self.client.patch(f"/portal/polls/{self.id}/", payload) + response = self.client.patch(f"/portal/polls/{self.poll_id}/", payload) res_json = json.loads(response.content) # asserts that the update worked - self.assertEqual(self.id, res_json["id"]) - self.assertEqual("New question", Poll.objects.get(id=self.id).question) + self.assertEqual(self.poll_id, res_json["id"]) + self.assertEqual("New question", Poll.objects.get(id=self.poll_id).question) @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.logic.get_user_info", mock_get_user_info) - def test_browse(self): + def test_browse(self) -> None: payload = { "club_code": "pennlabs", "question": "How is this question? 2", @@ -148,7 +156,7 @@ def test_browse(self): @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.logic.get_user_info", mock_get_null_user_info) - def test_null_user_info_browse(self): + def test_null_user_info_browse(self) -> None: # Asserts that a user with empty user info can access all available polls response = self.client.post("/portal/polls/browse/", {"id_hash": 1}) res_json = json.loads(response.content) @@ -158,38 +166,38 @@ def test_null_user_info_browse(self): @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.logic.get_user_info", mock_get_user_info) - def test_create_option(self): - payload_1 = {"poll": self.id, "choice": "yes!"} - payload_2 = {"poll": self.id, "choice": "no!"} + def test_create_option(self) -> None: + payload_1 = {"poll": self.poll_id, "choice": "yes!"} + payload_2 = {"poll": self.poll_id, "choice": "no!"} self.client.post("/portal/options/", payload_1) self.client.post("/portal/options/", payload_2) self.assertEqual(2, PollOption.objects.all().count()) # asserts options were created and were placed to right poll for poll_option in PollOption.objects.all(): - self.assertEqual(Poll.objects.get(id=self.id), poll_option.poll) + self.assertEqual(Poll.objects.get(id=self.poll_id), poll_option.poll) response = self.client.post("/portal/polls/browse/", {"id_hash": 1}) res_json = json.loads(response.content) self.assertEqual(2, len(res_json[0]["options"])) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.views.get_user_clubs", mock_get_user_clubs) - def test_update_option(self): - payload_1 = {"poll": self.id, "choice": "yes!"} + def test_update_option(self) -> None: + payload_1 = {"poll": self.poll_id, "choice": "yes!"} response = self.client.post("/portal/options/", payload_1) res_json = json.loads(response.content) self.assertEqual("yes!", PollOption.objects.get(id=res_json["id"]).choice) - payload_2 = {"poll": self.id, "choice": "no!"} + payload_2 = {"poll": self.poll_id, "choice": "no!"} # checks that poll's option was changed self.client.patch(f'/portal/options/{res_json["id"]}/', payload_2) self.assertEqual("no!", PollOption.objects.get(id=res_json["id"]).choice) - def test_review_poll(self): + def test_review_poll(self) -> None: Poll.objects.create( club_code="pennlabs", question="hello?", expire_date=timezone.now() + datetime.timedelta(days=3), ) - admin = User.objects.create_superuser("admin@example.com", "admin", "admin") + admin = DjangoUserModel.objects.create_superuser("admin@example.com", "admin", "admin") self.client.force_authenticate(user=admin) response = self.client.get("/portal/polls/review/") res_json = json.loads(response.content) @@ -200,12 +208,12 @@ def test_review_poll(self): @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.logic.get_user_info", mock_get_user_info) - def test_more_than_five_options(self): - payload_1 = {"poll": self.id, "choice": "1"} - payload_2 = {"poll": self.id, "choice": "2"} - payload_3 = {"poll": self.id, "choice": "3"} - payload_4 = {"poll": self.id, "choice": "4"} - payload_5 = {"poll": self.id, "choice": "5"} + def test_more_than_five_options(self) -> None: + payload_1 = {"poll": self.poll_id, "choice": "1"} + payload_2 = {"poll": self.poll_id, "choice": "2"} + payload_3 = {"poll": self.poll_id, "choice": "3"} + payload_4 = {"poll": self.poll_id, "choice": "4"} + payload_5 = {"poll": self.poll_id, "choice": "5"} self.client.post("/portal/options/", payload_1) self.client.post("/portal/options/", payload_2) self.client.post("/portal/options/", payload_3) @@ -214,17 +222,17 @@ def test_more_than_five_options(self): self.assertEqual(5, PollOption.objects.all().count()) # asserts options were created and were placed to right poll for poll_option in PollOption.objects.all(): - self.assertEqual(Poll.objects.get(id=self.id), poll_option.poll) + self.assertEqual(Poll.objects.get(id=self.poll_id), poll_option.poll) response = self.client.post("/portal/polls/browse/", {"id_hash": 1}) res_json = json.loads(response.content) self.assertEqual(5, len(res_json[0]["options"])) # adding more than 5 options to same poll should not be allowed - payload_6 = {"poll": self.id, "choice": "6"} + payload_6 = {"poll": self.poll_id, "choice": "6"} response = self.client.post("/portal/options/", payload_6) self.assertEqual(5, PollOption.objects.all().count()) - def test_option_vote_view(self): - response = self.client.get(f"/portal/polls/{self.id}/option_view/") + def test_option_vote_view(self) -> None: + response = self.client.get(f"/portal/polls/{self.poll_id}/option_view/") res_json = json.loads(response.content) self.assertEqual("pennlabs", res_json["club_code"]) # test that options key is in response @@ -233,7 +241,7 @@ def test_option_vote_view(self): @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("utils.email.send_automated_email.delay_on_commit") - def test_send_email_on_create(self, mock_send_email): + def test_send_email_on_create(self, mock_send_email: mock.Mock) -> None: payload = { "club_code": "pennlabs", "question": "How is this question? 2", @@ -249,7 +257,7 @@ def test_send_email_on_create(self, mock_send_email): @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("utils.email.send_automated_email.delay_on_commit") - def test_send_email_on_status_change(self, mock_send_email): + def test_send_email_on_status_change(self, mock_send_email: mock.Mock) -> None: payload = { "club_code": "pennlabs", "question": "How is this question? 2", @@ -262,6 +270,7 @@ def test_send_email_on_status_change(self, mock_send_email): mock_send_email.assert_called_once() poll = Poll.objects.last() + assert poll is not None poll.status = Poll.STATUS_REVISION poll.save() @@ -272,12 +281,14 @@ def test_send_email_on_status_change(self, mock_send_email): class TestPollVotes(TestCase): """Tests Create/Update Polls and History""" - def setUp(self): + def setUp(self) -> None: call_command("load_target_populations", "--years", "2022, 2023, 2024, 2025") self.target_id = TargetPopulation.objects.get(population="2024").id - self.client = APIClient() - self.test_user = User.objects.create_user("user", "user@seas.upenn.edu", "user") + self.client: APIClient = APIClient() + self.test_user: DjangoUserType = DjangoUserModel.objects.create_user( + "user", "user@seas.upenn.edu", "user" + ) self.client.force_authenticate(user=self.test_user) # creates 4 polls, each with 3 options @@ -331,27 +342,30 @@ def setUp(self): PollOption.objects.create(poll=p4, choice="choice 12") @mock.patch("portal.logic.get_user_info", mock_get_user_info) - def test_create_vote(self): + def test_create_vote(self) -> None: payload_1 = {"id_hash": 1, "poll_options": [self.p1_op1_id]} response = self.client.post("/portal/votes/", payload_1) res_json = json.loads(response.content) + assert isinstance(res_json, dict) # tests that voting works self.assertIn(self.p1_op1_id, res_json["poll_options"]) + vote = PollVote.objects.first() + assert vote is not None self.assertEqual(1, PollVote.objects.all().count()) - self.assertEqual("1", PollVote.objects.all().first().id_hash) + self.assertEqual("1", vote.id_hash) self.assertIn( TargetPopulation.objects.get(id=self.target_id), - PollVote.objects.all().first().target_populations.all(), + vote.target_populations.all(), ) - def test_recent_poll_empty(self): + def test_recent_poll_empty(self) -> None: response = self.client.post("/portal/votes/recent/", {"id_hash": 1}) res_json = json.loads(response.content) self.assertIsNone(res_json["created_date"]) self.assertIsNone(res_json["poll"]["created_date"]) @mock.patch("portal.logic.get_user_info", mock_get_user_info) - def test_recent_poll(self): + def test_recent_poll(self) -> None: # answer poll payload_1 = {"id_hash": 1, "poll_options": [self.p1_op1_id]} self.client.post("/portal/votes/", payload_1) @@ -368,7 +382,7 @@ def test_recent_poll(self): self.assertEquals(self.p4_id, res_json2["poll"]["id"]) @mock.patch("portal.logic.get_user_info", mock_get_user_info) - def test_all_votes(self): + def test_all_votes(self) -> None: payload_1 = {"id_hash": 1, "poll_options": [self.p1_op1_id]} self.client.post("/portal/votes/", payload_1) payload_2 = {"id_hash": 1, "poll_options": [self.p4_op1_id]} @@ -382,7 +396,7 @@ def test_all_votes(self): @mock.patch("portal.logic.get_user_info", mock_get_user_info) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) - def test_demographic_breakdown(self): + def test_demographic_breakdown(self) -> None: # plugging in votes for breakdown payload_1 = {"id_hash": 1, "poll_options": [self.p1_op1_id]} self.client.post("/portal/votes/", payload_1) diff --git a/backend/tests/portal/test_posts.py b/backend/tests/portal/test_posts.py index 4a2b5bae..26accc79 100644 --- a/backend/tests/portal/test_posts.py +++ b/backend/tests/portal/test_posts.py @@ -1,8 +1,8 @@ import datetime import json +from typing import Any, cast from unittest import mock -from django.contrib.auth import get_user_model from django.core.management import call_command from django.test import TestCase from django.utils import timezone @@ -10,26 +10,24 @@ from portal.models import Post, TargetPopulation from utils.email import get_backend_manager_emails +from utils.types import DjangoUserModel, DjangoUserType -User = get_user_model() - - -def mock_get_user_clubs(*args, **kwargs): +def mock_get_user_clubs(*args: Any, **kwargs: Any) -> list[dict]: with open("tests/portal/get_user_clubs.json") as data: return json.load(data) -def mock_get_no_clubs(*args, **kwargs): +def mock_get_no_clubs(*args: Any, **kwargs: Any) -> list[dict]: return [] -def mock_get_user_info(*args, **kwargs): +def mock_get_user_info(*args: Any, **kwargs: Any) -> list[dict]: with open("tests/portal/get_user_info.json") as data: return json.load(data) -def mock_get_club_info(*args, **kwargs): +def mock_get_club_info(*args: Any, **kwargs: Any) -> list[dict]: with open("tests/portal/get_club_info.json") as data: return json.load(data) @@ -38,11 +36,13 @@ class TestPosts(TestCase): """Tests Created/Update/Retrieve for Posts""" @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) - def setUp(self): + def setUp(self) -> None: call_command("load_target_populations", "--years", "2022, 2023, 2024, 2025") - self.target_id = TargetPopulation.objects.get(population="2024").id - self.client = APIClient() - self.test_user = User.objects.create_user("user", "user@seas.upenn.edu", "user") + self.target_id: int = TargetPopulation.objects.get(population="2024").id + self.client: APIClient = APIClient() + self.test_user: DjangoUserType = DjangoUserModel.objects.create_user( + "user", "user@seas.upenn.edu", "user" + ) self.client.force_authenticate(user=self.test_user) payload = { @@ -57,12 +57,13 @@ def setUp(self): } self.client.post("/portal/posts/", payload) post_1 = Post.objects.all().first() + assert post_1 is not None post_1.status = Post.STATUS_APPROVED post_1.save() - self.id = post_1.id + self.post_id = post_1.id @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) - def test_create_post(self): + def test_create_post(self) -> None: # Creates an unapproved post payload = { "club_code": "pennlabs", @@ -81,7 +82,7 @@ def test_create_post(self): self.assertEqual(None, Post.objects.get(id=res_json["id"]).admin_comment) @mock.patch("portal.serializers.get_user_clubs", mock_get_no_clubs) - def test_fail_post(self): + def test_fail_post(self) -> None: # Creates an unapproved post payload = { "club_code": "pennlabs", @@ -101,29 +102,31 @@ def test_fail_post(self): @mock.patch("portal.views.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) - def test_update_post(self): + def test_update_post(self) -> None: payload = {"title": "New Test Title 3"} - response = self.client.patch(f"/portal/posts/{self.id}/", payload) + response = self.client.patch(f"/portal/posts/{self.post_id}/", payload) res_json = json.loads(response.content) - self.assertEqual(self.id, res_json["id"]) - self.assertEqual("New Test Title 3", Post.objects.get(id=self.id).title) + self.assertEqual(self.post_id, res_json["id"]) + self.assertEqual("New Test Title 3", Post.objects.get(id=self.post_id).title) # since the user is not an admin, approved should be set to false after update self.assertEqual(Post.STATUS_DRAFT, res_json["status"]) @mock.patch("portal.views.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) - def test_update_post_admin(self): - admin = User.objects.create_superuser("admin@upenn.edu", "admin", "admin") + def test_update_post_admin(self) -> None: + admin: DjangoUserType = DjangoUserModel.objects.create_superuser( + "admin@upenn.edu", "admin", "admin" + ) self.client.force_authenticate(user=admin) payload = {"title": "New Test Title 3"} - response = self.client.patch(f"/portal/posts/{self.id}/", payload) + response = self.client.patch(f"/portal/posts/{self.post_id}/", payload) res_json = json.loads(response.content) - self.assertEqual(self.id, res_json["id"]) + self.assertEqual(self.post_id, res_json["id"]) self.assertEqual(Post.STATUS_APPROVED, res_json["status"]) @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.logic.get_user_info", mock_get_user_info) - def test_browse(self): + def test_browse(self) -> None: payload = { "club_code": "pennlabs", "title": "Test Title 2", @@ -139,7 +142,7 @@ def test_browse(self): self.assertEqual(1, len(res_json)) self.assertEqual(2, Post.objects.all().count()) - def test_review_post_no_admin_comment(self): + def test_review_post_no_admin_comment(self) -> None: # No admin comment Post.objects.create( club_code="notpennlabs", @@ -147,7 +150,9 @@ def test_review_post_no_admin_comment(self): subtitle="Test subtitle 2", expire_date=timezone.localtime() + datetime.timedelta(days=1), ) - admin = User.objects.create_superuser("admin@upenn.edu", "admin", "admin") + admin: DjangoUserType = DjangoUserModel.objects.create_superuser( + "admin@upenn.edu", "admin", "admin" + ) self.client.force_authenticate(user=admin) response = self.client.get("/portal/posts/review/") res_json = json.loads(response.content) @@ -158,7 +163,7 @@ def test_review_post_no_admin_comment(self): @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("utils.email.send_automated_email.delay_on_commit") - def test_send_email_on_create(self, mock_send_email): + def test_send_email_on_create(self, mock_send_email: mock.Mock) -> None: payload = { "club_code": "pennlabs", "title": "Test Title 2", @@ -175,7 +180,7 @@ def test_send_email_on_create(self, mock_send_email): @mock.patch("portal.serializers.get_user_clubs", mock_get_user_clubs) @mock.patch("portal.permissions.get_user_clubs", mock_get_user_clubs) @mock.patch("utils.email.send_automated_email.delay_on_commit") - def test_send_email_on_status_change(self, mock_send_email): + def test_send_email_on_status_change(self, mock_send_email: mock.Mock) -> None: payload = { "club_code": "pennlabs", "title": "Test Title 2", @@ -190,8 +195,10 @@ def test_send_email_on_status_change(self, mock_send_email): mock_send_email.assert_called_once() post = Post.objects.last() + assert post is not None + creator = cast(DjangoUserType, post.creator) post.status = Post.STATUS_APPROVED post.save() self.assertEqual(mock_send_email.call_count, 2) - self.assertEqual(mock_send_email.call_args[0][1], [post.creator.email]) + self.assertEqual(mock_send_email.call_args[0][1], [creator.email]) diff --git a/backend/utils/types.py b/backend/utils/types.py new file mode 100644 index 00000000..2e5df190 --- /dev/null +++ b/backend/utils/types.py @@ -0,0 +1,42 @@ +from typing import Any, Type, TypeAlias, TypeVar, Union, cast, Protocol +from django.contrib.auth import get_user_model +from django.contrib.auth.models import AbstractUser, AnonymousUser +from django.db.models import Manager, Model, QuerySet +from rest_framework.request import Request + + +# Get the actual User model +DjangoUser = get_user_model() + +class UserManager(Protocol): + def create_user(self, username: str, email: str, password: str) -> 'DjangoUserType': ... + def create_superuser(self, username: str, email: str, password: str) -> 'DjangoUserType': ... + def get(self, **kwargs: Any) -> 'DjangoUserType': ... + def filter(self, **kwargs: Any) -> QuerySet['DjangoUserType']: ... + def all(self) -> QuerySet['DjangoUserType']: ... + +class DjangoUserType(AbstractUser, Protocol): + objects: UserManager + is_superuser: bool + id: int + username: str + is_authenticated: bool + +DjangoUserModel: Type[DjangoUserType] = cast(Type[DjangoUserType], get_user_model()) + +# Union type for all possible user types +UserType = Union[DjangoUserType, AnonymousUser] + +# Type for authenticated Django user requests +class AuthRequest(Request): + user: DjangoUserType + +# Helper function to safely cast user to DjangoUserType +def get_auth_user(request: Request) -> DjangoUserType: + if not request.user.is_authenticated: + raise ValueError("User must be authenticated") + return cast(DjangoUserType, request.user) + +# QuerySet type helpers +ModelT = TypeVar("ModelT", bound=Model) +ModelQuerySet: TypeAlias = QuerySet[ModelT, Manager[ModelT]]