Skip to content

Commit

Permalink
Add API endpoint for retrieving CSRF token.
Browse files Browse the repository at this point in the history
  • Loading branch information
Douglas Hall committed Oct 9, 2018
1 parent 58f4297 commit 360b7d0
Show file tree
Hide file tree
Showing 16 changed files with 201 additions and 68 deletions.
11 changes: 11 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ edX Django REST Framework Extensions |Travis|_ |Codecov|_
This library includes extensions of `Django REST Framework <http://www.django-rest-framework.org/>`_
useful for edX applications.

CSRF API
--------

This library also includes a ``csrf`` app containing an API endpoint for retrieving CSRF tokens from
the Django service in which it is installed. This is useful for frontend apps attempting to make POST,
PUT, and DELETE requests to a Django service with Django's CSRF middleware enabled.

To make use of this API endpoint:

#. Install edx-drf-extensions in your Django project.
#. Add ``csrf.apps.CsrfAppConfig`` to ``INSTALLED_APPS``.

License
-------
Expand Down
Empty file added csrf/__init__.py
Empty file.
Empty file added csrf/api/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions csrf/api/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
URL definitions for the CSRF API endpoints.
"""

from django.conf.urls import include, url

urlpatterns = [
url(r'^v1/', include('csrf.api.v1.urls'), name='csrf_api_v1'),
]
Empty file added csrf/api/v1/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions csrf/api/v1/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
URL definitions for version 1 of the CSRF API.
"""

from django.conf.urls import url

from .views import CsrfTokenView

urlpatterns = [
url(r'^token$', CsrfTokenView.as_view(), name='csrf_token'),
]
36 changes: 36 additions & 0 deletions csrf/api/v1/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
API for CSRF application.
"""

from django.middleware.csrf import get_token
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView

from edx_rest_framework_extensions.authentication import JwtAuthentication


class CsrfTokenView(APIView):
"""
**Use Case**
Allows frontend apps to obtain a CSRF token from the Django
service in order to make POST, PUT, and DELETE requests to
API endpoints hosted on the service.
**Behavior**
GET /csrf/api/v1/token
>>> {
>>> "csrfToken": "abcdefg1234567"
>>> }
"""

authentication_classes = (JwtAuthentication,)
permission_classes = (IsAuthenticated,)

def get(self, request):
"""
GET /csrf/api/v1/token
"""
return Response({'csrfToken': get_token(request)})
12 changes: 12 additions & 0 deletions csrf/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
App for creating and distributing CSRF tokens to frontend applications.
"""

from django.apps import AppConfig


class CsrfAppConfig(AppConfig):
"""Configuration for the csrf application."""

name = 'csrf'
verbose_name = 'CSRF'
Empty file added csrf/tests/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions csrf/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
""" Tests for the CSRF API """

from django.urls import reverse
from rest_framework import status
from rest_framework.test import APITestCase

from edx_rest_framework_extensions.auth.jwt.tests.utils import generate_jwt
from edx_rest_framework_extensions.tests.factories import UserFactory


class CsrfTokenTests(APITestCase):
""" Tests for the CSRF token endpoint. """

