From 6687092c4067bd3cb4c5b0f5c3d424c5daf27c66 Mon Sep 17 00:00:00 2001 From: Matthias Dellweg Date: Tue, 5 Nov 2024 12:10:18 +0100 Subject: [PATCH 1/2] Refactor internals of openapi class This will mask attributes of the class private and add a bit of abstraction from requests. --- pulp-glue/pulp_glue/common/openapi.py | 112 +++++++++++++++----------- pulp-glue/pyproject.toml | 1 + 2 files changed, 64 insertions(+), 49 deletions(-) diff --git a/pulp-glue/pulp_glue/common/openapi.py b/pulp-glue/pulp_glue/common/openapi.py index 35858ea79..2e2e3666e 100644 --- a/pulp-glue/pulp_glue/common/openapi.py +++ b/pulp-glue/pulp_glue/common/openapi.py @@ -30,20 +30,14 @@ class OpenAPIError(Exception): """Base Exception for errors related to using the openapi spec.""" - pass - class OpenAPIValidationError(OpenAPIError): """Exception raised for failed client side validation of parameters or request bodies.""" - pass - class UnsafeCallError(OpenAPIError): """Exception raised for POST, PUT, PATCH or DELETE calls with `safe_calls_only=True`.""" - pass - class AuthProviderBase: """ @@ -174,45 +168,62 @@ def __init__( user_agent: t.Optional[str] = None, cid: t.Optional[str] = None, ): - if verify is False: - urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - - self.debug_callback: t.Callable[[int, str], t.Any] = debug_callback or (lambda i, x: None) - self.base_url: str = base_url - self.doc_path: str = doc_path - self.safe_calls_only: bool = safe_calls_only - self.auth_provider = auth_provider - - self._session: requests.Session = requests.session() - if self.auth_provider: - if cert or key: - raise OpenAPIError(_("Cannot use both 'auth' and 'cert'.")) - else: - if cert and key: - self._session.cert = (cert, key) - elif cert: - self._session.cert = cert - elif key: - raise OpenAPIError(_("Cert is required if key is set.")) - self._session.headers.update( + self._debug_callback: t.Callable[[int, str], t.Any] = debug_callback or (lambda i, x: None) + self._base_url: str = base_url + self._doc_path: str = doc_path + self._safe_calls_only: bool = safe_calls_only + self._headers = headers or {} + self._verify = verify + self._auth_provider = auth_provider + self._cert = cert + self._key = key + + self._headers.update( { "User-Agent": user_agent or f"Pulp-glue openapi parser ({__version__})", "Accept": "application/json", } ) if cid: - self._session.headers["Correlation-Id"] = cid - if headers: - self._session.headers.update(headers) - self._session.max_redirects = 0 + self._headers["Correlation-Id"] = cid + + self._setup_session() + self.load_api(refresh_cache=refresh_cache) + + def _setup_session(self) -> None: + # This is specific requests library. + + if self._verify is False: + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + self._session: requests.Session = requests.session() + # Don't redirect, because carrying auth accross redirects is unsafe. + self._session.max_redirects = 0 + self._session.headers.update(self._headers) + if self._auth_provider: + if self._cert or self._key: + raise OpenAPIError(_("Cannot use both 'auth' and 'cert'.")) + else: + if self._cert and self._key: + self._session.cert = (self._cert, self._key) + elif self._cert: + self._session.cert = self._cert + elif self._key: + raise OpenAPIError(_("Cert is required if key is set.")) session_settings = self._session.merge_environment_settings( - base_url, {}, None, verify, None + self._base_url, {}, None, self._verify, None ) self._session.verify = session_settings["verify"] self._session.proxies = session_settings["proxies"] - self.load_api(refresh_cache=refresh_cache) + @property + def base_url(self) -> str: + return self._base_url + + @property + def cid(self) -> t.Optional[str]: + return self._headers.get("Correlation-Id") def load_api(self, refresh_cache: bool = False) -> None: # TODO: Find a way to invalidate caches on upstream change @@ -220,7 +231,8 @@ def load_api(self, refresh_cache: bool = False) -> None: apidoc_cache: str = os.path.join( os.path.expanduser(xdg_cache_home), "squeezer", - (self.base_url + "_" + self.doc_path).replace(":", "_").replace("/", "_") + "api.json", + (self._base_url + "_" + self._doc_path).replace(":", "_").replace("/", "_") + + "api.json", ) try: if refresh_cache: @@ -252,7 +264,7 @@ def _parse_api(self, data: bytes) -> None: def _download_api(self) -> bytes: try: - response: requests.Response = self._session.get(urljoin(self.base_url, self.doc_path)) + response: requests.Response = self._session.get(urljoin(self._base_url, self._doc_path)) except requests.RequestException as e: raise OpenAPIError(str(e)) response.raise_for_status() @@ -261,14 +273,16 @@ def _download_api(self) -> bytes: return response.content def _set_correlation_id(self, correlation_id: str) -> None: - if "Correlation-ID" in self._session.headers: - if self._session.headers["Correlation-ID"] != correlation_id: + if "Correlation-ID" in self._headers: + if self._headers["Correlation-ID"] != correlation_id: raise OpenAPIError( _("Correlation ID returned from server did not match. {} != {}").format( - self._session.headers["Correlation-ID"], correlation_id + self._headers["Correlation-ID"], correlation_id ) ) else: + self._headers["Correlation-ID"] = correlation_id + # Do it for requests too... self._session.headers["Correlation-ID"] = correlation_id def param_spec( @@ -420,7 +434,7 @@ def validate_schema(self, schema: t.Any, name: str, value: t.Any) -> t.Any: # TextField # DateTimeField etc. # ChoiceField - # FielField (binary data) + # FileField (binary data) value = self.validate_string(schema, name, value) elif schema_type == "integer": # IntegerField @@ -654,12 +668,12 @@ def render_request( security: t.List[t.Dict[str, t.List[str]]] = method_spec.get( "security", self.api_spec.get("security", {}) ) - if security and self.auth_provider: + if security and self._auth_provider: if "Authorization" in self._session.headers: # Bad idea, but you wanted it that way. auth = None else: - auth = self.auth_provider(security, self.api_spec["components"]["securitySchemes"]) + auth = self._auth_provider(security, self.api_spec["components"]["securitySchemes"]) else: # No auth required? Don't provide it. # No auth_provider available? Hope for the best (should do the trick for cert auth). @@ -751,7 +765,7 @@ def call( names=", ".join(parameters.keys()), operation_id=operation_id ) ) - url = urljoin(self.base_url, path) + url = urljoin(self._base_url, path) request: requests.PreparedRequest = self.render_request( path_spec, @@ -763,12 +777,12 @@ def call( validate_body=validate_body, ) - self.debug_callback(1, f"{operation_id} : {method} {request.url}") + self._debug_callback(1, f"{operation_id} : {method} {request.url}") for key, value in request.headers.items(): - self.debug_callback(2, f" {key}: {value}") + self._debug_callback(2, f" {key}: {value}") if request.body is not None: - self.debug_callback(3, f"{request.body!r}") - if self.safe_calls_only and method.upper() not in SAFE_METHODS: + self._debug_callback(3, f"{request.body!r}") + if self._safe_calls_only and method.upper() not in SAFE_METHODS: raise UnsafeCallError(_("Call aborted due to safe mode")) try: response: requests.Response = self._session.send(request) @@ -781,13 +795,13 @@ def call( ) except requests.RequestException as e: raise OpenAPIError(str(e)) - self.debug_callback( + self._debug_callback( 1, _("Response: {status_code}").format(status_code=response.status_code) ) for key, value in response.headers.items(): - self.debug_callback(2, f" {key}: {value}") + self._debug_callback(2, f" {key}: {value}") if response.text: - self.debug_callback(3, f"{response.text}") + self._debug_callback(3, f"{response.text}") if "Correlation-ID" in response.headers: self._set_correlation_id(response.headers["Correlation-ID"]) response.raise_for_status() diff --git a/pulp-glue/pyproject.toml b/pulp-glue/pyproject.toml index 13415e30e..ccd1062fc 100644 --- a/pulp-glue/pyproject.toml +++ b/pulp-glue/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Typing :: Typed", ] dependencies = [ + "aiohttp>=3.10.10,<3.11", "packaging>=20.0,<25", "requests>=2.24.0,<2.33", "importlib_resources>=5.4.0,<6.2;python_version<'3.9'", From 034bd3096896b85aabaa6bca7575814b519dd78d Mon Sep 17 00:00:00 2001 From: Matthias Dellweg Date: Tue, 5 Nov 2024 12:58:05 +0100 Subject: [PATCH 2/2] WIP --- pulp-glue/pulp_glue/common/openapi.py | 55 +++++++++++++++++++-------- pulp_cli/config.py | 6 +-- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/pulp-glue/pulp_glue/common/openapi.py b/pulp-glue/pulp_glue/common/openapi.py index 2e2e3666e..226a99e78 100644 --- a/pulp-glue/pulp_glue/common/openapi.py +++ b/pulp-glue/pulp_glue/common/openapi.py @@ -1,16 +1,19 @@ # copyright (c) 2020, Matthias Dellweg # GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt) +import asyncio import base64 import datetime import json import os +import ssl import typing as t from collections import defaultdict from contextlib import suppress from io import BufferedReader from urllib.parse import urljoin +import aiohttp import requests import urllib3 @@ -174,6 +177,9 @@ def __init__( self._safe_calls_only: bool = safe_calls_only self._headers = headers or {} self._verify = verify + # Shall we make that a parameter? + self._ssl_context: t.Optional[t.Union[ssl.SSLContext, bool]] = None + self._auth_provider = auth_provider self._cert = cert self._key = key @@ -225,6 +231,22 @@ def base_url(self) -> str: def cid(self) -> t.Optional[str]: return self._headers.get("Correlation-Id") + @property + def ssl_context(self) -> t.Union[ssl.SSLContext, bool]: + if self._ssl_context is None: + if self._verify is False: + self._ssl_context = False + else: + if isinstance(self._verify, str): + self._ssl_context = ssl.create_default_context(cafile=self._verify) + else: + self._ssl_context = ssl.create_default_context() + if self._cert is not None: + self._ssl_context.load_cert_chain(self._cert, self._key) + # Type inference is failing here. + self._ssl_context = t.cast(t.Union[ssl.SSLContext, bool], self._ssl_context) + return self._ssl_context + def load_api(self, refresh_cache: bool = False) -> None: # TODO: Find a way to invalidate caches on upstream change xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache" @@ -242,7 +264,7 @@ def load_api(self, refresh_cache: bool = False) -> None: self._parse_api(data) except Exception: # Try again with a freshly downloaded version - data = self._download_api() + data = asyncio.run(self._download_api()) self._parse_api(data) # Write to cache as it seems to be valid os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True) @@ -262,28 +284,31 @@ def _parse_api(self, data: bytes) -> None: if method in {"get", "put", "post", "delete", "options", "head", "patch", "trace"} } - def _download_api(self) -> bytes: + async def _download_api(self) -> bytes: try: - response: requests.Response = self._session.get(urljoin(self._base_url, self._doc_path)) - except requests.RequestException as e: + connector = aiohttp.TCPConnector(ssl=self.ssl_context) + async with aiohttp.ClientSession(connector=connector, headers=self._headers) as session: + async with session.get(urljoin(self._base_url, self._doc_path)) as response: + response.raise_for_status() + data = await response.read() + if "Correlation-Id" in response.headers: + self._set_correlation_id(response.headers["Correlation-Id"]) + except aiohttp.ClientError as e: raise OpenAPIError(str(e)) - response.raise_for_status() - if "Correlation-ID" in response.headers: - self._set_correlation_id(response.headers["Correlation-ID"]) - return response.content + return data def _set_correlation_id(self, correlation_id: str) -> None: - if "Correlation-ID" in self._headers: - if self._headers["Correlation-ID"] != correlation_id: + if "Correlation-Id" in self._headers: + if self._headers["Correlation-Id"] != correlation_id: raise OpenAPIError( _("Correlation ID returned from server did not match. {} != {}").format( - self._headers["Correlation-ID"], correlation_id + self._headers["Correlation-Id"], correlation_id ) ) else: - self._headers["Correlation-ID"] = correlation_id + self._headers["Correlation-Id"] = correlation_id # Do it for requests too... - self._session.headers["Correlation-ID"] = correlation_id + self._session.headers["Correlation-Id"] = correlation_id def param_spec( self, operation_id: str, param_type: str, required: bool = False @@ -802,7 +827,7 @@ def call( self._debug_callback(2, f" {key}: {value}") if response.text: self._debug_callback(3, f"{response.text}") - if "Correlation-ID" in response.headers: - self._set_correlation_id(response.headers["Correlation-ID"]) + if "Correlation-Id" in response.headers: + self._set_correlation_id(response.headers["Correlation-Id"]) response.raise_for_status() return self.parse_response(method_spec, response) diff --git a/pulp_cli/config.py b/pulp_cli/config.py index 5c5668bc7..dcf0e4809 100644 --- a/pulp_cli/config.py +++ b/pulp_cli/config.py @@ -70,10 +70,10 @@ click.option("--password", default=None, help=_("Password on pulp server")), click.option("--client-id", default=None, help=_("OAuth2 client ID")), click.option("--client-secret", default=None, help=_("OAuth2 client secret")), - click.option("--cert", default="", help=_("Path to client certificate")), + click.option("--cert", default=None, help=_("Path to client certificate")), click.option( "--key", - default="", + default=None, help=_("Path to client private key. Not required if client cert contains this."), ), click.option("--verify-ssl/--no-verify-ssl", default=True, help=_("Verify SSL connection")), @@ -167,7 +167,7 @@ def validate_config(config: t.Dict[str, t.Any], strict: bool = False) -> None: missing_settings = ( set(SETTINGS) - set(config.keys()) - - {"plugins", "username", "password", "client_id", "client_secret"} + - {"plugins", "username", "password", "client_id", "client_secret", "cert", "key"} ) if missing_settings: errors.append(_("Missing settings: '{}'.").format("','".join(missing_settings)))