Skip to content

Commit

Permalink
add JwtRedirectToLoginIfUnauthenticatedMiddleware
Browse files Browse the repository at this point in the history
Middleware enables the DRF JwtAuthentication authentication class for
endpoints using the LoginRedirectIfUnauthenticated permission class.

Enables a DRF view to redirect the user to login when they are
unauthenticated. It automatically enables JWT-cookie-based
authentication by setting the `USE_JWT_COOKIE_HEADER` for endpoints
using the LoginRedirectIfUnauthenticated permission.

This can be used to convert a plain Django view using @login_required
into a DRF APIView, which is useful to enable our DRF JwtAuthentication
class.

NOTE: This includes a breaking change that is unlikely to affect anyone
unless they subclassed JwtAuthCookieMiddleware, which switched from
using `process_request` to `process_view` so it would not run before
this new middleware.

ARCH-1051
  • Loading branch information
robrap committed Aug 14, 2019
1 parent 9a14294 commit 0351010
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 36 deletions.
2 changes: 1 addition & 1 deletion edx_rest_framework_extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
""" edx Django REST Framework extensions. """

__version__ = '2.3.8' # pragma: no cover
__version__ = '2.4.0' # pragma: no cover
123 changes: 104 additions & 19 deletions edx_rest_framework_extensions/auth/jwt/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"""
import logging

from django.contrib.auth.decorators import login_required
from django.utils.deprecation import MiddlewareMixin
from edx_django_utils import monitoring
from edx_django_utils.cache import RequestCache
from rest_framework_jwt.authentication import BaseJSONWebTokenAuthentication

from edx_rest_framework_extensions.auth.jwt.cookies import (
Expand All @@ -13,7 +15,7 @@
jwt_cookie_signature_name,
)
from edx_rest_framework_extensions.auth.jwt.constants import JWT_DELIMITER
from edx_rest_framework_extensions.permissions import NotJwtRestrictedApplication
from edx_rest_framework_extensions.permissions import LoginRedirectIfUnauthenticated, NotJwtRestrictedApplication

log = logging.getLogger(__name__)
USE_JWT_COOKIE_HEADER = 'HTTP_USE_JWT_COOKIE'
Expand All @@ -26,14 +28,6 @@ class EnsureJWTAuthSettingsMiddleware(MiddlewareMixin):
"""
_required_permission_classes = (NotJwtRestrictedApplication,)

def _includes_base_class(self, iter_classes, base_class):
"""
Returns whether any class in iter_class is a subclass of the given base_class.
"""
return any(
issubclass(auth_class, base_class) for auth_class in iter_classes,
)

