Skip to content

Commit

Permalink
Support multiple urlconfs
Browse files Browse the repository at this point in the history
  • Loading branch information
itssimon committed Mar 22, 2024
1 parent 3e45ddd commit db828f4
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 55 deletions.
118 changes: 67 additions & 51 deletions apitally/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import time
from dataclasses import dataclass
from importlib import import_module
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union

from django.conf import settings
from django.urls import Resolver404, URLPattern, URLResolver, get_resolver, resolve
from django.urls import URLPattern, URLResolver, get_resolver
from django.utils.module_loading import import_string

from apitally.client.logging import get_logger
Expand All @@ -32,6 +32,7 @@ class ApitallyMiddlewareConfig:
env: str
app_version: Optional[str]
identify_consumer_callback: Optional[Callable[[HttpRequest], Optional[str]]]
urlconfs: List[Optional[str]]


class ApitallyMiddleware:
Expand All @@ -53,7 +54,12 @@ def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:

self.client = ApitallyClient(client_id=self.config.client_id, env=self.config.env)
self.client.start_sync_loop()
self.client.set_app_info(app_info=_get_app_info(app_version=self.config.app_version))
self.client.set_app_info(
app_info=_get_app_info(
app_version=self.config.app_version,
urlconfs=self.config.urlconfs,
)
)

