From db828f4241cca2c39564d171b0fbfefac9b621cd Mon Sep 17 00:00:00 2001 From: Simon Gurcke Date: Fri, 22 Mar 2024 13:05:53 +1000 Subject: [PATCH] Support multiple urlconfs --- apitally/django.py | 118 ++++++++++++++++------------ tests/test_django_ninja.py | 12 ++- tests/test_django_rest_framework.py | 12 ++- 3 files changed, 87 insertions(+), 55 deletions(-) diff --git a/apitally/django.py b/apitally/django.py index a0305bb..df04fc1 100644 --- a/apitally/django.py +++ b/apitally/django.py @@ -6,10 +6,10 @@ import time from dataclasses import dataclass from importlib import import_module -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union from django.conf import settings -from django.urls import Resolver404, URLPattern, URLResolver, get_resolver, resolve +from django.urls import URLPattern, URLResolver, get_resolver from django.utils.module_loading import import_string from apitally.client.logging import get_logger @@ -32,6 +32,7 @@ class ApitallyMiddlewareConfig: env: str app_version: Optional[str] identify_consumer_callback: Optional[Callable[[HttpRequest], Optional[str]]] + urlconfs: List[Optional[str]] class ApitallyMiddleware: @@ -53,7 +54,12 @@ def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None: self.client = ApitallyClient(client_id=self.config.client_id, env=self.config.env) self.client.start_sync_loop() - self.client.set_app_info(app_info=_get_app_info(app_version=self.config.app_version)) + self.client.set_app_info( + app_info=_get_app_info( + app_version=self.config.app_version, + urlconfs=self.config.urlconfs, + ) + ) @classmethod def configure( @@ -62,6 +68,7 @@ def configure( env: str = "dev", app_version: Optional[str] = None, identify_consumer_callback: Optional[str] = None, + urlconf: Optional[Union[List[Optional[str]], str]] = None, ) -> None: cls.config = ApitallyMiddlewareConfig( client_id=client_id, @@ -70,15 +77,16 @@ def configure( identify_consumer_callback=import_string(identify_consumer_callback) if identify_consumer_callback else None, + urlconfs=[urlconf] if urlconf is None or isinstance(urlconf, str) else urlconf, ) def __call__(self, request: HttpRequest) -> HttpResponse: - path = self.get_path(request) start_time = time.perf_counter() response = self.get_response(request) + path = self.get_path(request) if request.method is not None and path is not None: + consumer = self.get_consumer(request) try: - consumer = self.get_consumer(request) self.client.request_counter.add_request( consumer=consumer, method=request.method, @@ -113,24 +121,21 @@ def __call__(self, request: HttpRequest) -> HttpResponse: return response def get_path(self, request: HttpRequest) -> Optional[str]: - try: - resolver_match = resolve(request.path_info) - except Resolver404: - return None - try: - if self.drf_endpoint_enumerator is not None: - from rest_framework.schemas.generators import is_api_view + if (match := request.resolver_match) is not None: + try: + if self.drf_endpoint_enumerator is not None: + from rest_framework.schemas.generators import is_api_view - if is_api_view(resolver_match.func): - return self.drf_endpoint_enumerator.get_path_from_regex(resolver_match.route) - if self.ninja_available: - from ninja.operation import PathView + if is_api_view(match.func): + return self.drf_endpoint_enumerator.get_path_from_regex(match.route) + if self.ninja_available: + from ninja.operation import PathView - if hasattr(resolver_match.func, "__self__") and isinstance(resolver_match.func.__self__, PathView): - path = "/" + resolver_match.route.lstrip("/") - return re.sub(r"<(?:[^:]+:)?([^>:]+)>", r"{\1}", path) - except Exception: - logger.exception("Failed to get path for request") + if hasattr(match.func, "__self__") and isinstance(match.func.__self__, PathView): + path = "/" + match.route.lstrip("/") + return re.sub(r"<(?:[^:]+:)?([^>:]+)>", r"{\1}", path) + except Exception: + logger.exception("Failed to get path for request") return None def get_consumer(self, request: HttpRequest) -> Optional[str]: @@ -146,15 +151,15 @@ def get_consumer(self, request: HttpRequest) -> Optional[str]: return None -def _get_app_info(app_version: Optional[str] = None) -> Dict[str, Any]: +def _get_app_info(app_version: Optional[str], urlconfs: List[Optional[str]]) -> Dict[str, Any]: app_info: Dict[str, Any] = {} try: - app_info["paths"] = _get_paths() + app_info["paths"] = _get_paths(urlconfs) except Exception: app_info["paths"] = [] logger.exception("Failed to get paths") try: - app_info["openapi"] = _get_openapi() + app_info["openapi"] = _get_openapi(urlconfs) except Exception: logger.exception("Failed to get OpenAPI schema") app_info["versions"] = get_versions("django", "djangorestframework", "django-ninja", app_version=app_version) @@ -162,13 +167,13 @@ def _get_app_info(app_version: Optional[str] = None) -> Dict[str, Any]: return app_info -def _get_openapi() -> Optional[str]: +def _get_openapi(urlconfs: List[Optional[str]]) -> Optional[str]: drf_schema = None ninja_schema = None with contextlib.suppress(ImportError): - drf_schema = _get_drf_schema() + drf_schema = _get_drf_schema(urlconfs) with contextlib.suppress(ImportError): - ninja_schema = _get_ninja_schema() + ninja_schema = _get_ninja_schema(urlconfs) if drf_schema is not None and ninja_schema is None: return json.dumps(drf_schema) elif ninja_schema is not None and drf_schema is None: @@ -176,43 +181,46 @@ def _get_openapi() -> Optional[str]: return None -def _get_paths() -> List[Dict[str, str]]: +def _get_paths(urlconfs: List[Optional[str]]) -> List[Dict[str, str]]: paths = [] with contextlib.suppress(ImportError): - paths.extend(_get_drf_paths()) + paths.extend(_get_drf_paths(urlconfs)) with contextlib.suppress(ImportError): - paths.extend(_get_ninja_paths()) + paths.extend(_get_ninja_paths(urlconfs)) return paths -def _get_drf_paths() -> List[Dict[str, str]]: +def _get_drf_paths(urlconfs: List[Optional[str]]) -> List[Dict[str, str]]: from rest_framework.schemas.generators import EndpointEnumerator - enumerator = EndpointEnumerator() + enumerators = [EndpointEnumerator(urlconf=urlconf) for urlconf in urlconfs] return [ { "method": method.upper(), "path": path, } + for enumerator in enumerators for path, method, _ in enumerator.get_api_endpoints() if method not in ["HEAD", "OPTIONS"] ] -def _get_drf_schema() -> Optional[Dict[str, Any]]: +def _get_drf_schema(urlconfs: List[Optional[str]]) -> Optional[Dict[str, Any]]: from rest_framework.schemas.openapi import SchemaGenerator - with contextlib.suppress(AssertionError): # uritemplate must be installed for OpenAPI schema support - generator = SchemaGenerator() - schema = generator.get_schema() - if schema is not None and len(schema["paths"]) > 0: - return schema # type: ignore[return-value] - return None + schemas = [] + with contextlib.suppress(AssertionError): # uritemplate and inflection must be installed for OpenAPI schema support + for urlconf in urlconfs: + generator = SchemaGenerator(urlconf=urlconf) + schema = generator.get_schema() + if schema is not None and len(schema["paths"]) > 0: + schemas.append(schema) + return None if len(schemas) != 1 else schemas[0] # type: ignore[return-value] -def _get_ninja_paths() -> List[Dict[str, str]]: +def _get_ninja_paths(urlconfs: List[Optional[str]]) -> List[Dict[str, str]]: endpoints = [] - for api in _get_ninja_api_instances(): + for api in _get_ninja_api_instances(urlconfs=urlconfs): schema = api.get_openapi_schema() for path, operations in schema["paths"].items(): for method, operation in operations.items(): @@ -228,25 +236,33 @@ def _get_ninja_paths() -> List[Dict[str, str]]: return endpoints -def _get_ninja_schema() -> Optional[Dict[str, Any]]: - if len(apis := _get_ninja_api_instances()) == 1: - api = list(apis)[0] +def _get_ninja_schema(urlconfs: List[Optional[str]]) -> Optional[Dict[str, Any]]: + schemas = [] + for api in _get_ninja_api_instances(urlconfs=urlconfs): schema = api.get_openapi_schema() if len(schema["paths"]) > 0: - return schema - return None + schemas.append(schema) + return None if len(schemas) != 1 else schemas[0] -def _get_ninja_api_instances(url_patterns: Optional[List[Any]] = None) -> Set[NinjaAPI]: +def _get_ninja_api_instances( + urlconfs: Optional[List[Optional[str]]] = None, + patterns: Optional[List[Any]] = None, +) -> Set[NinjaAPI]: from ninja import NinjaAPI - if url_patterns is None: - url_patterns = get_resolver().url_patterns + if urlconfs is None: + urlconfs = [None] + if patterns is None: + patterns = [] + for urlconf in urlconfs: + patterns.extend(get_resolver(urlconf).url_patterns) + apis: Set[NinjaAPI] = set() - for p in url_patterns: + for p in patterns: if isinstance(p, URLResolver): if p.app_name != "ninja": - apis.update(_get_ninja_api_instances(p.url_patterns)) + apis.update(_get_ninja_api_instances(patterns=p.url_patterns)) else: for pattern in p.url_patterns: if isinstance(pattern, URLPattern) and pattern.lookup_str.startswith("ninja."): diff --git a/tests/test_django_ninja.py b/tests/test_django_ninja.py index 0724a28..02459e0 100644 --- a/tests/test_django_ninja.py +++ b/tests/test_django_ninja.py @@ -88,6 +88,14 @@ def test_middleware_requests_ok(client: Client, mocker: MockerFixture): assert int(mock.call_args.kwargs["request_size"]) > 0 +def test_middleware_requests_404(client: Client, mocker: MockerFixture): + mock = mocker.patch("apitally.client.base.RequestCounter.add_request") + + response = client.get("/api/none") + assert response.status_code == 404 + mock.assert_not_called() + + def test_middleware_requests_error(client: Client, mocker: MockerFixture): mock = mocker.patch("apitally.client.base.RequestCounter.add_request") @@ -117,7 +125,7 @@ def test_middleware_validation_error(client: Client, mocker: MockerFixture): def test_get_app_info(): from apitally.django import _get_app_info - app_info = _get_app_info(app_version="1.2.3") + app_info = _get_app_info(app_version="1.2.3", urlconfs=[None]) openapi = json.loads(app_info["openapi"]) assert len(app_info["paths"]) == 5 assert len(openapi["paths"]) == 5 @@ -142,7 +150,7 @@ def test_get_ninja_api_instances(): def test_get_ninja_api_endpoints(): from apitally.django import _get_ninja_paths - endpoints = _get_ninja_paths() + endpoints = _get_ninja_paths([None]) assert len(endpoints) == 5 assert all(len(e["summary"]) > 0 for e in endpoints) assert any(e["description"] is not None and len(e["description"]) > 0 for e in endpoints) diff --git a/tests/test_django_rest_framework.py b/tests/test_django_rest_framework.py index 5a1896a..d57002a 100644 --- a/tests/test_django_rest_framework.py +++ b/tests/test_django_rest_framework.py @@ -85,6 +85,14 @@ def test_middleware_requests_ok(client: APIClient, mocker: MockerFixture): assert int(mock.call_args.kwargs["request_size"]) > 0 +def test_middleware_requests_404(client: APIClient, mocker: MockerFixture): + mock = mocker.patch("apitally.client.base.RequestCounter.add_request") + + response = client.get("/api/none") + assert response.status_code == 404 + mock.assert_not_called() + + def test_middleware_requests_error(client: APIClient, mocker: MockerFixture): mock = mocker.patch("apitally.client.base.RequestCounter.add_request") @@ -101,7 +109,7 @@ def test_middleware_requests_error(client: APIClient, mocker: MockerFixture): def test_get_app_info(): from apitally.django import _get_app_info - app_info = _get_app_info(app_version="1.2.3") + app_info = _get_app_info(app_version="1.2.3", urlconfs=[None]) openapi = json.loads(app_info["openapi"]) assert len(app_info["paths"]) == 4 assert len(openapi["paths"]) == 4 @@ -115,7 +123,7 @@ def test_get_app_info(): def test_get_drf_api_endpoints(): from apitally.django import _get_drf_paths - endpoints = _get_drf_paths() + endpoints = _get_drf_paths([None]) assert len(endpoints) == 4 assert endpoints[0]["method"] == "GET" assert endpoints[0]["path"] == "/foo/"