def _add_missing_jwt_permission_classes(self, view_class):
"""
Adds permissions classes that should exist for Jwt based authentication,
Expand All @@ -56,7 +50,7 @@ def _add_missing_jwt_permission_classes(self, view_class):
view_permissions.append(child)

for perm_class in self._required_permission_classes:
if not self._includes_base_class(permission_classes, perm_class):
if not _includes_base_class(permission_classes, perm_class):
message = (
u"The view %s allows Jwt Authentication. The required permission class, %s,",
u" was automatically added."
Expand All @@ -72,19 +66,88 @@ def _add_missing_jwt_permission_classes(self, view_class):
view_class.permission_classes += tuple(classes_to_add)

def process_view(self, request, view_func, view_args, view_kwargs): # pylint: disable=unused-argument
# Views as functions store the view's class in the 'view_class' attribute.
# Viewsets store the view's class in the 'cls' attribute.
view_class = getattr(
view_func,
'view_class',
getattr(view_func, 'cls', view_func),
)
view_class = _get_view_class(view_func)

view_authentication_classes = getattr(view_class, 'authentication_classes', tuple())
if self._includes_base_class(view_authentication_classes, BaseJSONWebTokenAuthentication):
if _includes_base_class(view_authentication_classes, BaseJSONWebTokenAuthentication):
self._add_missing_jwt_permission_classes(view_class)


class JwtRedirectToLoginIfUnauthenticatedMiddleware(MiddlewareMixin):
"""
Middleware enables the DRF JwtAuthentication authentication class for endpoints
using the LoginRedirectIfUnauthenticated permission class.
Enables a DRF view to redirect the user to login when they are unauthenticated.
It automatically enables JWT-cookie-based authentication by setting the
`USE_JWT_COOKIE_HEADER` for endpoints using the LoginRedirectIfUnauthenticated
permission.
This can be used to convert a plain Django view using @login_required into a
DRF APIView, which is useful to enable our DRF JwtAuthentication class.
Usage Notes:
- This middleware must be added before JwtAuthCookieMiddleware.
- Only affects endpoints using the LoginRedirectIfUnauthenticated permission class.
See https://github.com/edx/edx-platform/blob/master/openedx/core/djangoapps/oauth_dispatch/docs/decisions/0009-jwt-in-session-cookie.rst # noqa E501 line too long
"""
def get_login_url(self, request): # pylint: disable=unused-argument
"""
Return None for default login url.
Can be overridden for slow-rollout or A/B testing of transition to other login mechanisms.
"""
return None

def is_jwt_auth_enabled_with_login_required(self, request, view_func): # pylint: disable=unused-argument
"""
Returns True if JwtAuthentication is enabled with the LoginRedirectIfUnauthenticated permission class.
Can be overridden for slow roll-out or A/B testing.
"""
return self._is_login_required_found()

def process_view(self, request, view_func, view_args, view_kwargs): # pylint: disable=unused-argument
"""
Enables Jwt Authentication for endpoints using the LoginRedirectIfUnauthenticated permission class.
"""
self._check_and_cache_login_required_found(view_func)
if self.is_jwt_auth_enabled_with_login_required(request, view_func):
request.META[USE_JWT_COOKIE_HEADER] = 'true'

def process_response(self, request, response):
"""
Redirects unauthenticated users to login when LoginRedirectIfUnauthenticated permission class was used.
"""
if self._is_login_required_found() and not request.user.is_authenticated:
login_url = self.get_login_url(request)
return login_required(function=lambda request: None, login_url=login_url)(request)

return response

_REQUEST_CACHE_NAMESPACE = 'JwtRedirectToLoginIfUnauthenticatedMiddleware'
_LOGIN_REQUIRED_FOUND_CACHE_KEY = 'login_required_found'

def _get_request_cache(self):
return RequestCache(self._REQUEST_CACHE_NAMESPACE).data

def _is_login_required_found(self):
"""
Returns True if LoginRedirectIfUnauthenticated permission was found, and False otherwise.
"""
return self._get_request_cache().get(self._LOGIN_REQUIRED_FOUND_CACHE_KEY, False)

def _check_and_cache_login_required_found(self, view_func):
"""
Checks for LoginRedirectIfUnauthenticated permission and caches the result.
"""
view_class = _get_view_class(view_func)
view_permission_classes = getattr(view_class, 'permission_classes', tuple())
is_login_required_found = _includes_base_class(view_permission_classes, LoginRedirectIfUnauthenticated)
self._get_request_cache()[self._LOGIN_REQUIRED_FOUND_CACHE_KEY] = is_login_required_found


class JwtAuthCookieMiddleware(MiddlewareMixin):
"""
Reconstitutes JWT auth cookies for use by API views which use the JwtAuthentication
Expand Down Expand Up @@ -122,7 +185,9 @@ def _get_missing_cookie_message_and_metric(self, cookie_name):
request_jwt_cookie = 'missing-{}'.format(cookie_name)
return cookie_missing_message, request_jwt_cookie

def process_request(self, request):
# Note: Using `process_view` over `process_request` so JwtRedirectToLoginIfUnauthenticatedMiddleware which
# uses `process_view` can update the request before this middleware. Method `process_request` happened too early.
def process_view(self, request, view_func, view_args, view_kwargs): # pylint: disable=unused-argument
"""
Reconstitute the full JWT and add a new cookie on the request object.
"""
Expand Down Expand Up @@ -156,3 +221,23 @@ def process_request(self, request):
metric_value = 'missing-both'

monitoring.set_custom_metric('request_jwt_cookie', metric_value)


def _includes_base_class(iter_classes, base_class):
"""
Returns whether any class in iter_class is a subclass of the given base_class.
"""
return any(
issubclass(current_class, base_class) for current_class in iter_classes,
)


def _get_view_class(view_func):
# Views as functions store the view's class in the 'view_class' attribute.
# Viewsets store the view's class in the 'cls' attribute.
view_class = getattr(
view_func,
'view_class',
getattr(view_func, 'cls', view_func),
)
return view_class
Loading

0 comments on commit 0351010

Please sign in to comment.