From 070b89563302b231e3867743e35d11c47a976ea2 Mon Sep 17 00:00:00 2001 From: MadsKK <35922365+MadsKK@users.noreply.github.com> Date: Sun, 14 Nov 2021 14:28:21 +0100 Subject: [PATCH] Fixed issue #363 Upon refreshing the token a new Outstanding token is created in the serializers.py where the user from the blacklisted token is added to the new refresh token. This insures that there is always an Outstanding token for each refresh token in use. Therefore, this will solve the issue of logging out from all devices by blacklisting all the Outstanding tokens linked to that specific user. --- rest_framework_simplejwt/serializers.py | 26 ++++++++++++++++++++++--- rest_framework_simplejwt/tokens.py | 15 ++++++-------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/rest_framework_simplejwt/serializers.py b/rest_framework_simplejwt/serializers.py index f2b12ab7e..bb5fb8570 100644 --- a/rest_framework_simplejwt/serializers.py +++ b/rest_framework_simplejwt/serializers.py @@ -8,8 +8,11 @@ from .settings import api_settings from .tokens import RefreshToken, SlidingToken, UntypedToken +from .authentication import JWTAuthentication +from .utils import datetime_from_epoch + if api_settings.BLACKLIST_AFTER_ROTATION: - from .token_blacklist.models import BlacklistedToken + from .token_blacklist.models import BlacklistedToken, OutstandingToken class PasswordField(serializers.CharField): @@ -57,7 +60,8 @@ def validate(self, attrs): @classmethod def get_token(cls, user): - raise NotImplementedError('Must implement `get_token` method for `TokenObtainSerializer` subclasses') + raise NotImplementedError( + 'Must implement `get_token` method for `TokenObtainSerializer` subclasses') class TokenObtainPairSerializer(TokenObtainSerializer): @@ -104,9 +108,12 @@ class TokenRefreshSerializer(serializers.Serializer): def validate(self, attrs): refresh = RefreshToken(attrs['refresh']) - data = {'access': str(refresh.access_token)} + data = {} if api_settings.ROTATE_REFRESH_TOKENS: + auth = JWTAuthentication() + user = auth.get_user(validated_token=refresh) + if api_settings.BLACKLIST_AFTER_ROTATION: try: # Attempt to blacklist the given refresh token @@ -120,8 +127,21 @@ def validate(self, attrs): refresh.set_exp() refresh.set_iat() + OutstandingToken.objects.create( + user=user, + jti=refresh[api_settings.JTI_CLAIM], + token=str(refresh), + created_at=refresh.current_time, + expires_at=datetime_from_epoch(refresh['exp']) + ) + data['refresh'] = str(refresh) + data['access'] = str(refresh.access_token) + + else: + data['access'] = str(refresh.access_token) + return data diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index 75ff573ea..c4d7020ec 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -165,7 +165,8 @@ def check_exp(self, claim='exp', current_time=None): claim_time = datetime_from_epoch(claim_value) if claim_time <= current_time: - raise TokenError(format_lazy(_("Token '{}' claim has expired"), claim)) + raise TokenError(format_lazy( + _("Token '{}' claim has expired"), claim)) @classmethod def for_user(cls, user): @@ -181,9 +182,9 @@ def for_user(cls, user): token[api_settings.USER_ID_CLAIM] = user_id return token - + _token_backend = None - + def get_token_backend(self): if self._token_backend is None: self._token_backend = import_string( @@ -223,13 +224,9 @@ def blacklist(self): jti = self.payload[api_settings.JTI_CLAIM] exp = self.payload['exp'] - # Ensure outstanding token exists with given jti - token, _ = OutstandingToken.objects.get_or_create( + # Outstanding token will always exist + token = OutstandingToken.objects.get( jti=jti, - defaults={ - 'token': str(self), - 'expires_at': datetime_from_epoch(exp), - }, ) return BlacklistedToken.objects.get_or_create(token=token)