def test_get_token(self):
"""
Ensure we can get a CSRF token.
"""
url = reverse('csrf_token')
user = UserFactory()
jwt = generate_jwt(user)
self.client.credentials(HTTP_AUTHORIZATION='JWT {}'.format(jwt))
response = self.client.get(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('csrfToken', response.data)
9 changes: 9 additions & 0 deletions csrf/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
URLs for the CSRF application.
"""

from django.conf.urls import include, url

urlpatterns = [
url(r'^csrf/api/', include('csrf.api.urls'), name='csrf_api'),
]
2 changes: 1 addition & 1 deletion edx_rest_framework_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
""" edx Django REST Framework extensions. """

__version__ = '1.10.0' # pragma: no cover
__version__ = '1.11.0' # pragma: no cover
82 changes: 16 additions & 66 deletions edx_rest_framework_extensions/auth/jwt/tests/test_decoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
""" Tests for utility functions. """
import copy
from time import time

import ddt
import jwt
Expand All @@ -9,63 +8,14 @@
from django.test import override_settings, TestCase

from edx_rest_framework_extensions.auth.jwt.decoder import jwt_decode_handler
from edx_rest_framework_extensions.auth.jwt.tests.utils import (
generate_jwt_token,
generate_latest_version_payload,
generate_unversioned_payload,
)
from edx_rest_framework_extensions.tests.factories import UserFactory


def generate_jwt(user, scopes=None, filters=None, is_restricted=None):
"""
Generate a valid JWT for authenticated requests.
"""
access_token = _generate_latest_version_payload(user, scopes=scopes, filters=filters, is_restricted=is_restricted)
return _generate_jwt_token(access_token)


def _generate_jwt_token(payload, signing_key=None):
"""
Generate a valid JWT token for authenticated requests.
"""
signing_key = signing_key or settings.JWT_AUTH['JWT_ISSUERS'][0]['SECRET_KEY']
return jwt.encode(payload, signing_key).decode('utf-8')


def _generate_latest_version_payload(user, scopes=None, filters=None, version=None, is_restricted=None):
"""
Generate a valid JWT payload given a user and optionally scopes and filters.
"""
payload = _generate_starting_version_payload(user)
payload.update({
# fix this version and add newly introduced fields as the version updates.
'version': '1.1.0',
'filters': [],
'is_restricted': False,
})
if scopes is not None:
payload['scopes'] = scopes
if version is not None:
payload['version'] = version
if filters is not None:
payload['filters'] = filters
if is_restricted is not None:
payload['is_restricted'] = is_restricted
return payload


def _generate_starting_version_payload(user):
jwt_issuer_data = settings.JWT_AUTH['JWT_ISSUERS'][0]
now = int(time())
ttl = 600
payload = {
'iss': jwt_issuer_data['ISSUER'],
'aud': jwt_issuer_data['AUDIENCE'],
'username': user.username,
'email': user.email,
'iat': now,
'exp': now + ttl,
'scopes': [],
}
return payload


def exclude_from_jwt_auth_setting(key):
"""
Clone the JWT_AUTH setting dict and remove the given key.
Expand All @@ -90,8 +40,8 @@ class JWTDecodeHandlerTests(TestCase):
def setUp(self):
super(JWTDecodeHandlerTests, self).setUp()
self.user = UserFactory()
self.payload = _generate_latest_version_payload(self.user)
self.jwt = _generate_jwt_token(self.payload)
self.payload = generate_latest_version_payload(self.user)
self.jwt = generate_jwt_token(self.payload)

def test_success(self):
"""
Expand All @@ -108,7 +58,7 @@ def test_valid_token_multiple_valid_issuers(self, jwt_issuer):
# Verify that each valid issuer is properly matched against the valid issuers list
# and used to decode the token that was generated using said valid issuer data
self.payload['iss'] = jwt_issuer['ISSUER']
token = _generate_jwt_token(self.payload, jwt_issuer['SECRET_KEY'])
token = generate_jwt_token(self.payload, jwt_issuer['SECRET_KEY'])
self.assertEqual(jwt_decode_handler(token), self.payload)

def test_failure_invalid_issuer(self):
Expand All @@ -123,7 +73,7 @@ def test_failure_invalid_issuer(self):
self.payload['iss'] = 'invalid-issuer'
signing_key = 'invalid-secret-key'
# Generate a token using the invalid issuer data
token = _generate_jwt_token(self.payload, signing_key)
token = generate_jwt_token(self.payload, signing_key)
# Attempt to decode the token against the entries in the valid issuers list,
# which will fail with an InvalidTokenError
jwt_decode_handler(token)
Expand All @@ -149,16 +99,16 @@ def test_supported_jwt_version_not_specified(self):
"""
Verifies the JWT is decoded successfully when the JWT_SUPPORTED_VERSION setting is not specified.
"""
token = _generate_jwt_token(self.payload)
token = generate_jwt_token(self.payload)
self.assertDictEqual(jwt_decode_handler(token), self.payload)

@ddt.data(None, '0.5.0', '1.0.0', '1.0.5', '1.5.0', '1.5.5')
def test_supported_jwt_version(self, jwt_version):
"""
Verifies the JWT is decoded successfully with different supported versions in the token.
"""
jwt_payload = _generate_latest_version_payload(self.user, version=jwt_version)
token = _generate_jwt_token(jwt_payload)
jwt_payload = generate_latest_version_payload(self.user, version=jwt_version)
token = generate_jwt_token(jwt_payload)
self.assertDictEqual(jwt_decode_handler(token), jwt_payload)

@override_settings(JWT_AUTH=update_jwt_auth_setting({'JWT_SUPPORTED_VERSION': '0.5.0'}))
Expand All @@ -169,7 +119,7 @@ def test_unsupported_jwt_version(self):
"""
with mock.patch('edx_rest_framework_extensions.auth.jwt.decoder.logger') as patched_log:
with self.assertRaises(jwt.InvalidTokenError):
token = _generate_jwt_token(self.payload)
token = generate_jwt_token(self.payload)
jwt_decode_handler(token)

msg = "Token decode failed due to unsupported JWT version number [%s]"
Expand All @@ -179,10 +129,10 @@ def test_upgrade(self):
"""
Verifies the JWT is upgraded when an old (starting) version is provided.
"""
jwt_payload = _generate_starting_version_payload(self.user)
token = _generate_jwt_token(jwt_payload)
jwt_payload = generate_unversioned_payload(self.user)
token = generate_jwt_token(jwt_payload)

upgraded_payload = _generate_latest_version_payload(self.user, version='1.0.0')
upgraded_payload = generate_latest_version_payload(self.user, version='1.0.0')

# Keep time-related values constant for full-proof comparison.
upgraded_payload['iat'], upgraded_payload['exp'] = jwt_payload['iat'], jwt_payload['exp']
Expand Down
68 changes: 68 additions & 0 deletions edx_rest_framework_extensions/auth/jwt/tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
""" Utility functions for tests. """
from time import time

import jwt
from django.conf import settings


def generate_jwt(user, scopes=None, filters=None, is_restricted=None):
"""
Generate a valid JWT for authenticated requests.
"""
access_token = generate_latest_version_payload(
user,
scopes=scopes,
filters=filters,
is_restricted=is_restricted
)
return generate_jwt_token(access_token)


def generate_jwt_token(payload, signing_key=None):
"""
Generate a valid JWT token for authenticated requests.
"""
signing_key = signing_key or settings.JWT_AUTH['JWT_ISSUERS'][0]['SECRET_KEY']
return jwt.encode(payload, signing_key).decode('utf-8')


def generate_latest_version_payload(user, scopes=None, filters=None, version=None,
is_restricted=None):
"""
Generate a valid JWT payload given a user and optionally scopes and filters.
"""
payload = generate_unversioned_payload(user)
payload.update({
# fix this version and add newly introduced fields as the version updates.
'version': '1.1.0',
'filters': [],
'is_restricted': False,
})
if scopes is not None:
payload['scopes'] = scopes
if version is not None:
payload['version'] = version
if filters is not None:
payload['filters'] = filters
if is_restricted is not None:
payload['is_restricted'] = is_restricted
return payload


def generate_unversioned_payload(user):
"""
Generate an unversioned valid JWT payload given a user.
"""
jwt_issuer_data = settings.JWT_AUTH['JWT_ISSUERS'][0]
now = int(time())
ttl = 600
payload = {
'iss': jwt_issuer_data['ISSUER'],
'aud': jwt_issuer_data['AUDIENCE'],
'username': user.username,
'email': user.email,
'iat': now,
'exp': now + ttl,
'scopes': [],
}
return payload
2 changes: 1 addition & 1 deletion edx_rest_framework_extensions/tests/test_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from edx_rest_framework_extensions import permissions
from edx_rest_framework_extensions.auth.jwt.authentication import JwtAuthentication
from edx_rest_framework_extensions.auth.jwt.tests.test_decoder import generate_jwt
from edx_rest_framework_extensions.auth.jwt.tests.utils import generate_jwt
from edx_rest_framework_extensions.tests import factories
from edx_rest_framework_extensions.tests.factories import UserFactory

Expand Down
3 changes: 3 additions & 0 deletions test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

SECRET_KEY = 'insecure-secret-key'

ROOT_URLCONF = 'csrf.urls'

INSTALLED_APPS = (
'csrf.apps.CsrfAppConfig',
'django.contrib.auth',
'django.contrib.contenttypes',
'django_nose',
Expand Down

0 comments on commit 360b7d0

Please sign in to comment.