Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async experiment #1105

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 98 additions & 59 deletions pulp-glue/pulp_glue/common/openapi.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# copyright (c) 2020, Matthias Dellweg
# GNU General Public License v3.0+ (see LICENSE or https://www.gnu.org/licenses/gpl-3.0.txt)

import asyncio
import base64
import datetime
import json
import os
import ssl
import typing as t
from collections import defaultdict
from contextlib import suppress
from io import BufferedReader
from urllib.parse import urljoin

import aiohttp
import requests
import urllib3

Expand All @@ -30,20 +33,14 @@
class OpenAPIError(Exception):
"""Base Exception for errors related to using the openapi spec."""

pass


class OpenAPIValidationError(OpenAPIError):
"""Exception raised for failed client side validation of parameters or request bodies."""

pass


class UnsafeCallError(OpenAPIError):
"""Exception raised for POST, PUT, PATCH or DELETE calls with `safe_calls_only=True`."""

pass


class AuthProviderBase:
"""
Expand Down Expand Up @@ -174,53 +171,90 @@ def __init__(
user_agent: t.Optional[str] = None,
cid: t.Optional[str] = None,
):
if verify is False:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

self.debug_callback: t.Callable[[int, str], t.Any] = debug_callback or (lambda i, x: None)
self.base_url: str = base_url
self.doc_path: str = doc_path
self.safe_calls_only: bool = safe_calls_only
self.auth_provider = auth_provider

self._session: requests.Session = requests.session()
if self.auth_provider:
if cert or key:
raise OpenAPIError(_("Cannot use both 'auth' and 'cert'."))
else:
if cert and key:
self._session.cert = (cert, key)
elif cert:
self._session.cert = cert
elif key:
raise OpenAPIError(_("Cert is required if key is set."))
self._session.headers.update(
self._debug_callback: t.Callable[[int, str], t.Any] = debug_callback or (lambda i, x: None)
self._base_url: str = base_url
self._doc_path: str = doc_path
self._safe_calls_only: bool = safe_calls_only
self._headers = headers or {}
self._verify = verify
# Shall we make that a parameter?
self._ssl_context: t.Optional[t.Union[ssl.SSLContext, bool]] = None

self._auth_provider = auth_provider
self._cert = cert
self._key = key

self._headers.update(
{
"User-Agent": user_agent or f"Pulp-glue openapi parser ({__version__})",
"Accept": "application/json",
}
)
if cid:
self._session.headers["Correlation-Id"] = cid
if headers:
self._session.headers.update(headers)
self._session.max_redirects = 0
self._headers["Correlation-Id"] = cid

self._setup_session()

self.load_api(refresh_cache=refresh_cache)

def _setup_session(self) -> None:
# This is specific requests library.

if self._verify is False:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

self._session: requests.Session = requests.session()
# Don't redirect, because carrying auth accross redirects is unsafe.
self._session.max_redirects = 0
self._session.headers.update(self._headers)
if self._auth_provider:
if self._cert or self._key:
raise OpenAPIError(_("Cannot use both 'auth' and 'cert'."))
else:
if self._cert and self._key:
self._session.cert = (self._cert, self._key)
elif self._cert:
self._session.cert = self._cert
elif self._key:
raise OpenAPIError(_("Cert is required if key is set."))
session_settings = self._session.merge_environment_settings(
base_url, {}, None, verify, None
self._base_url, {}, None, self._verify, None
)
self._session.verify = session_settings["verify"]
self._session.proxies = session_settings["proxies"]

self.load_api(refresh_cache=refresh_cache)
@property
def base_url(self) -> str:
return self._base_url

@property
def cid(self) -> t.Optional[str]:
return self._headers.get("Correlation-Id")

@property
def ssl_context(self) -> t.Union[ssl.SSLContext, bool]:
if self._ssl_context is None:
if self._verify is False:
self._ssl_context = False
else:
if isinstance(self._verify, str):
self._ssl_context = ssl.create_default_context(cafile=self._verify)
else:
self._ssl_context = ssl.create_default_context()
if self._cert is not None:
self._ssl_context.load_cert_chain(self._cert, self._key)
# Type inference is failing here.
self._ssl_context = t.cast(t.Union[ssl.SSLContext, bool], self._ssl_context)
return self._ssl_context

def load_api(self, refresh_cache: bool = False) -> None:
# TODO: Find a way to invalidate caches on upstream change
xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache"
apidoc_cache: str = os.path.join(
os.path.expanduser(xdg_cache_home),
"squeezer",
(self.base_url + "_" + self.doc_path).replace(":", "_").replace("/", "_") + "api.json",
(self._base_url + "_" + self._doc_path).replace(":", "_").replace("/", "_")
+ "api.json",
)
try:
if refresh_cache:
Expand All @@ -230,7 +264,7 @@ def load_api(self, refresh_cache: bool = False) -> None:
self._parse_api(data)
except Exception:
# Try again with a freshly downloaded version
data = self._download_api()
data = asyncio.run(self._download_api())
self._parse_api(data)
# Write to cache as it seems to be valid
os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
Expand All @@ -250,26 +284,31 @@ def _parse_api(self, data: bytes) -> None:
if method in {"get", "put", "post", "delete", "options", "head", "patch", "trace"}
}

def _download_api(self) -> bytes:
async def _download_api(self) -> bytes:
try:
response: requests.Response = self._session.get(urljoin(self.base_url, self.doc_path))
except requests.RequestException as e:
connector = aiohttp.TCPConnector(ssl=self.ssl_context)
async with aiohttp.ClientSession(connector=connector, headers=self._headers) as session:
async with session.get(urljoin(self._base_url, self._doc_path)) as response:
response.raise_for_status()
data = await response.read()
if "Correlation-Id" in response.headers:
self._set_correlation_id(response.headers["Correlation-Id"])
except aiohttp.ClientError as e:
raise OpenAPIError(str(e))
response.raise_for_status()
if "Correlation-ID" in response.headers:
self._set_correlation_id(response.headers["Correlation-ID"])
return response.content
return data

def _set_correlation_id(self, correlation_id: str) -> None:
if "Correlation-ID" in self._session.headers:
if self._session.headers["Correlation-ID"] != correlation_id:
if "Correlation-Id" in self._headers:
if self._headers["Correlation-Id"] != correlation_id:
raise OpenAPIError(
_("Correlation ID returned from server did not match. {} != {}").format(
self._session.headers["Correlation-ID"], correlation_id
self._headers["Correlation-Id"], correlation_id
)
)
else:
self._session.headers["Correlation-ID"] = correlation_id
self._headers["Correlation-Id"] = correlation_id
# Do it for requests too...
self._session.headers["Correlation-Id"] = correlation_id

def param_spec(
self, operation_id: str, param_type: str, required: bool = False
Expand Down Expand Up @@ -420,7 +459,7 @@ def validate_schema(self, schema: t.Any, name: str, value: t.Any) -> t.Any:
# TextField
# DateTimeField etc.
# ChoiceField
# FielField (binary data)
# FileField (binary data)
value = self.validate_string(schema, name, value)
elif schema_type == "integer":
# IntegerField
Expand Down Expand Up @@ -654,12 +693,12 @@ def render_request(
security: t.List[t.Dict[str, t.List[str]]] = method_spec.get(
"security", self.api_spec.get("security", {})
)
if security and self.auth_provider:
if security and self._auth_provider:
if "Authorization" in self._session.headers:
# Bad idea, but you wanted it that way.
auth = None
else:
auth = self.auth_provider(security, self.api_spec["components"]["securitySchemes"])
auth = self._auth_provider(security, self.api_spec["components"]["securitySchemes"])
else:
# No auth required? Don't provide it.
# No auth_provider available? Hope for the best (should do the trick for cert auth).
Expand Down Expand Up @@ -751,7 +790,7 @@ def call(
names=", ".join(parameters.keys()), operation_id=operation_id
)
)
url = urljoin(self.base_url, path)
url = urljoin(self._base_url, path)

request: requests.PreparedRequest = self.render_request(
path_spec,
Expand All @@ -763,12 +802,12 @@ def call(
validate_body=validate_body,
)

self.debug_callback(1, f"{operation_id} : {method} {request.url}")
self._debug_callback(1, f"{operation_id} : {method} {request.url}")
for key, value in request.headers.items():
self.debug_callback(2, f" {key}: {value}")
self._debug_callback(2, f" {key}: {value}")
if request.body is not None:
self.debug_callback(3, f"{request.body!r}")
if self.safe_calls_only and method.upper() not in SAFE_METHODS:
self._debug_callback(3, f"{request.body!r}")
if self._safe_calls_only and method.upper() not in SAFE_METHODS:
raise UnsafeCallError(_("Call aborted due to safe mode"))
try:
response: requests.Response = self._session.send(request)
Expand All @@ -781,14 +820,14 @@ def call(
)
except requests.RequestException as e:
raise OpenAPIError(str(e))
self.debug_callback(
self._debug_callback(
1, _("Response: {status_code}").format(status_code=response.status_code)
)
for key, value in response.headers.items():
self.debug_callback(2, f" {key}: {value}")
self._debug_callback(2, f" {key}: {value}")
if response.text:
self.debug_callback(3, f"{response.text}")
if "Correlation-ID" in response.headers:
self._set_correlation_id(response.headers["Correlation-ID"])
self._debug_callback(3, f"{response.text}")
if "Correlation-Id" in response.headers:
self._set_correlation_id(response.headers["Correlation-Id"])
response.raise_for_status()
return self.parse_response(method_spec, response)
1 change: 1 addition & 0 deletions pulp-glue/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
"Typing :: Typed",
]
dependencies = [
"aiohttp>=3.10.10,<3.11",
"packaging>=20.0,<25",
"requests>=2.24.0,<2.33",
"importlib_resources>=5.4.0,<6.2;python_version<'3.9'",
Expand Down
6 changes: 3 additions & 3 deletions pulp_cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@
click.option("--password", default=None, help=_("Password on pulp server")),
click.option("--client-id", default=None, help=_("OAuth2 client ID")),
click.option("--client-secret", default=None, help=_("OAuth2 client secret")),
click.option("--cert", default="", help=_("Path to client certificate")),
click.option("--cert", default=None, help=_("Path to client certificate")),
click.option(
"--key",
default="",
default=None,
help=_("Path to client private key. Not required if client cert contains this."),
),
click.option("--verify-ssl/--no-verify-ssl", default=True, help=_("Verify SSL connection")),
Expand Down Expand Up @@ -167,7 +167,7 @@ def validate_config(config: t.Dict[str, t.Any], strict: bool = False) -> None:
missing_settings = (
set(SETTINGS)
- set(config.keys())
- {"plugins", "username", "password", "client_id", "client_secret"}
- {"plugins", "username", "password", "client_id", "client_secret", "cert", "key"}
)
if missing_settings:
errors.append(_("Missing settings: '{}'.").format("','".join(missing_settings)))
Expand Down
Loading