Skip to content

Commit

Permalink
Merge pull request #1176 from ansible/goneri/convert-the-T-C-flow-as-…
Browse files Browse the repository at this point in the history
…a-regular-post-login-page_5239

convert the T&C flow as a regular post-login page
  • Loading branch information
goneri authored Jul 10, 2024
2 parents 641cd11 + 96611e6 commit e89e315
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 318 deletions.
6 changes: 1 addition & 5 deletions ansible_ai_connect/main/settings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
"django.contrib.admin",
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"django.contrib.sessions", # Used by the admin dashboard
"django.contrib.messages",
"django.contrib.staticfiles",
"rest_framework",
Expand Down Expand Up @@ -216,9 +216,6 @@
"oauth2_provider.backends.OAuth2Backend",
]

SOCIAL_AUTH_FIELDS_STORED_IN_SESSION = [
"terms_accepted",
]
SOCIAL_AUTH_PIPELINE = (
"ansible_ai_connect.users.pipeline.block_auth_users",
"social_core.pipeline.social_auth.social_details",
Expand All @@ -233,7 +230,6 @@
"social_core.pipeline.social_auth.associate_user",
"social_core.pipeline.user.user_details",
"ansible_ai_connect.users.pipeline.load_extra_data",
"ansible_ai_connect.users.pipeline.terms_of_service",
)

# Wisdom Eng Team:
Expand Down
6 changes: 1 addition & 5 deletions ansible_ai_connect/main/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@
path("unauthorized/", UnauthorizedView.as_view(), name="unauthorized"),
path("check/status/", WisdomServiceHealthView.as_view(), name="health_check"),
path("check/", WisdomServiceLivenessProbeView.as_view(), name="liveness_probe"),
path(
"community-terms/",
TermsOfService.as_view(template_name="users/community-terms.html"),
name="community_terms",
),
path("community-terms/", TermsOfService.as_view(), name="community_terms"),
path("o/", include((base_urlpatterns, app_name), namespace="oauth2_provider")),
path(
"login/",
Expand Down
40 changes: 1 addition & 39 deletions ansible_ai_connect/users/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
import jwt
from django.conf import settings
from django.contrib.auth import get_user_model
from django.urls import reverse
from django.utils import timezone
from social_core.exceptions import AuthCanceled, AuthException
from social_core.pipeline.partial import partial
from social_core.exceptions import AuthException
from social_core.pipeline.user import get_username
from social_django.models import UserSocialAuth

Expand Down Expand Up @@ -126,41 +123,6 @@ def redhat_organization(backend, user, response, *args, **kwargs):
}


def _terms_of_service(strategy, user, backend, **kwargs):
accepted = "terms_accepted"
is_commercial = user.rh_user_has_seat
if not settings.ANSIBLE_AI_ENABLE_TECH_PREVIEW:
return {accepted: True}
# Commercial & local users are not presented with T&C page in login flow (new & existing users)
if settings.TERMS_NOT_APPLICABLE or is_commercial:
return {accepted: True}

field_name = "community_terms_accepted"
view_name = "community_terms"
terms_accepted = strategy.session_get(accepted, None)
if getattr(user, field_name, None) is not None:
# User had previously accepted, so short-circuit the T&C page.
return {accepted: True}

if terms_accepted is None:
# We haven't gone through the flow yet -- go to the T&C page
current_partial = kwargs.get("current_partial")
return strategy.redirect(f"{reverse(view_name)}?partial_token={current_partial.token}")

if not terms_accepted:
raise AuthCanceled("Terms and conditions were not accepted.")

# We've accepted the T&C, set the field on the user.
setattr(user, field_name, timezone.now())
user.save()
return {accepted: terms_accepted}


@partial
def terms_of_service(strategy, details, backend, user=None, is_new=False, *args, **kwargs):
return _terms_of_service(strategy, user, backend, **kwargs)


class AuthAlreadyLoggedIn(AuthException):
def __str__(self):
return "User already logged in"
Expand Down
238 changes: 1 addition & 237 deletions ansible_ai_connect/users/tests/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
import random
import string
from http import HTTPStatus
from types import SimpleNamespace
from typing import Optional
from unittest.mock import Mock, patch
from unittest.mock import patch
from uuid import uuid4

