Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue #363 #488

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions rest_framework_simplejwt/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
from .settings import api_settings
from .tokens import RefreshToken, SlidingToken, UntypedToken

from .authentication import JWTAuthentication
from .utils import datetime_from_epoch

if api_settings.BLACKLIST_AFTER_ROTATION:
from .token_blacklist.models import BlacklistedToken
from .token_blacklist.models import BlacklistedToken, OutstandingToken


class PasswordField(serializers.CharField):
Expand Down Expand Up @@ -57,7 +60,8 @@ def validate(self, attrs):

@classmethod
def get_token(cls, user):
raise NotImplementedError('Must implement `get_token` method for `TokenObtainSerializer` subclasses')
raise NotImplementedError(
'Must implement `get_token` method for `TokenObtainSerializer` subclasses')


class TokenObtainPairSerializer(TokenObtainSerializer):
Expand Down Expand Up @@ -104,9 +108,12 @@ class TokenRefreshSerializer(serializers.Serializer):
def validate(self, attrs):
refresh = RefreshToken(attrs['refresh'])

data = {'access': str(refresh.access_token)}
data = {}

if api_settings.ROTATE_REFRESH_TOKENS:
auth = JWTAuthentication()
user = auth.get_user(validated_token=refresh)

if api_settings.BLACKLIST_AFTER_ROTATION:
try:
# Attempt to blacklist the given refresh token
Expand All @@ -120,8 +127,21 @@ def validate(self, attrs):
refresh.set_exp()
refresh.set_iat()

OutstandingToken.objects.create(
user=user,
jti=refresh[api_settings.JTI_CLAIM],
token=str(refresh),
created_at=refresh.current_time,
expires_at=datetime_from_epoch(refresh['exp'])
)

data['refresh'] = str(refresh)

data['access'] = str(refresh.access_token)

else:
data['access'] = str(refresh.access_token)

return data


Expand Down
15 changes: 6 additions & 9 deletions rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ def check_exp(self, claim='exp', current_time=None):

claim_time = datetime_from_epoch(claim_value)
if claim_time <= current_time:
raise TokenError(format_lazy(_("Token '{}' claim has expired"), claim))
raise TokenError(format_lazy(
_("Token '{}' claim has expired"), claim))

@classmethod
def for_user(cls, user):
Expand All @@ -181,9 +182,9 @@ def for_user(cls, user):
token[api_settings.USER_ID_CLAIM] = user_id

return token

_token_backend = None

def get_token_backend(self):
if self._token_backend is None:
self._token_backend = import_string(
Expand Down Expand Up @@ -223,13 +224,9 @@ def blacklist(self):
jti = self.payload[api_settings.JTI_CLAIM]
exp = self.payload['exp']

# Ensure outstanding token exists with given jti
token, _ = OutstandingToken.objects.get_or_create(
# Outstanding token will always exist
token = OutstandingToken.objects.get(
jti=jti,
defaults={
'token': str(self),
'expires_at': datetime_from_epoch(exp),
},
)

return BlacklistedToken.objects.get_or_create(token=token)
Expand Down