diff --git a/rest_framework_jwt/views.py b/rest_framework_jwt/views.py index 30cd4646..cb7baff9 100644 --- a/rest_framework_jwt/views.py +++ b/rest_framework_jwt/views.py @@ -5,8 +5,9 @@ from .settings import api_settings from .serializers import ( - JSONWebTokenSerializer, RefreshJSONWebTokenSerializer, - VerifyJSONWebTokenSerializer + JSONWebTokenSerializer, + RefreshJSONWebTokenSerializer, + VerifyJSONWebTokenSerializer, ) jwt_response_payload_handler = api_settings.JWT_RESPONSE_PAYLOAD_HANDLER @@ -16,6 +17,7 @@ class JSONWebTokenAPIView(APIView): """ Base API View that various JWT interactions inherit from. """ + permission_classes = () authentication_classes = () @@ -23,10 +25,7 @@ def get_serializer_context(self): """ Extra context provided to the serializer class. """ - return { - 'request': self.request, - 'view': self, - } + return {"request": self.request, "view": self} def get_serializer_class(self): """ @@ -38,8 +37,8 @@ def get_serializer_class(self): """ assert self.serializer_class is not None, ( "'%s' should either include a `serializer_class` attribute, " - "or override the `get_serializer_class()` method." - % self.__class__.__name__) + "or override the `get_serializer_class()` method." % self.__class__.__name__ + ) return self.serializer_class def get_serializer(self, *args, **kwargs): @@ -48,24 +47,32 @@ def get_serializer(self, *args, **kwargs): deserializing input, and for serializing output. """ serializer_class = self.get_serializer_class() - kwargs['context'] = self.get_serializer_context() + kwargs["context"] = self.get_serializer_context() return serializer_class(*args, **kwargs) def post(self, request, *args, **kwargs): - serializer = self.get_serializer(data=request.data) + serializer_data = dict(request.data) + if ( + "token" not in request.data + and api_settings.JWT_AUTH_COOKIE + and api_settings.JWT_AUTH_COOKIE in request.COOKIES + ): + serializer_data["token"] = request.COOKIES[api_settings.JWT_AUTH_COOKIE] + serializer = self.get_serializer(data=serializer_data) if serializer.is_valid(): - user = serializer.object.get('user') or request.user - token = serializer.object.get('token') + user = serializer.object.get("user") or request.user + token = serializer.object.get("token") response_data = jwt_response_payload_handler(token, user, request) response = Response(response_data) if api_settings.JWT_AUTH_COOKIE: - expiration = (datetime.utcnow() + - api_settings.JWT_EXPIRATION_DELTA) - response.set_cookie(api_settings.JWT_AUTH_COOKIE, - token, - expires=expiration, - httponly=True) + expiration = datetime.utcnow() + api_settings.JWT_EXPIRATION_DELTA + response.set_cookie( + api_settings.JWT_AUTH_COOKIE, + token, + expires=expiration, + httponly=True, + ) return response return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) @@ -77,6 +84,7 @@ class ObtainJSONWebToken(JSONWebTokenAPIView): Returns a JSON Web Token that can be used for authenticated requests. """ + serializer_class = JSONWebTokenSerializer @@ -85,6 +93,7 @@ class VerifyJSONWebToken(JSONWebTokenAPIView): API View that checks the veracity of a token, returning the token if it is valid. """ + serializer_class = VerifyJSONWebTokenSerializer @@ -96,6 +105,7 @@ class RefreshJSONWebToken(JSONWebTokenAPIView): If 'orig_iat' field (original issued-at-time) is found, will first check if it's within expiration window, then copy it to the new token """ + serializer_class = RefreshJSONWebTokenSerializer diff --git a/tests/test_views.py b/tests/test_views.py index c8c72465..943ad127 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -2,6 +2,7 @@ from calendar import timegm from datetime import datetime, timedelta import time +from http.cookies import SimpleCookie from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric import rsa @@ -19,32 +20,25 @@ User = get_user_model() -NO_CUSTOM_USER_MODEL = 'Custom User Model only supported after Django 1.5' +NO_CUSTOM_USER_MODEL = "Custom User Model only supported after Django 1.5" orig_datetime = datetime class BaseTestCase(TestCase): - def setUp(self): - self.email = 'jpueblo@example.com' - self.username = 'jpueblo' - self.password = 'password' - self.user = User.objects.create_user( - self.username, self.email, self.password) + self.email = "jpueblo@example.com" + self.username = "jpueblo" + self.password = "password" + self.user = User.objects.create_user(self.username, self.email, self.password) - self.data = { - 'username': self.username, - 'password': self.password - } + self.data = {"username": self.username, "password": self.password} class TestCustomResponsePayload(BaseTestCase): - def setUp(self): self.original_handler = views.jwt_response_payload_handler - views.jwt_response_payload_handler = test_utils\ - .jwt_response_payload_handler + views.jwt_response_payload_handler = test_utils.jwt_response_payload_handler return super(TestCustomResponsePayload, self).setUp() def test_jwt_login_custom_response_json(self): @@ -53,32 +47,31 @@ def test_jwt_login_custom_response_json(self): """ client = APIClient(enforce_csrf_checks=True) - response = client.post('/auth-token/', self.data, format='json') + response = client.post("/auth-token/", self.data, format="json") - decoded_payload = utils.jwt_decode_handler(response.data['token']) + decoded_payload = utils.jwt_decode_handler(response.data["token"]) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(decoded_payload['username'], self.username) - self.assertEqual(response.data['user'], self.username) + self.assertEqual(decoded_payload["username"], self.username) + self.assertEqual(response.data["user"], self.username) def tearDown(self): views.jwt_response_payload_handler = self.original_handler class ObtainJSONWebTokenTests(BaseTestCase): - def test_jwt_login_json(self): """ Ensure JWT login view using JSON POST works. """ client = APIClient(enforce_csrf_checks=True) - response = client.post('/auth-token/', self.data, format='json') + response = client.post("/auth-token/", self.data, format="json") - decoded_payload = utils.jwt_decode_handler(response.data['token']) + decoded_payload = utils.jwt_decode_handler(response.data["token"]) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(decoded_payload['username'], self.username) + self.assertEqual(decoded_payload["username"], self.username) def test_jwt_login_json_bad_creds(self): """ @@ -87,8 +80,8 @@ def test_jwt_login_json_bad_creds(self): """ client = APIClient(enforce_csrf_checks=True) - self.data['password'] = 'wrong' - response = client.post('/auth-token/', self.data, format='json') + self.data["password"] = "wrong" + response = client.post("/auth-token/", self.data, format="json") self.assertEqual(response.status_code, 400) @@ -98,8 +91,9 @@ def test_jwt_login_json_missing_fields(self): """ client = APIClient(enforce_csrf_checks=True) - response = client.post('/auth-token/', - {'username': self.username}, format='json') + response = client.post( + "/auth-token/", {"username": self.username}, format="json" + ) self.assertEqual(response.status_code, 400) @@ -109,31 +103,31 @@ def test_jwt_login_form(self): """ client = APIClient(enforce_csrf_checks=True) - response = client.post('/auth-token/', self.data) + response = client.post("/auth-token/", self.data) - decoded_payload = utils.jwt_decode_handler(response.data['token']) + decoded_payload = utils.jwt_decode_handler(response.data["token"]) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(decoded_payload['username'], self.username) + self.assertEqual(decoded_payload["username"], self.username) def test_jwt_login_with_expired_token(self): """ Ensure JWT login view works even if expired token is provided """ payload = utils.jwt_payload_handler(self.user) - payload['exp'] = 1 + payload["exp"] = 1 token = utils.jwt_encode_handler(payload) - auth = 'JWT {0}'.format(token) + auth = "JWT {0}".format(token) client = APIClient(enforce_csrf_checks=True) response = client.post( - '/auth-token/', self.data, - HTTP_AUTHORIZATION=auth, format='json') + "/auth-token/", self.data, HTTP_AUTHORIZATION=auth, format="json" + ) - decoded_payload = utils.jwt_decode_handler(response.data['token']) + decoded_payload = utils.jwt_decode_handler(response.data["token"]) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(decoded_payload['username'], self.username) + self.assertEqual(decoded_payload["username"], self.username) def test_jwt_login_using_zero(self): """ @@ -141,35 +135,29 @@ def test_jwt_login_using_zero(self): """ client = APIClient(enforce_csrf_checks=True) - data = { - 'username': '0', - 'password': '0' - } + data = {"username": "0", "password": "0"} - response = client.post('/auth-token/', data, format='json') + response = client.post("/auth-token/", data, format="json") self.assertEqual(response.status_code, 400) -@unittest.skipIf(get_version() < '1.5.0', 'No Configurable User model feature') -@override_settings(AUTH_USER_MODEL='tests.CustomUser') +@unittest.skipIf(get_version() < "1.5.0", "No Configurable User model feature") +@override_settings(AUTH_USER_MODEL="tests.CustomUser") class CustomUserObtainJSONWebTokenTests(TestCase): """JSON Web Token Authentication""" def setUp(self): from .models import CustomUser - self.email = 'jpueblo@example.com' - self.password = 'password' + self.email = "jpueblo@example.com" + self.password = "password" user = CustomUser.objects.create(email=self.email) user.set_password(self.password) user.save() self.user = user - self.data = { - 'email': self.email, - 'password': self.password - } + self.data = {"email": self.email, "password": self.password} def test_jwt_login_json(self): """ @@ -177,11 +165,11 @@ def test_jwt_login_json(self): """ client = APIClient(enforce_csrf_checks=True) - response = client.post('/auth-token/', self.data, format='json') + response = client.post("/auth-token/", self.data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) - decoded_payload = utils.jwt_decode_handler(response.data['token']) - self.assertEqual(decoded_payload['email'], self.email) + decoded_payload = utils.jwt_decode_handler(response.data["token"]) + self.assertEqual(decoded_payload["email"], self.email) def test_jwt_login_json_bad_creds(self): """ @@ -190,30 +178,27 @@ def test_jwt_login_json_bad_creds(self): """ client = APIClient(enforce_csrf_checks=True) - self.data['password'] = 'wrong' - response = client.post('/auth-token/', self.data, format='json') + self.data["password"] = "wrong" + response = client.post("/auth-token/", self.data, format="json") self.assertEqual(response.status_code, 400) -@override_settings(AUTH_USER_MODEL='tests.CustomUserUUID') +@override_settings(AUTH_USER_MODEL="tests.CustomUserUUID") class CustomUserUUIDObtainJSONWebTokenTests(TestCase): """JSON Web Token Authentication""" def setUp(self): from .models import CustomUserUUID - self.email = 'jpueblo@example.com' - self.password = 'password' + self.email = "jpueblo@example.com" + self.password = "password" user = CustomUserUUID.objects.create(email=self.email) user.set_password(self.password) user.save() self.user = user - self.data = { - 'email': self.email, - 'password': self.password - } + self.data = {"email": self.email, "password": self.password} def test_jwt_login_json(self): """ @@ -221,11 +206,11 @@ def test_jwt_login_json(self): """ client = APIClient(enforce_csrf_checks=True) - response = client.post('/auth-token/', self.data, format='json') + response = client.post("/auth-token/", self.data, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) - decoded_payload = utils.jwt_decode_handler(response.data['token']) - self.assertEqual(decoded_payload['user_id'], str(self.user.id)) + decoded_payload = utils.jwt_decode_handler(response.data["token"]) + self.assertEqual(decoded_payload["user_id"], str(self.user.id)) def test_jwt_login_json_bad_creds(self): """ @@ -234,8 +219,8 @@ def test_jwt_login_json_bad_creds(self): """ client = APIClient(enforce_csrf_checks=True) - self.data['password'] = 'wrong' - response = client.post('/auth-token/', self.data, format='json') + self.data["password"] = "wrong" + response = client.post("/auth-token/", self.data, format="json") self.assertEqual(response.status_code, 400) @@ -250,23 +235,22 @@ def setUp(self): def get_token(self): client = APIClient(enforce_csrf_checks=True) - response = client.post('/auth-token/', self.data, format='json') - return response.data['token'] + response = client.post("/auth-token/", self.data, format="json") + return response.data["token"] def create_token(self, user, exp=None, orig_iat=None): payload = utils.jwt_payload_handler(user) if exp: - payload['exp'] = exp + payload["exp"] = exp if orig_iat: - payload['orig_iat'] = timegm(orig_iat.utctimetuple()) + payload["orig_iat"] = timegm(orig_iat.utctimetuple()) token = utils.jwt_encode_handler(payload) return token class VerifyJSONWebTokenTestsSymmetric(TokenTestCase): - def test_verify_jwt(self): """ Test that a valid, non-expired token will return a 200 response @@ -276,12 +260,13 @@ def test_verify_jwt(self): orig_token = self.get_token() # Now try to get a refreshed token - response = client.post('/auth-token-verify/', {'token': orig_token}, - format='json') + response = client.post( + "/auth-token-verify/", {"token": orig_token}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['token'], orig_token) + self.assertEqual(response.data["token"], orig_token) def test_verify_jwt_fails_with_expired_token(self): """ @@ -293,14 +278,14 @@ def test_verify_jwt_fails_with_expired_token(self): token = self.create_token( self.user, exp=datetime.utcnow() - timedelta(seconds=5), - orig_iat=datetime.utcnow() - timedelta(hours=1) + orig_iat=datetime.utcnow() - timedelta(hours=1), ) - response = client.post('/auth-token-verify/', {'token': token}, - format='json') + response = client.post("/auth-token-verify/", {"token": token}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertRegexpMatches(response.data['non_field_errors'][0], - 'Signature has expired') + self.assertRegexpMatches( + response.data["non_field_errors"][0], "Signature has expired" + ) def test_verify_jwt_fails_with_bad_token(self): """ @@ -310,11 +295,11 @@ def test_verify_jwt_fails_with_bad_token(self): token = "i am not a correctly formed token" - response = client.post('/auth-token-verify/', {'token': token}, - format='json') + response = client.post("/auth-token-verify/", {"token": token}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertRegexpMatches(response.data['non_field_errors'][0], - 'Error decoding signature') + self.assertRegexpMatches( + response.data["non_field_errors"][0], "Error decoding signature" + ) def test_verify_jwt_fails_with_missing_user(self): """ @@ -323,33 +308,33 @@ def test_verify_jwt_fails_with_missing_user(self): client = APIClient(enforce_csrf_checks=True) user = User.objects.create_user( - email='jsmith@example.com', username='jsmith', password='password') + email="jsmith@example.com", username="jsmith", password="password" + ) token = self.create_token(user) # Delete the user used to make the token user.delete() - response = client.post('/auth-token-verify/', {'token': token}, - format='json') + response = client.post("/auth-token-verify/", {"token": token}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertRegexpMatches(response.data['non_field_errors'][0], - "User doesn't exist") + self.assertRegexpMatches( + response.data["non_field_errors"][0], "User doesn't exist" + ) class VerifyJSONWebTokenTestsAsymmetric(TokenTestCase): - def setUp(self): super(VerifyJSONWebTokenTestsAsymmetric, self).setUp() - private_key = rsa.generate_private_key(public_exponent=65537, - key_size=2048, - backend=default_backend()) + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) public_key = private_key.public_key() api_settings.JWT_PRIVATE_KEY = private_key api_settings.JWT_PUBLIC_KEY = public_key - api_settings.JWT_ALGORITHM = 'RS512' + api_settings.JWT_ALGORITHM = "RS512" def test_verify_jwt_with_pub_pvt_key(self): """ @@ -360,11 +345,12 @@ def test_verify_jwt_with_pub_pvt_key(self): orig_token = self.get_token() # Now try to get a refreshed token - response = client.post('/auth-token-verify/', {'token': orig_token}, - format='json') + response = client.post( + "/auth-token-verify/", {"token": orig_token}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(response.data['token'], orig_token) + self.assertEqual(response.data["token"], orig_token) def test_verify_jwt_fails_with_expired_token(self): """ @@ -376,14 +362,14 @@ def test_verify_jwt_fails_with_expired_token(self): token = self.create_token( self.user, exp=datetime.utcnow() - timedelta(seconds=5), - orig_iat=datetime.utcnow() - timedelta(hours=1) + orig_iat=datetime.utcnow() - timedelta(hours=1), ) - response = client.post('/auth-token-verify/', {'token': token}, - format='json') + response = client.post("/auth-token-verify/", {"token": token}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertRegexpMatches(response.data['non_field_errors'][0], - 'Signature has expired') + self.assertRegexpMatches( + response.data["non_field_errors"][0], "Signature has expired" + ) def test_verify_jwt_fails_with_bad_token(self): """ @@ -394,11 +380,11 @@ def test_verify_jwt_fails_with_bad_token(self): token = "i am not a correctly formed token" - response = client.post('/auth-token-verify/', {'token': token}, - format='json') + response = client.post("/auth-token-verify/", {"token": token}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertRegexpMatches(response.data['non_field_errors'][0], - 'Error decoding signature') + self.assertRegexpMatches( + response.data["non_field_errors"][0], "Error decoding signature" + ) def test_verify_jwt_fails_with_bad_pvt_key(self): """ @@ -407,9 +393,9 @@ def test_verify_jwt_fails_with_bad_pvt_key(self): """ # Generate a new private key - private_key = rsa.generate_private_key(public_exponent=65537, - key_size=2048, - backend=default_backend()) + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) # Don't set the private key api_settings.JWT_PRIVATE_KEY = private_key @@ -418,22 +404,23 @@ def test_verify_jwt_fails_with_bad_pvt_key(self): orig_token = self.get_token() # Now try to get a refreshed token - response = client.post('/auth-token-verify/', {'token': orig_token}, - format='json') + response = client.post( + "/auth-token-verify/", {"token": orig_token}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertRegexpMatches(response.data['non_field_errors'][0], - 'Error decoding signature') + self.assertRegexpMatches( + response.data["non_field_errors"][0], "Error decoding signature" + ) def tearDown(self): # Restore original settings - api_settings.JWT_ALGORITHM = DEFAULTS['JWT_ALGORITHM'] - api_settings.JWT_PRIVATE_KEY = DEFAULTS['JWT_PRIVATE_KEY'] - api_settings.JWT_PUBLIC_KEY = DEFAULTS['JWT_PUBLIC_KEY'] + api_settings.JWT_ALGORITHM = DEFAULTS["JWT_ALGORITHM"] + api_settings.JWT_PRIVATE_KEY = DEFAULTS["JWT_PRIVATE_KEY"] + api_settings.JWT_PUBLIC_KEY = DEFAULTS["JWT_PUBLIC_KEY"] class RefreshJSONWebTokenTests(TokenTestCase): - def setUp(self): super(RefreshJSONWebTokenTests, self).setUp() api_settings.JWT_ALLOW_REFRESH = True @@ -452,22 +439,23 @@ def test_refresh_jwt(self): expected_orig_iat = timegm(datetime.utcnow().utctimetuple()) # Make sure 'orig_iat' exists and is the current time (give some slack) - orig_iat = orig_token_decoded['orig_iat'] + orig_iat = orig_token_decoded["orig_iat"] self.assertLessEqual(orig_iat - expected_orig_iat, 1) time.sleep(1) # Now try to get a refreshed token - response = client.post('/auth-token-refresh/', {'token': orig_token}, - format='json') + response = client.post( + "/auth-token-refresh/", {"token": orig_token}, format="json" + ) self.assertEqual(response.status_code, status.HTTP_200_OK) - new_token = response.data['token'] + new_token = response.data["token"] new_token_decoded = utils.jwt_decode_handler(new_token) # Make sure 'orig_iat' on the new token is same as original - self.assertEquals(new_token_decoded['orig_iat'], orig_iat) - self.assertGreater(new_token_decoded['exp'], orig_token_decoded['exp']) + self.assertEquals(new_token_decoded["orig_iat"], orig_iat) + self.assertGreater(new_token_decoded["exp"], orig_token_decoded["exp"]) def test_refresh_jwt_after_refresh_expiration(self): """ @@ -475,20 +463,55 @@ def test_refresh_jwt_after_refresh_expiration(self): """ client = APIClient(enforce_csrf_checks=True) - orig_iat = (datetime.utcnow() - api_settings.JWT_REFRESH_EXPIRATION_DELTA - - timedelta(seconds=5)) + orig_iat = ( + datetime.utcnow() + - api_settings.JWT_REFRESH_EXPIRATION_DELTA + - timedelta(seconds=5) + ) token = self.create_token( - self.user, - exp=datetime.utcnow() + timedelta(hours=1), - orig_iat=orig_iat + self.user, exp=datetime.utcnow() + timedelta(hours=1), orig_iat=orig_iat ) - response = client.post('/auth-token-refresh/', {'token': token}, - format='json') + response = client.post("/auth-token-refresh/", {"token": token}, format="json") self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - self.assertEqual(response.data['non_field_errors'][0], - 'Refresh has expired.') + self.assertEqual(response.data["non_field_errors"][0], "Refresh has expired.") + + def test_refresh_jwt_with_cookies(self): + """ + Test getting a refreshed token from original token works + + No date/time modifications are neccessary because it is assumed + that this operation will take less than 300 seconds. + The token is passed through cookies rather than POST data. + """ + api_settings.JWT_AUTH_COOKIE = "plop" + client = APIClient(enforce_csrf_checks=True) + # This way the client will have its token stored in a cookie + response = client.post("/auth-token/", self.data, format="json") + + orig_token = response["token"] + orig_token_decoded = utils.jwt_decode_handler(orig_token) + + expected_orig_iat = timegm(datetime.utcnow().utctimetuple()) + + # Make sure 'orig_iat' exists and is the current time (give some slack) + orig_iat = orig_token_decoded["orig_iat"] + self.assertLessEqual(orig_iat - expected_orig_iat, 1) + + time.sleep(1) + + # Now try to get a refreshed token + response = client.post("/auth-token-refresh/", {}, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + + new_token = response.data["token"] + new_token_decoded = utils.jwt_decode_handler(new_token) + + # Make sure 'orig_iat' on the new token is same as original + self.assertEquals(new_token_decoded["orig_iat"], orig_iat) + self.assertGreater(new_token_decoded["exp"], orig_token_decoded["exp"]) def tearDown(self): # Restore original settings - api_settings.JWT_ALLOW_REFRESH = DEFAULTS['JWT_ALLOW_REFRESH'] + api_settings.JWT_ALLOW_REFRESH = DEFAULTS["JWT_ALLOW_REFRESH"] + api_settings.JWT_AUTH_COOKIE = DEFAULTS["JWT_AUTH_COOKIE"]