from django.apps import apps
Expand All @@ -26,10 +25,8 @@
from django.core.cache import cache
from django.test import override_settings
from django.urls import reverse
from django.utils import timezone
from prometheus_client.parser import text_string_to_metric_families
from rest_framework.test import APITransactionTestCase
from social_core.exceptions import AuthCanceled
from social_django.models import UserSocialAuth

import ansible_ai_connect.ai.feature_flags as feature_flags
Expand All @@ -48,8 +45,6 @@
USER_SOCIAL_AUTH_PROVIDER_GITHUB,
USER_SOCIAL_AUTH_PROVIDER_OIDC,
)
from ansible_ai_connect.users.pipeline import _terms_of_service
from ansible_ai_connect.users.views import TermsOfService


def create_user(
Expand Down Expand Up @@ -119,237 +114,6 @@ def test_users_audit_logging(self):
self.assertInLog("LOGIN successful", log)


@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=True)
class TestTermsAndConditions(WisdomServiceLogAwareTestCase):
def setUp(self) -> None:
super().setUp()

class MockSession(dict):
def save(self):
pass

class MockRequest:
GET = {}
POST = {}

def __init__(self):
self.session = MockSession()

class MockBackend:
name = "github"

class MockStrategy:
session = None
redirect_url = None

def redirect(self, redirect_url):
self.redirect_url = redirect_url

def partial_load(self, partial_token):
if partial_token == "invalid_token":
return None
else:
return SimpleNamespace(backend="backend", token=partial_token)

def session_get(self, key, default=None):
return self.session.get(key, default)

# class MockUser:
# community_terms_accepted = None
# saved = False

# def save(self):
# self.saved = True

self.request = MockRequest()
self.backend = MockBackend()
self.strategy = MockStrategy()
self.strategy.session = self.request.session
self.partial = SimpleNamespace(token="token")
self.user = Mock(
community_terms_accepted=None, commercial_terms_accepted=None, rh_user_has_seat=False
)
cache.clear()

def test_terms_of_service_community_first_call(self):
_terms_of_service(
self.strategy,
self.user,
self.backend,
request=self.request,
current_partial=self.partial,
)
self.assertIsNone(self.request.session.get("terms_accepted", None))
self.assertEqual(self.strategy.redirect_url, "/community-terms/?partial_token=token")
self.assertFalse(self.user.save.called)
self.assertIsNone(self.user.community_terms_accepted)

def test_terms_of_service_first_commercial(self):
# We must be using the Red Hat SSO and be a member of the Community placeholder group
# Commercial Users enclosed Terms of Service by default earlier, no need to ask them again
self.backend.name = "oidc"
self.user.rh_user_has_seat = True

_terms_of_service(
self.strategy,
self.user,
self.backend,
request=self.request,
current_partial=self.partial,
)
self.assertIsNone(self.request.session.get("terms_accepted", None))
self.assertNotEqual(self.strategy.redirect_url, "/community-terms/?partial_token=token")
self.assertFalse(self.user.save.called)
self.assertIsNone(self.user.community_terms_accepted)

def test_terms_of_service_commercial_previously_accepted(self):
now = timezone.now()
self.user.community_terms_accepted = now
self.backend.name = "oidc"
self.user.rh_user_has_seat = True
_terms_of_service(
self.strategy,
self.user,
self.backend,
request=self.request,
current_partial=self.partial,
)

self.assertNotEqual(self.strategy.redirect_url, "/community-terms/?partial_token=token")
self.assertFalse(self.user.save.called)
self.assertEqual(self.user.community_terms_accepted, now)

def test_terms_of_service_community_previously_accepted(self):
now = timezone.now()
self.user.community_terms_accepted = now
_terms_of_service(
self.strategy,
self.user,
self.backend,
request=self.request,
current_partial=self.partial,
)

self.assertNotEqual(self.strategy.redirect_url, "/community-terms/?partial_token=token")
self.assertFalse(self.user.save.called)
self.assertEqual(self.user.community_terms_accepted, now)

def test_terms_of_service_with_acceptance(self):
self.request.session["terms_accepted"] = True
_terms_of_service(
self.strategy,
self.user,
self.backend,
request=self.request,
current_partial=self.partial,
)
self.assertTrue(self.user.save.called)
self.assertIsNotNone(self.user.community_terms_accepted)