@classmethod
def configure(
Expand All @@ -62,6 +68,7 @@ def configure(
env: str = "dev",
app_version: Optional[str] = None,
identify_consumer_callback: Optional[str] = None,
urlconf: Optional[Union[List[Optional[str]], str]] = None,
) -> None:
cls.config = ApitallyMiddlewareConfig(
client_id=client_id,
Expand All @@ -70,15 +77,16 @@ def configure(
identify_consumer_callback=import_string(identify_consumer_callback)
if identify_consumer_callback
else None,
urlconfs=[urlconf] if urlconf is None or isinstance(urlconf, str) else urlconf,
)

def __call__(self, request: HttpRequest) -> HttpResponse:
path = self.get_path(request)
start_time = time.perf_counter()
response = self.get_response(request)
path = self.get_path(request)
if request.method is not None and path is not None:
consumer = self.get_consumer(request)
try:
consumer = self.get_consumer(request)
self.client.request_counter.add_request(
consumer=consumer,
method=request.method,
Expand Down Expand Up @@ -113,24 +121,21 @@ def __call__(self, request: HttpRequest) -> HttpResponse:
return response

def get_path(self, request: HttpRequest) -> Optional[str]:
try:
resolver_match = resolve(request.path_info)
except Resolver404:
return None
try:
if self.drf_endpoint_enumerator is not None:
from rest_framework.schemas.generators import is_api_view
if (match := request.resolver_match) is not None:
try:
if self.drf_endpoint_enumerator is not None:
from rest_framework.schemas.generators import is_api_view

if is_api_view(resolver_match.func):
return self.drf_endpoint_enumerator.get_path_from_regex(resolver_match.route)
if self.ninja_available:
from ninja.operation import PathView
if is_api_view(match.func):
return self.drf_endpoint_enumerator.get_path_from_regex(match.route)
if self.ninja_available:
from ninja.operation import PathView

if hasattr(resolver_match.func, "__self__") and isinstance(resolver_match.func.__self__, PathView):
path = "/" + resolver_match.route.lstrip("/")
return re.sub(r"<(?:[^:]+:)?([^>:]+)>", r"{\1}", path)
except Exception:
logger.exception("Failed to get path for request")
if hasattr(match.func, "__self__") and isinstance(match.func.__self__, PathView):
path = "/" + match.route.lstrip("/")
return re.sub(r"<(?:[^:]+:)?([^>:]+)>", r"{\1}", path)
except Exception:
logger.exception("Failed to get path for request")

Check warning on line 138 in apitally/django.py

View check run for this annotation

Codecov / codecov/patch

apitally/django.py#L137-L138

Added lines #L137 - L138 were not covered by tests
return None

def get_consumer(self, request: HttpRequest) -> Optional[str]:
Expand All @@ -146,73 +151,76 @@ def get_consumer(self, request: HttpRequest) -> Optional[str]:
return None


def _get_app_info(app_version: Optional[str] = None) -> Dict[str, Any]:
def _get_app_info(app_version: Optional[str], urlconfs: List[Optional[str]]) -> Dict[str, Any]:
app_info: Dict[str, Any] = {}
try:
app_info["paths"] = _get_paths()
app_info["paths"] = _get_paths(urlconfs)
except Exception:
app_info["paths"] = []
logger.exception("Failed to get paths")

Check warning on line 160 in apitally/django.py

View check run for this annotation

Codecov / codecov/patch

apitally/django.py#L158-L160

Added lines #L158 - L160 were not covered by tests
try:
app_info["openapi"] = _get_openapi()
app_info["openapi"] = _get_openapi(urlconfs)
except Exception:
logger.exception("Failed to get OpenAPI schema")

Check warning on line 164 in apitally/django.py

View check run for this annotation

Codecov / codecov/patch

apitally/django.py#L163-L164

Added lines #L163 - L164 were not covered by tests
app_info["versions"] = get_versions("django", "djangorestframework", "django-ninja", app_version=app_version)
app_info["client"] = "python:django"
return app_info


def _get_openapi() -> Optional[str]:
def _get_openapi(urlconfs: List[Optional[str]]) -> Optional[str]:
drf_schema = None
ninja_schema = None
with contextlib.suppress(ImportError):
drf_schema = _get_drf_schema()
drf_schema = _get_drf_schema(urlconfs)
with contextlib.suppress(ImportError):
ninja_schema = _get_ninja_schema()
ninja_schema = _get_ninja_schema(urlconfs)
if drf_schema is not None and ninja_schema is None:
return json.dumps(drf_schema)
elif ninja_schema is not None and drf_schema is None:
return json.dumps(ninja_schema)
return None

Check warning on line 181 in apitally/django.py

View check run for this annotation

Codecov / codecov/patch

apitally/django.py#L181

Added line #L181 was not covered by tests


def _get_paths() -> List[Dict[str, str]]:
def _get_paths(urlconfs: List[Optional[str]]) -> List[Dict[str, str]]:
paths = []
with contextlib.suppress(ImportError):
paths.extend(_get_drf_paths())
paths.extend(_get_drf_paths(urlconfs))
with contextlib.suppress(ImportError):
paths.extend(_get_ninja_paths())
paths.extend(_get_ninja_paths(urlconfs))
return paths


def _get_drf_paths() -> List[Dict[str, str]]:
def _get_drf_paths(urlconfs: List[Optional[str]]) -> List[Dict[str, str]]:
from rest_framework.schemas.generators import EndpointEnumerator

enumerator = EndpointEnumerator()
enumerators = [EndpointEnumerator(urlconf=urlconf) for urlconf in urlconfs]
return [
{
"method": method.upper(),
"path": path,
}
for enumerator in enumerators
for path, method, _ in enumerator.get_api_endpoints()
if method not in ["HEAD", "OPTIONS"]
]


def _get_drf_schema() -> Optional[Dict[str, Any]]:
def _get_drf_schema(urlconfs: List[Optional[str]]) -> Optional[Dict[str, Any]]:
from rest_framework.schemas.openapi import SchemaGenerator

with contextlib.suppress(AssertionError): # uritemplate must be installed for OpenAPI schema support
generator = SchemaGenerator()
schema = generator.get_schema()
if schema is not None and len(schema["paths"]) > 0:
return schema # type: ignore[return-value]
return None
schemas = []
with contextlib.suppress(AssertionError): # uritemplate and inflection must be installed for OpenAPI schema support
for urlconf in urlconfs:
generator = SchemaGenerator(urlconf=urlconf)
schema = generator.get_schema()
if schema is not None and len(schema["paths"]) > 0:
schemas.append(schema)
return None if len(schemas) != 1 else schemas[0] # type: ignore[return-value]


def _get_ninja_paths() -> List[Dict[str, str]]:
def _get_ninja_paths(urlconfs: List[Optional[str]]) -> List[Dict[str, str]]:
endpoints = []
for api in _get_ninja_api_instances():
for api in _get_ninja_api_instances(urlconfs=urlconfs):
schema = api.get_openapi_schema()
for path, operations in schema["paths"].items():
for method, operation in operations.items():
Expand All @@ -228,25 +236,33 @@ def _get_ninja_paths() -> List[Dict[str, str]]:
return endpoints


def _get_ninja_schema() -> Optional[Dict[str, Any]]:
if len(apis := _get_ninja_api_instances()) == 1:
api = list(apis)[0]
def _get_ninja_schema(urlconfs: List[Optional[str]]) -> Optional[Dict[str, Any]]:
schemas = []
for api in _get_ninja_api_instances(urlconfs=urlconfs):
schema = api.get_openapi_schema()
if len(schema["paths"]) > 0:
return schema
return None
schemas.append(schema)
return None if len(schemas) != 1 else schemas[0]


def _get_ninja_api_instances(url_patterns: Optional[List[Any]] = None) -> Set[NinjaAPI]:
def _get_ninja_api_instances(
urlconfs: Optional[List[Optional[str]]] = None,
patterns: Optional[List[Any]] = None,
) -> Set[NinjaAPI]:
from ninja import NinjaAPI

if url_patterns is None:
url_patterns = get_resolver().url_patterns
if urlconfs is None:
urlconfs = [None]
if patterns is None:
patterns = []
for urlconf in urlconfs:
patterns.extend(get_resolver(urlconf).url_patterns)

apis: Set[NinjaAPI] = set()
for p in url_patterns:
for p in patterns:
if isinstance(p, URLResolver):
if p.app_name != "ninja":
apis.update(_get_ninja_api_instances(p.url_patterns))
apis.update(_get_ninja_api_instances(patterns=p.url_patterns))

Check warning on line 265 in apitally/django.py

View check run for this annotation

Codecov / codecov/patch

apitally/django.py#L265

Added line #L265 was not covered by tests
else:
for pattern in p.url_patterns:
if isinstance(pattern, URLPattern) and pattern.lookup_str.startswith("ninja."):
Expand Down
12 changes: 10 additions & 2 deletions tests/test_django_ninja.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ def test_middleware_requests_ok(client: Client, mocker: MockerFixture):
assert int(mock.call_args.kwargs["request_size"]) > 0


def test_middleware_requests_404(client: Client, mocker: MockerFixture):
mock = mocker.patch("apitally.client.base.RequestCounter.add_request")

response = client.get("/api/none")
assert response.status_code == 404
mock.assert_not_called()


def test_middleware_requests_error(client: Client, mocker: MockerFixture):
mock = mocker.patch("apitally.client.base.RequestCounter.add_request")

Expand Down Expand Up @@ -117,7 +125,7 @@ def test_middleware_validation_error(client: Client, mocker: MockerFixture):
def test_get_app_info():
from apitally.django import _get_app_info

app_info = _get_app_info(app_version="1.2.3")
app_info = _get_app_info(app_version="1.2.3", urlconfs=[None])
openapi = json.loads(app_info["openapi"])
assert len(app_info["paths"]) == 5
assert len(openapi["paths"]) == 5
Expand All @@ -142,7 +150,7 @@ def test_get_ninja_api_instances():
def test_get_ninja_api_endpoints():
from apitally.django import _get_ninja_paths

endpoints = _get_ninja_paths()
endpoints = _get_ninja_paths([None])
assert len(endpoints) == 5
assert all(len(e["summary"]) > 0 for e in endpoints)
assert any(e["description"] is not None and len(e["description"]) > 0 for e in endpoints)
12 changes: 10 additions & 2 deletions tests/test_django_rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ def test_middleware_requests_ok(client: APIClient, mocker: MockerFixture):
assert int(mock.call_args.kwargs["request_size"]) > 0


def test_middleware_requests_404(client: APIClient, mocker: MockerFixture):
mock = mocker.patch("apitally.client.base.RequestCounter.add_request")

response = client.get("/api/none")
assert response.status_code == 404
mock.assert_not_called()


def test_middleware_requests_error(client: APIClient, mocker: MockerFixture):
mock = mocker.patch("apitally.client.base.RequestCounter.add_request")

Expand All @@ -101,7 +109,7 @@ def test_middleware_requests_error(client: APIClient, mocker: MockerFixture):
def test_get_app_info():
from apitally.django import _get_app_info

app_info = _get_app_info(app_version="1.2.3")
app_info = _get_app_info(app_version="1.2.3", urlconfs=[None])
openapi = json.loads(app_info["openapi"])
assert len(app_info["paths"]) == 4
assert len(openapi["paths"]) == 4
Expand All @@ -115,7 +123,7 @@ def test_get_app_info():
def test_get_drf_api_endpoints():
from apitally.django import _get_drf_paths

endpoints = _get_drf_paths()
endpoints = _get_drf_paths([None])
assert len(endpoints) == 4
assert endpoints[0]["method"] == "GET"
assert endpoints[0]["path"] == "/foo/"
Expand Down

0 comments on commit db828f4

Please sign in to comment.