diff --git a/ansible_ai_connect/main/settings/base.py b/ansible_ai_connect/main/settings/base.py index ffcb398ba..20520a229 100644 --- a/ansible_ai_connect/main/settings/base.py +++ b/ansible_ai_connect/main/settings/base.py @@ -98,7 +98,7 @@ "django.contrib.admin", "django.contrib.auth", "django.contrib.contenttypes", - "django.contrib.sessions", + "django.contrib.sessions", # Used by the admin dashboard "django.contrib.messages", "django.contrib.staticfiles", "rest_framework", @@ -216,9 +216,6 @@ "oauth2_provider.backends.OAuth2Backend", ] -SOCIAL_AUTH_FIELDS_STORED_IN_SESSION = [ - "terms_accepted", -] SOCIAL_AUTH_PIPELINE = ( "ansible_ai_connect.users.pipeline.block_auth_users", "social_core.pipeline.social_auth.social_details", @@ -233,7 +230,6 @@ "social_core.pipeline.social_auth.associate_user", "social_core.pipeline.user.user_details", "ansible_ai_connect.users.pipeline.load_extra_data", - "ansible_ai_connect.users.pipeline.terms_of_service", ) # Wisdom Eng Team: diff --git a/ansible_ai_connect/main/urls.py b/ansible_ai_connect/main/urls.py index 3b5747a4c..fe3c2d728 100644 --- a/ansible_ai_connect/main/urls.py +++ b/ansible_ai_connect/main/urls.py @@ -73,11 +73,7 @@ path("unauthorized/", UnauthorizedView.as_view(), name="unauthorized"), path("check/status/", WisdomServiceHealthView.as_view(), name="health_check"), path("check/", WisdomServiceLivenessProbeView.as_view(), name="liveness_probe"), - path( - "community-terms/", - TermsOfService.as_view(template_name="users/community-terms.html"), - name="community_terms", - ), + path("community-terms/", TermsOfService.as_view(), name="community_terms"), path("o/", include((base_urlpatterns, app_name), namespace="oauth2_provider")), path( "login/", diff --git a/ansible_ai_connect/users/pipeline.py b/ansible_ai_connect/users/pipeline.py index 366cd57b8..d704d4ae6 100644 --- a/ansible_ai_connect/users/pipeline.py +++ b/ansible_ai_connect/users/pipeline.py @@ -17,10 +17,7 @@ import jwt from django.conf import settings from django.contrib.auth import get_user_model -from django.urls import reverse -from django.utils import timezone -from social_core.exceptions import AuthCanceled, AuthException -from social_core.pipeline.partial import partial +from social_core.exceptions import AuthException from social_core.pipeline.user import get_username from social_django.models import UserSocialAuth @@ -126,41 +123,6 @@ def redhat_organization(backend, user, response, *args, **kwargs): } -def _terms_of_service(strategy, user, backend, **kwargs): - accepted = "terms_accepted" - is_commercial = user.rh_user_has_seat - if not settings.ANSIBLE_AI_ENABLE_TECH_PREVIEW: - return {accepted: True} - # Commercial & local users are not presented with T&C page in login flow (new & existing users) - if settings.TERMS_NOT_APPLICABLE or is_commercial: - return {accepted: True} - - field_name = "community_terms_accepted" - view_name = "community_terms" - terms_accepted = strategy.session_get(accepted, None) - if getattr(user, field_name, None) is not None: - # User had previously accepted, so short-circuit the T&C page. - return {accepted: True} - - if terms_accepted is None: - # We haven't gone through the flow yet -- go to the T&C page - current_partial = kwargs.get("current_partial") - return strategy.redirect(f"{reverse(view_name)}?partial_token={current_partial.token}") - - if not terms_accepted: - raise AuthCanceled("Terms and conditions were not accepted.") - - # We've accepted the T&C, set the field on the user. - setattr(user, field_name, timezone.now()) - user.save() - return {accepted: terms_accepted} - - -@partial -def terms_of_service(strategy, details, backend, user=None, is_new=False, *args, **kwargs): - return _terms_of_service(strategy, user, backend, **kwargs) - - class AuthAlreadyLoggedIn(AuthException): def __str__(self): return "User already logged in" diff --git a/ansible_ai_connect/users/tests/test_users.py b/ansible_ai_connect/users/tests/test_users.py index 96f4c9209..7d5d870bf 100644 --- a/ansible_ai_connect/users/tests/test_users.py +++ b/ansible_ai_connect/users/tests/test_users.py @@ -15,9 +15,8 @@ import random import string from http import HTTPStatus -from types import SimpleNamespace from typing import Optional -from unittest.mock import Mock, patch +from unittest.mock import patch from uuid import uuid4 from django.apps import apps @@ -26,10 +25,8 @@ from django.core.cache import cache from django.test import override_settings from django.urls import reverse -from django.utils import timezone from prometheus_client.parser import text_string_to_metric_families from rest_framework.test import APITransactionTestCase -from social_core.exceptions import AuthCanceled from social_django.models import UserSocialAuth import ansible_ai_connect.ai.feature_flags as feature_flags @@ -48,8 +45,6 @@ USER_SOCIAL_AUTH_PROVIDER_GITHUB, USER_SOCIAL_AUTH_PROVIDER_OIDC, ) -from ansible_ai_connect.users.pipeline import _terms_of_service -from ansible_ai_connect.users.views import TermsOfService def create_user( @@ -119,237 +114,6 @@ def test_users_audit_logging(self): self.assertInLog("LOGIN successful", log) -@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True) -class TestTermsAndConditions(WisdomServiceLogAwareTestCase): - def setUp(self) -> None: - super().setUp() - - class MockSession(dict): - def save(self): - pass - - class MockRequest: - GET = {} - POST = {} - - def __init__(self): - self.session = MockSession() - - class MockBackend: - name = "github" - - class MockStrategy: - session = None - redirect_url = None - - def redirect(self, redirect_url): - self.redirect_url = redirect_url - - def partial_load(self, partial_token): - if partial_token == "invalid_token": - return None - else: - return SimpleNamespace(backend="backend", token=partial_token) - - def session_get(self, key, default=None): - return self.session.get(key, default) - - # class MockUser: - # community_terms_accepted = None - # saved = False - - # def save(self): - # self.saved = True - - self.request = MockRequest() - self.backend = MockBackend() - self.strategy = MockStrategy() - self.strategy.session = self.request.session - self.partial = SimpleNamespace(token="token") - self.user = Mock( - community_terms_accepted=None, commercial_terms_accepted=None, rh_user_has_seat=False - ) - cache.clear() - - def test_terms_of_service_community_first_call(self): - _terms_of_service( - self.strategy, - self.user, - self.backend, - request=self.request, - current_partial=self.partial, - ) - self.assertIsNone(self.request.session.get("terms_accepted", None)) - self.assertEqual(self.strategy.redirect_url, "/community-terms/?partial_token=token") - self.assertFalse(self.user.save.called) - self.assertIsNone(self.user.community_terms_accepted) - - def test_terms_of_service_first_commercial(self): - # We must be using the Red Hat SSO and be a member of the Community placeholder group - # Commercial Users enclosed Terms of Service by default earlier, no need to ask them again - self.backend.name = "oidc" - self.user.rh_user_has_seat = True - - _terms_of_service( - self.strategy, - self.user, - self.backend, - request=self.request, - current_partial=self.partial, - ) - self.assertIsNone(self.request.session.get("terms_accepted", None)) - self.assertNotEqual(self.strategy.redirect_url, "/community-terms/?partial_token=token") - self.assertFalse(self.user.save.called) - self.assertIsNone(self.user.community_terms_accepted) - - def test_terms_of_service_commercial_previously_accepted(self): - now = timezone.now() - self.user.community_terms_accepted = now - self.backend.name = "oidc" - self.user.rh_user_has_seat = True - _terms_of_service( - self.strategy, - self.user, - self.backend, - request=self.request, - current_partial=self.partial, - ) - - self.assertNotEqual(self.strategy.redirect_url, "/community-terms/?partial_token=token") - self.assertFalse(self.user.save.called) - self.assertEqual(self.user.community_terms_accepted, now) - - def test_terms_of_service_community_previously_accepted(self): - now = timezone.now() - self.user.community_terms_accepted = now - _terms_of_service( - self.strategy, - self.user, - self.backend, - request=self.request, - current_partial=self.partial, - ) - - self.assertNotEqual(self.strategy.redirect_url, "/community-terms/?partial_token=token") - self.assertFalse(self.user.save.called) - self.assertEqual(self.user.community_terms_accepted, now) - - def test_terms_of_service_with_acceptance(self): - self.request.session["terms_accepted"] = True - _terms_of_service( - self.strategy, - self.user, - self.backend, - request=self.request, - current_partial=self.partial, - ) - self.assertTrue(self.user.save.called) - self.assertIsNotNone(self.user.community_terms_accepted) - - def test_terms_of_service_without_acceptance(self): - self.request.session["terms_accepted"] = False - with self.assertRaises(AuthCanceled): - _terms_of_service( - self.strategy, - self.user, - self.backend, - request=self.request, - current_partial=self.partial, - ) - self.assertFalse(self.user.save.called) - self.assertIsNone(self.user.community_terms_accepted) - - @override_settings(TERMS_NOT_APPLICABLE=True) - def test_terms_of_service_with_override(self): - self.request.session["terms_accepted"] = False - result = _terms_of_service( - self.strategy, - self.user, - self.backend, - request=self.request, - current_partial=self.partial, - ) - self.assertEqual(result, {"terms_accepted": True}) - self.assertIsNone(self.strategy.redirect_url) - self.assertFalse(self.user.save.called) - self.assertIsNone(self.user.community_terms_accepted) - - @override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=False) - def test_terms_of_service_after_tech_preview(self): - self.request.session["terms_accepted"] = False - result = _terms_of_service( - self.strategy, - self.user, - self.backend, - request=self.request, - current_partial=self.partial, - ) - self.assertEqual(result, {"terms_accepted": True}) - self.assertIsNone(self.strategy.redirect_url) - self.assertFalse(self.user.save.called) - self.assertIsNone(self.user.community_terms_accepted) - - @patch("social_django.utils.get_strategy") - def test_post_accepted(self, get_strategy): - get_strategy.return_value = self.strategy - self.request.POST["partial_token"] = "token" - self.request.POST["accepted"] = "True" - view = TermsOfService(template_name="users/community-terms.html") - view.post(self.request) - self.assertTrue(self.request.session["terms_accepted"]) - self.assertEqual("/complete/backend/?partial_token=token", self.strategy.redirect_url) - - @patch("social_django.utils.get_strategy") - def test_post_not_accepted(self, get_strategy): - get_strategy.return_value = self.strategy - self.request.POST["partial_token"] = "token" - self.request.POST["accepted"] = "False" - view = TermsOfService(template_name="users/community-terms.html") - view.post(self.request) - self.assertFalse(self.request.session["terms_accepted"]) - self.assertEqual("/complete/backend/?partial_token=token", self.strategy.redirect_url) - - @patch("social_django.utils.get_strategy") - def test_post_without_partial_token(self, get_strategy): - get_strategy.return_value = self.strategy - # self.request.POST['partial_token'] = 'token' - self.request.POST["accepted"] = "False" - view = TermsOfService(template_name="users/community-terms.html") - with self.assertLogs(logger="root", level="WARN") as log: - res = view.post(self.request) - self.assertEqual(400, res.status_code) - self.assertInLog("POST TermsOfService was invoked without partial_token", log) - - @patch("social_django.utils.get_strategy") - def test_post_with_invalid_partial_token(self, get_strategy): - get_strategy.return_value = self.strategy - self.request.POST["partial_token"] = "invalid_token" - self.request.POST["accepted"] = "False" - view = TermsOfService(template_name="users/community-terms.html") - with self.assertLogs(logger="root", level="ERROR") as log: - res = view.post(self.request) - self.assertEqual(400, res.status_code) - self.assertInLog("strategy.partial_load(partial_token) returned None", log) - - def test_get(self): - view = TermsOfService(template_name="users/community-terms.html") - setattr(view, "request", self.request) # needed for TemplateResponseMixin - self.request.GET["partial_token"] = "token" - res = view.get(self.request) - self.assertEqual(200, res.status_code) - self.assertIn("form", res.context_data) - self.assertIn("partial_token", res.context_data) - - def test_get_without_partial_token(self): - view = TermsOfService(template_name="users/community-terms.html") - setattr(view, "request", self.request) # needed for TemplateResponseMixin - # self.request.GET['partial_token'] = 'token' - with self.assertLogs(logger="root", level="WARN") as log: - res = view.get(self.request) - self.assertEqual(403, res.status_code) - self.assertInLog("GET TermsOfService was invoked without partial_token", log) - - @override_settings(WCA_SECRET_BACKEND_TYPE="dummy") @override_settings(WCA_SECRET_DUMMY_SECRETS="1981:valid") @override_settings(AUTHZ_BACKEND_TYPE="dummy") diff --git a/ansible_ai_connect/users/views.py b/ansible_ai_connect/users/views.py index 43a587575..374e4ee99 100644 --- a/ansible_ai_connect/users/views.py +++ b/ansible_ai_connect/users/views.py @@ -17,15 +17,15 @@ from django.apps import apps from django.conf import settings from django.forms import Form -from django.http import HttpResponseBadRequest, HttpResponseForbidden +from django.http import HttpResponseRedirect from django.urls import reverse from django.utils.decorators import method_decorator +from django.utils.timezone import now from django.views.generic import TemplateView from rest_framework.generics import RetrieveAPIView from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.throttling import UserRateThrottle -from social_django.utils import load_strategy from ansible_ai_connect.ai.api.aws.exceptions import ( WcaSecretManagerMissingCredentialsError, @@ -104,37 +104,14 @@ def retrieve(self, request, *args, **kwargs): class TermsOfService(TemplateView): - template_name = None # passed in via the urlpatterns - extra_context = { - "form": Form(), - } - - def get(self, request, *args, **kwargs): - partial_token = request.GET.get("partial_token") - self.extra_context["partial_token"] = partial_token - if partial_token is None: - logger.warning("GET TermsOfService was invoked without partial_token") - return HttpResponseForbidden() - return super().get(request, args, kwargs) + template_name = "users/community-terms.html" def post(self, request, *args, **kwargs): form = Form(request.POST) form.is_valid() - partial_token = form.data.get("partial_token") - if partial_token is None: - logger.warning("POST TermsOfService was invoked without partial_token") - return HttpResponseBadRequest() - - strategy = load_strategy() - partial = strategy.partial_load(partial_token) - if partial is None: - logger.error("strategy.partial_load(partial_token) returned None") - return HttpResponseBadRequest() - - accepted = request.POST.get("accepted") == "True" - request.session["terms_accepted"] = accepted - request.session.save() - - backend = partial.backend - complete = reverse("social:complete", kwargs={"backend": backend}) - return strategy.redirect(complete + f"?partial_token={partial.token}") + + if request.POST.get("accepted") == "True": + request.user.commercial_terms_accepted = now() + request.session.save() + + return HttpResponseRedirect(reverse("home"))