def test_terms_of_service_without_acceptance(self):
self.request.session["terms_accepted"] = False
with self.assertRaises(AuthCanceled):
_terms_of_service(
self.strategy,
self.user,
self.backend,
request=self.request,
current_partial=self.partial,
)
self.assertFalse(self.user.save.called)
self.assertIsNone(self.user.community_terms_accepted)

@override_settings(TERMS_NOT_APPLICABLE=True)
def test_terms_of_service_with_override(self):
self.request.session["terms_accepted"] = False
result = _terms_of_service(
self.strategy,
self.user,
self.backend,
request=self.request,
current_partial=self.partial,
)
self.assertEqual(result, {"terms_accepted": True})
self.assertIsNone(self.strategy.redirect_url)
self.assertFalse(self.user.save.called)
self.assertIsNone(self.user.community_terms_accepted)

@override_settings(ANSIBLE_AI_ENABLE_TECH_PREVIEW=False)
def test_terms_of_service_after_tech_preview(self):
self.request.session["terms_accepted"] = False
result = _terms_of_service(
self.strategy,
self.user,
self.backend,
request=self.request,
current_partial=self.partial,
)
self.assertEqual(result, {"terms_accepted": True})
self.assertIsNone(self.strategy.redirect_url)
self.assertFalse(self.user.save.called)
self.assertIsNone(self.user.community_terms_accepted)

@patch("social_django.utils.get_strategy")
def test_post_accepted(self, get_strategy):
get_strategy.return_value = self.strategy
self.request.POST["partial_token"] = "token"
self.request.POST["accepted"] = "True"
view = TermsOfService(template_name="users/community-terms.html")
view.post(self.request)
self.assertTrue(self.request.session["terms_accepted"])
self.assertEqual("/complete/backend/?partial_token=token", self.strategy.redirect_url)

@patch("social_django.utils.get_strategy")
def test_post_not_accepted(self, get_strategy):
get_strategy.return_value = self.strategy
self.request.POST["partial_token"] = "token"
self.request.POST["accepted"] = "False"
view = TermsOfService(template_name="users/community-terms.html")
view.post(self.request)
self.assertFalse(self.request.session["terms_accepted"])
self.assertEqual("/complete/backend/?partial_token=token", self.strategy.redirect_url)

@patch("social_django.utils.get_strategy")
def test_post_without_partial_token(self, get_strategy):
get_strategy.return_value = self.strategy
# self.request.POST['partial_token'] = 'token'
self.request.POST["accepted"] = "False"
view = TermsOfService(template_name="users/community-terms.html")
with self.assertLogs(logger="root", level="WARN") as log:
res = view.post(self.request)
self.assertEqual(400, res.status_code)
self.assertInLog("POST TermsOfService was invoked without partial_token", log)

@patch("social_django.utils.get_strategy")
def test_post_with_invalid_partial_token(self, get_strategy):
get_strategy.return_value = self.strategy
self.request.POST["partial_token"] = "invalid_token"
self.request.POST["accepted"] = "False"
view = TermsOfService(template_name="users/community-terms.html")
with self.assertLogs(logger="root", level="ERROR") as log:
res = view.post(self.request)
self.assertEqual(400, res.status_code)
self.assertInLog("strategy.partial_load(partial_token) returned None", log)

def test_get(self):
view = TermsOfService(template_name="users/community-terms.html")
setattr(view, "request", self.request) # needed for TemplateResponseMixin
self.request.GET["partial_token"] = "token"
res = view.get(self.request)
self.assertEqual(200, res.status_code)
self.assertIn("form", res.context_data)
self.assertIn("partial_token", res.context_data)

def test_get_without_partial_token(self):
view = TermsOfService(template_name="users/community-terms.html")
setattr(view, "request", self.request) # needed for TemplateResponseMixin
# self.request.GET['partial_token'] = 'token'
with self.assertLogs(logger="root", level="WARN") as log:
res = view.get(self.request)
self.assertEqual(403, res.status_code)
self.assertInLog("GET TermsOfService was invoked without partial_token", log)


@override_settings(WCA_SECRET_BACKEND_TYPE="dummy")
@override_settings(WCA_SECRET_DUMMY_SECRETS="1981:valid")
@override_settings(AUTHZ_BACKEND_TYPE="dummy")
Expand Down
Loading

0 comments on commit e89e315

Please sign in to comment.