diff --git a/LICENSE b/LICENSE index bdc6f31..5a9ac0b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2017 HDE Inc +Copyright 2019 HDE Inc Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 15b0551..ba96616 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # AAPNS -[![CircleCI](https://circleci.com/gh/hde/aapns/tree/master.svg?style=svg)](https://circleci.com/gh/hde/aapns/tree/master) +[![CircleCI](https://circleci.com/gh/HDE/aapns/tree/master.svg?style=svg)](https://circleci.com/gh/HDE/aapns/tree/master) [![Documentation Status](https://readthedocs.org/projects/aapns/badge/?version=latest)](http://aapns.readthedocs.io/en/latest/?badge=latest) Asynchronous Apple Push Notification Service client. diff --git a/docs/conf.py b/docs/conf.py index faba5ac..2fd2bb2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,30 +34,30 @@ extensions = [] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'AAPNS' -copyright = '2017, HDE Inc' -author = 'HDE Inc' +project = "AAPNS" +copyright = "2019, HDE Inc" +author = "HDE Inc" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '0.1' +version = "19.1" # The full version, including alpha/beta/rc tags. -release = '0.1' +release = "19.1" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -69,10 +69,10 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = False @@ -83,7 +83,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = "alabaster" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -94,13 +94,13 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'AAPNSdoc' +htmlhelp_basename = "AAPNSdoc" # -- Options for LaTeX output --------------------------------------------- @@ -109,15 +109,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -127,8 +124,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'AAPNS.tex', 'AAPNS Documentation', - 'HDE Inc', 'manual'), + (master_doc, "AAPNS.tex", "AAPNS Documentation", "HDE Inc", "manual") ] @@ -136,10 +132,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'aapns', 'AAPNS Documentation', - [author], 1) -] +man_pages = [(master_doc, "aapns", "AAPNS Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -148,10 +141,13 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'AAPNS', 'AAPNS Documentation', - author, 'AAPNS', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "AAPNS", + "AAPNS Documentation", + author, + "AAPNS", + "One line description of project.", + "Miscellaneous", + ) ] - - - diff --git a/setup.py b/setup.py index 7c06067..18b8a6a 100644 --- a/setup.py +++ b/setup.py @@ -1,30 +1,36 @@ +import os from setuptools import setup, find_packages +with open(os.path.relpath(f"{__file__}/../README.md")) as f: + readme = f.read() setup( - version='1.0.0.dev8', - name='aapns', - package_dir={'': 'src'}, - packages=find_packages(where='src'), - python_requires='>=3.6', - install_requires=[ - 'h2>3', - 'attrs', - 'structlog', - ], - extras_require={ - 'cli': ['click'] - }, - entry_points={ - 'console_scripts': [ - 'aapns = aapns.cli:main' - ] + version="19.1", + name="aapns", + package_dir={"": "src"}, + packages=find_packages(where="src"), + python_requires=">=3.6", + install_requires=["h2>3", "attrs", "structlog"], + extras_require={"cli": ["click"]}, + entry_points={"console_scripts": ["aapns = aapns.cli:main"]}, + author="Jonas Obrist", + author_email="ojiidotch@gmail.com", + description="Asynchronous Apple Push Notification Service client", + long_description=readme, + long_description_content_type="text/markdown", + url="https://github.com/HDE/aapns", + project_urls={ + "Documentation": "https://aapns.readthedocs.io/en/latest/", + "Code": "https://github.com/HDE/aapns", + "Issue tracker": "https://github.com/HDE/aapns/issues", }, - license='APLv2', + license="APLv2", classifiers=[ - 'Framework :: AsyncIO', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3 :: Only', - 'License :: OSI Approved :: Apache Software License' - ] + "Framework :: AsyncIO", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3 :: Only", + "License :: OSI Approved :: Apache Software License", + ], ) diff --git a/src/aapns/__init__.py b/src/aapns/__init__.py index 1b3d4a9..689ad8e 100644 --- a/src/aapns/__init__.py +++ b/src/aapns/__init__.py @@ -1,3 +1,9 @@ -from .config import Production, ProductionAltPort, Development, DevelopmentAltPort, Priority +from .config import ( + Production, + ProductionAltPort, + Development, + DevelopmentAltPort, + Priority, +) from .api import connect from .models import Notification, Alert, Localized diff --git a/src/aapns/api.py b/src/aapns/api.py index 3ec6b37..f9176d8 100644 --- a/src/aapns/api.py +++ b/src/aapns/api.py @@ -38,15 +38,13 @@ async def __aenter__(self) -> APNSProtocol: if not self.do_connect: raise errors.Disconnected() self.connection = APNSProtocol( - self.server.host, - self.logger, - self.clear_connection + self.server.host, self.logger, self.clear_connection ) await loop.create_connection( lambda: self.connection, self.server.host, self.server.port, - ssl=self.ssl_context + ssl=self.ssl_context, ) return self.connection @@ -65,42 +63,44 @@ def close(self): self.connection.close() -def encode_request(*, - server: config.Server, - token: str, - notification: models.Notification, - apns_id: Optional[str]=None, - expiration: Optional[int]=None, - priority: config.Priority=config.Priority.normal, - topic: Optional[str]=None, - collapse_id: Optional[str]=None) -> Tuple[Headers, bytes]: +def encode_request( + *, + server: config.Server, + token: str, + notification: models.Notification, + apns_id: Optional[str] = None, + expiration: Optional[int] = None, + priority: config.Priority = config.Priority.normal, + topic: Optional[str] = None, + collapse_id: Optional[str] = None, +) -> Tuple[Headers, bytes]: request_body = notification.encode() request_headers = [ - (':method', 'POST'), - (':authority', server.host), - (':scheme', 'https'), - (':path', f'/3/device/{token}'), - ('content-length', str(len(request_body))), - ('apns-priority', str(priority.value)), + (":method", "POST"), + (":authority", server.host), + (":scheme", "https"), + (":path", f"/3/device/{token}"), + ("content-length", str(len(request_body))), + ("apns-priority", str(priority.value)), ] if apns_id: - request_headers.append(('apns-id', apns_id)) + request_headers.append(("apns-id", apns_id)) if expiration: - request_headers.append(('apns-expiration', str(expiration))) + request_headers.append(("apns-expiration", str(expiration))) if topic: - request_headers.append(('apns-topic', topic)) + request_headers.append(("apns-topic", topic)) if collapse_id: - request_headers.append(('apns-collapse-id', collapse_id)) + request_headers.append(("apns-collapse-id", collapse_id)) return request_headers, request_body def handle_response(response: connection.Response) -> str: - response_id = response.headers.get('apns-id', '') + response_id = response.headers.get("apns-id", "") if response.status != 200: try: - reason = json.loads(response.body)['reason'] + reason = json.loads(response.body)["reason"] except: reason = response.body exc = errors.get(reason, response_id) @@ -113,6 +113,7 @@ def ensure_task(coro: Coroutine) -> Coroutine: @wraps(coro) async def wrapper(*args, **kwargs): return await asyncio.get_event_loop().create_task(coro(*args, **kwargs)) + return wrapper @@ -122,15 +123,17 @@ class APNS: server: config.Server = attr.ib() @ensure_task - async def send_notification(self, - token: str, - notification: models.Notification, - *, - apns_id: Optional[str]=None, - expiration: Optional[int]=None, - priority: config.Priority=config.Priority.normal, - topic: Optional[str]=None, - collapse_id: Optional[str]=None) -> str: + async def send_notification( + self, + token: str, + notification: models.Notification, + *, + apns_id: Optional[str] = None, + expiration: Optional[int] = None, + priority: config.Priority = config.Priority.normal, + topic: Optional[str] = None, + collapse_id: Optional[str] = None, + ) -> str: request_headers, request_body = encode_request( token=token, notification=notification, @@ -143,10 +146,7 @@ async def send_notification(self, ) async with self.connector as conn: - response = await conn.request( - headers=request_headers, - body=request_body, - ) + response = await conn.request(headers=request_headers, body=request_body) return handle_response(response) @@ -154,18 +154,20 @@ async def close(self): self.connector.close() -async def connect(client_cert_path: str, - server: config.Server, - *, - ssl_context: Optional[ssl.SSLContext]=None, - logger: Optional[BoundLogger]=None, - auto_reconnect: bool=False, - timeout: Optional[float]=None) -> APNS: +async def connect( + client_cert_path: str, + server: config.Server, + *, + ssl_context: Optional[ssl.SSLContext] = None, + logger: Optional[BoundLogger] = None, + auto_reconnect: bool = False, + timeout: Optional[float] = None, +) -> APNS: if ssl_context is None: ssl_context: ssl.SSLContext = ssl.create_default_context() - ssl_context.set_alpn_protocols(['h2']) + ssl_context.set_alpn_protocols(["h2"]) try: - ssl_context.set_npn_protocols(['h2']) + ssl_context.set_npn_protocols(["h2"]) except AttributeError: pass ssl_context.load_cert_chain(client_cert_path) diff --git a/src/aapns/cli.py b/src/aapns/cli.py index 26f111e..d794515 100644 --- a/src/aapns/cli.py +++ b/src/aapns/cli.py @@ -9,14 +9,8 @@ SERVERS: Dict[bool, Dict[bool, config.Server]] = { - True: { - True: config.ProductionAltPort, - False: config.Production - }, - False: { - True: config.DevelopmentAltPort, - False: config.Development - } + True: {True: config.ProductionAltPort, False: config.Production}, + False: {True: config.DevelopmentAltPort, False: config.Development}, } @@ -35,9 +29,7 @@ class Context: async def do_send(context: Context, notification: models.Notification) -> str: conn = await connect( - context.cert, - context.server, - logger=get_logger() if context.verbose else None + context.cert, context.server, logger=get_logger() if context.verbose else None ) resp_id = await conn.send_notification( context.token, @@ -46,7 +38,7 @@ async def do_send(context: Context, notification: models.Notification) -> str: expiration=context.expiration, priority=context.priority, topic=context.topic, - collapse_id=context.collapse_id + collapse_id=context.collapse_id, ) await conn.close() return resp_id @@ -58,18 +50,30 @@ def send(context: Context, notification: models.Notification): @click.group() -@click.argument('token') -@click.option('--client-cert-path', envvar='CLIENT_CERT_PATH') -@click.option('--prod', is_flag=True, default=False) -@click.option('--alt-port', is_flag=True, default=False) -@click.option('--expiration', default=None) -@click.option('--immediately', is_flag=True, default=False) -@click.option('--topic', default=None) -@click.option('--collapse-id', default=None) -@click.option('--apns-id', default=None) -@click.option('--verbose', is_flag=True, default=False) +@click.argument("token") +@click.option("--client-cert-path", envvar="CLIENT_CERT_PATH") +@click.option("--prod", is_flag=True, default=False) +@click.option("--alt-port", is_flag=True, default=False) +@click.option("--expiration", default=None) +@click.option("--immediately", is_flag=True, default=False) +@click.option("--topic", default=None) +@click.option("--collapse-id", default=None) +@click.option("--apns-id", default=None) +@click.option("--verbose", is_flag=True, default=False) @click.pass_context -def main(ctx, token, client_cert_path, prod, alt_port, expiration, immediately, topic, collapse_id, apns_id, verbose): +def main( + ctx, + token, + client_cert_path, + prod, + alt_port, + expiration, + immediately, + topic, + collapse_id, + apns_id, + verbose, +): ctx.obj = Context( token=token, cert=client_cert_path, @@ -83,26 +87,21 @@ def main(ctx, token, client_cert_path, prod, alt_port, expiration, immediately, ) -@main.command('simple') -@click.argument('body') -@click.option('--title', default=None) +@main.command("simple") +@click.argument("body") +@click.option("--title", default=None) @click.pass_context def simple(ctx, title, body): - notification = models.Notification( - alert=models.Alert( - title=title, - body=body, - ) - ) + notification = models.Notification(alert=models.Alert(title=title, body=body)) send(ctx.obj, notification) -@main.command('localized') -@click.argument('body') -@click.option('--body-args', multiple=True) -@click.option('--title', default=None) -@click.option('--title-args', multiple=True) -@click.option('--badge', type=click.INT) +@main.command("localized") +@click.argument("body") +@click.option("--body-args", multiple=True) +@click.option("--title", default=None) +@click.option("--title-args", multiple=True) +@click.option("--badge", type=click.INT) @click.pass_context def localized(ctx, title, body, title_args, body_args, badge): notification = models.Notification( diff --git a/src/aapns/config.py b/src/aapns/config.py index 509d5ea..c2b4ce7 100644 --- a/src/aapns/config.py +++ b/src/aapns/config.py @@ -9,9 +9,9 @@ class Server: port = attr.ib() -Production = Server('api.push.apple.com', 443) +Production = Server("api.push.apple.com", 443) ProductionAltPort = attr.evolve(Production, port=2197) -Development = Server('api.development.push.apple.com', 443) +Development = Server("api.development.push.apple.com", 443) DevelopmentAltPort = attr.evolve(Development, port=2197) diff --git a/src/aapns/connection.py b/src/aapns/connection.py index a3fa3a6..803352c 100644 --- a/src/aapns/connection.py +++ b/src/aapns/connection.py @@ -5,10 +5,7 @@ import attr from h2.connection import H2Connection -from h2.events import ( - ResponseReceived, DataReceived, StreamEnded, - StreamReset, -) +from h2.events import ResponseReceived, DataReceived, StreamEnded, StreamReset from hyperframe.frame import SettingsFrame from structlog import wrap_logger, PrintLogger, BoundLogger @@ -24,14 +21,13 @@ class PendingResponse: logger: BoundLogger = attr.ib() future: Future = attr.ib(default=attr.Factory(Future)) headers: List[Tuple[bytes, bytes]] = attr.ib(default=None) - body: bytes = attr.ib(default=b'') + body: bytes = attr.ib(default=b"") - def to_response(self) -> 'Response': + def to_response(self) -> "Response": headers = { - key.decode('utf-8'): value.decode('utf-8') - for key, value in self.headers + key.decode("utf-8"): value.decode("utf-8") for key, value in self.headers } - status = int(headers[':status']) + status = int(headers[":status"]) return Response(status, headers, self.body) @@ -43,22 +39,25 @@ class Response: class APNSProtocol(Protocol): - def __init__(self, authority: str, logger: Optional[BoundLogger], on_close: Callable[[], None]): + def __init__( + self, + authority: str, + logger: Optional[BoundLogger], + on_close: Callable[[], None], + ): self.authority = authority - self.logger = logger or wrap_logger(PrintLogger(open(os.devnull, 'w'))) + self.logger = logger or wrap_logger(PrintLogger(open(os.devnull, "w"))) self.on_close = on_close self.conn = H2Connection() self.transport: Union[Transport, None] = None self.responses: Dict[int, PendingResponse] = {} - async def request(self, - headers: List[Tuple[str, str]], - body: bytes) -> Response: + async def request(self, headers: List[Tuple[str, str]], body: bytes) -> Response: stream_id = self.conn.get_next_available_stream_id() logger = self.logger.bind(stream_id=stream_id) pending = self.responses[stream_id] = PendingResponse(logger=logger) pending.future.add_done_callback(partial(self.responses.pop, stream_id)) - logger.debug('request', headers=headers, body=body) + logger.debug("request", headers=headers, body=body) self.conn.send_headers(stream_id, headers) self.conn.send_data(stream_id, body, end_stream=True) if self.transport is not None: @@ -77,7 +76,7 @@ def close(self): self.on_close() def connection_made(self, transport: Transport): - self.logger.debug('connected') + self.logger.debug("connected") self.transport = transport self.conn.initiate_connection() # This reproduces the error in #396, by changing the header table size. @@ -85,7 +84,7 @@ def connection_made(self, transport: Transport): self.transport.write(self.conn.data_to_send()) def connection_lost(self, exc): - self.logger.debug('disconnected') + self.logger.debug("disconnected") self.transport = None for pending in self.responses.values(): pending.future.set_exception(Disconnected()) @@ -104,52 +103,44 @@ def data_received(self, data: bytes): elif isinstance(event, StreamReset): self.reset_stream(event.stream_id) else: - self.logger.debug('ignored', h2event=event) + self.logger.debug("ignored", h2event=event) data = self.conn.data_to_send() if data: self.transport.write(data) - def handle_response(self, response_headers: List[Tuple[bytes, bytes]], stream_id: int): + def handle_response( + self, response_headers: List[Tuple[bytes, bytes]], stream_id: int + ): if stream_id in self.responses: self.responses[stream_id].logger.debug( - 'response-headers', - headers=response_headers + "response-headers", headers=response_headers ) self.responses[stream_id].headers = response_headers else: self.logger.warning( - 'unexpected-response', - stream_id=stream_id, - headers=response_headers + "unexpected-response", stream_id=stream_id, headers=response_headers ) def handle_data(self, data: bytes, stream_id: int): if stream_id in self.responses: - self.responses[stream_id].logger.debug( - 'response-body', - data=data - ) + self.responses[stream_id].logger.debug("response-body", data=data) self.responses[stream_id].body += data else: - self.logger.warning( - 'unexpected-data', - stream_id=stream_id, - data=data - ) + self.logger.warning("unexpected-data", stream_id=stream_id, data=data) def end_stream(self, stream_id: int): if stream_id in self.responses: response = self.responses[stream_id] - response.logger.debug('end-stream') + response.logger.debug("end-stream") response.future.set_result(True) else: - self.logger.warning('unexpected-end-stream', stream_id=stream_id) + self.logger.warning("unexpected-end-stream", stream_id=stream_id) def reset_stream(self, stream_id: int): if stream_id in self.responses: response = self.responses[stream_id] - response.logger.debug('reset-stream') + response.logger.debug("reset-stream") response.future.set_exception(errors.StreamResetError()) else: - self.logger.warning('unexpected-reset-stream', stream_id=stream_id) + self.logger.warning("unexpected-reset-stream", stream_id=stream_id) diff --git a/src/aapns/errors.py b/src/aapns/errors.py index 3a4d366..4d0b85f 100644 --- a/src/aapns/errors.py +++ b/src/aapns/errors.py @@ -23,45 +23,45 @@ def __init__(self, reason: str, apns_id: str): class UnknownResponseError(ResponseError): - codename = '!unknown' + codename = "!unknown" CODES: Dict[str, Type[ResponseError]] = {} def create(codename: str) -> Type[ResponseError]: - cls: Type[ResponseError] = type(codename, (ResponseError,), {'codename': codename}) + cls: Type[ResponseError] = type(codename, (ResponseError,), {"codename": codename}) CODES[codename] = cls return cls -BadCollapseId = create('BadCollapseId') -BadDeviceToken = create('BadDeviceToken') -BadExpirationDate = create('BadExpirationDate') -BadMessageId = create('BadMessageId') -BadPriority = create('BadPriority') -BadTopic = create('BadTopic') -DeviceTokenNotForTopic = create('DeviceTokenNotForTopic') -DuplicateHeaders = create('DuplicateHeaders') -IdleTimeout = create('IdleTimeout') -MissingDeviceToken = create('MissingDeviceToken') -MissingTopic = create('MissingTopic') -PayloadEmpty = create('PayloadEmpty') -BadCertificate = create('BadCertificate') -BadCertificateEnvironment = create('BadCertificateEnvironment') -ExpiredProviderToken = create('ExpiredProviderToken') -Forbidden = create('Forbidden') -InvalidProviderToken = create('InvalidProviderToken') -MissingProviderToken = create('MissingProviderToken') -BadPath = create('BadPath') -MethodNotAllowed = create('MethodNotAllowed') -Unregistered = create('Unregistered') -PayloadTooLarge = create('PayloadTooLarge') -TooManyProviderTokenUpdates = create('TooManyProviderTokenUpdates') -TooManyRequests = create('TooManyRequests') -InternalServerError = create('InternalServerError') -ServiceUnavailable = create('ServiceUnavailable') -Shutdown = create('Shutdown') +BadCollapseId = create("BadCollapseId") +BadDeviceToken = create("BadDeviceToken") +BadExpirationDate = create("BadExpirationDate") +BadMessageId = create("BadMessageId") +BadPriority = create("BadPriority") +BadTopic = create("BadTopic") +DeviceTokenNotForTopic = create("DeviceTokenNotForTopic") +DuplicateHeaders = create("DuplicateHeaders") +IdleTimeout = create("IdleTimeout") +MissingDeviceToken = create("MissingDeviceToken") +MissingTopic = create("MissingTopic") +PayloadEmpty = create("PayloadEmpty") +BadCertificate = create("BadCertificate") +BadCertificateEnvironment = create("BadCertificateEnvironment") +ExpiredProviderToken = create("ExpiredProviderToken") +Forbidden = create("Forbidden") +InvalidProviderToken = create("InvalidProviderToken") +MissingProviderToken = create("MissingProviderToken") +BadPath = create("BadPath") +MethodNotAllowed = create("MethodNotAllowed") +Unregistered = create("Unregistered") +PayloadTooLarge = create("PayloadTooLarge") +TooManyProviderTokenUpdates = create("TooManyProviderTokenUpdates") +TooManyRequests = create("TooManyRequests") +InternalServerError = create("InternalServerError") +ServiceUnavailable = create("ServiceUnavailable") +Shutdown = create("Shutdown") def get(reason: str, apns_id: str) -> ResponseError: diff --git a/src/aapns/models.py b/src/aapns/models.py index 28394c5..39822cd 100644 --- a/src/aapns/models.py +++ b/src/aapns/models.py @@ -6,10 +6,10 @@ def str_list(instance: object, attr: attr.Attribute, value: Any) -> None: if not isinstance(value, list): - raise TypeError('Must be list of strings') + raise TypeError("Must be list of strings") for arg in value: if not isinstance(arg, str): - raise TypeError('Must be list of strings') + raise TypeError("Must be list of strings") @attr.s @@ -18,15 +18,12 @@ class Localized: args = attr.ib(default=None, validator=attr.validators.optional(str_list)) -def maybe_localized(thing: Union[str, Localized], - nonloc: str, - lockey: str, - locarg: str) -> Dict[str, Union[str, List[str]]]: +def maybe_localized( + thing: Union[str, Localized], nonloc: str, lockey: str, locarg: str +) -> Dict[str, Union[str, List[str]]]: if isinstance(thing, Localized): attr.validate(thing) - localized = { - lockey: thing.key - } + localized = {lockey: thing.key} if thing.args: localized[locarg] = thing.args return localized @@ -39,37 +36,31 @@ class Alert: body = attr.ib(validator=attr.validators.instance_of((str, Localized))) title = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of((str, Localized))) + validator=attr.validators.optional( + attr.validators.instance_of((str, Localized)) + ), ) action_loc_key = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(str)) + validator=attr.validators.optional(attr.validators.instance_of(str)), ) launch_image = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(str)) + validator=attr.validators.optional(attr.validators.instance_of(str)), ) def get_dict(self) -> Dict[str, Any]: attr.validate(self) alert = {} if self.title: - alert.update(maybe_localized( - self.title, - 'title', - 'title-loc-key', - 'title-loc-args' - )) - alert.update(maybe_localized( - self.body, - 'body', - 'loc-key', - 'loc-args' - )) + alert.update( + maybe_localized(self.title, "title", "title-loc-key", "title-loc-args") + ) + alert.update(maybe_localized(self.body, "body", "loc-key", "loc-args")) if self.action_loc_key: - alert['action-loc-key'] = self.action_loc_key + alert["action-loc-key"] = self.action_loc_key if self.launch_image: - alert['launch-image'] = self.launch_image + alert["launch-image"] = self.launch_image return alert @@ -78,57 +69,47 @@ class Notification: alert = attr.ib(validator=attr.validators.instance_of(Alert)) badge = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(int)) + validator=attr.validators.optional(attr.validators.instance_of(int)), ) sound = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(str)) + validator=attr.validators.optional(attr.validators.instance_of(str)), ) content_available = attr.ib( - default=False, - validator=attr.validators.instance_of(bool) + default=False, validator=attr.validators.instance_of(bool) ) category = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(str)) + validator=attr.validators.optional(attr.validators.instance_of(str)), ) thread_id = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(str)) + validator=attr.validators.optional(attr.validators.instance_of(str)), ) extra = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(dict)) + validator=attr.validators.optional(attr.validators.instance_of(dict)), ) def get_dict(self) -> Dict[str, Any]: attr.validate(self) - apns = { - 'alert': self.alert.get_dict() - } + apns = {"alert": self.alert.get_dict()} if self.badge: - apns['badge'] = self.badge + apns["badge"] = self.badge if self.sound: - apns['sound'] = self.sound + apns["sound"] = self.sound if self.content_available: - apns['content-available'] = 1 + apns["content-available"] = 1 if self.category: - apns['category'] = self.category + apns["category"] = self.category if self.thread_id: - apns['thread-id'] = self.thread_id - raw = { - 'aps': apns - } + apns["thread-id"] = self.thread_id + raw = {"aps": apns} if self.extra: raw.update(self.extra) return raw def encode(self) -> bytes: raw = self.get_dict() - s = json.dumps( - raw, - ensure_ascii=False, - separators=(',', ':'), - sort_keys=True - ) - return s.encode('utf-8') + s = json.dumps(raw, ensure_ascii=False, separators=(",", ":"), sort_keys=True) + return s.encode("utf-8") diff --git a/tests/conftest.py b/tests/conftest.py index 24ae301..797a3e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,4 +11,3 @@ def pytest_pyfunc_call(pyfuncitem): for record in records: message = str(record.message) assert NEVER_AWAITED_RE.search(message) is None, message - diff --git a/tests/fake_apns_server.py b/tests/fake_apns_server.py index 8779813..f4c4df1 100644 --- a/tests/fake_apns_server.py +++ b/tests/fake_apns_server.py @@ -35,37 +35,35 @@ class Response: @attr.s class Request: headers = attr.ib() - body = attr.ib(default=b'', repr=False) + body = attr.ib(default=b"", repr=False) def handle(self, server): headers = dict(self.headers) - if headers[b':method'] != b'POST': - return Response([ - (b':status', b'405'), - (b'content-length', b'0'), - ]) - apns_id = headers.get(b'apns-id', str(uuid.uuid4()).upper()) - token = headers[b':path'][len(b'/3/device/'):].decode('ascii') + if headers[b":method"] != b"POST": + return Response([(b":status", b"405"), (b"content-length", b"0")]) + apns_id = headers.get(b"apns-id", str(uuid.uuid4()).upper()) + token = headers[b":path"][len(b"/3/device/") :].decode("ascii") payload = json.loads(self.body) if token not in server.devices: - data = json.dumps({ - 'apns-id': apns_id, - 'reason': BadDeviceToken.codename - }) - return Response([ - (b':status', b'400'), - (b'content-length', str(len(data)).encode('ascii')) - ], data.encode('utf-8')) + data = json.dumps({"apns-id": apns_id, "reason": BadDeviceToken.codename}) + return Response( + [ + (b":status", b"400"), + (b"content-length", str(len(data)).encode("ascii")), + ], + data.encode("utf-8"), + ) else: server.devices[token].append(payload) - data = json.dumps({ - 'apns-id': apns_id, - }) - return Response([ - (b':status', b'200'), - (b'apns-id', apns_id.encode('ascii')), - (b'content-length', b'0') - ], b'') + data = json.dumps({"apns-id": apns_id}) + return Response( + [ + (b":status", b"200"), + (b"apns-id", apns_id.encode("ascii")), + (b"content-length", b"0"), + ], + b"", + ) def coroify(coro): @@ -74,6 +72,7 @@ def func(self, *args, **kwargs): task = asyncio.ensure_future(coro(self, *args, **kwargs)) self.pending.append(task) task.add_done_callback(lambda *args: self.pending.remove(task)) + return func @@ -87,14 +86,16 @@ def __init__(self, server): @coroify async def connection_made(self, transport): - self.server.logger.info('connection-made', server=self.server) + self.server.logger.info("connection-made", server=self.server) self.transport = transport self.conn.initiate_connection() await asyncio.sleep(self.server.lag) self.transport.write(self.conn.data_to_send()) def connection_lost(self, exc): - self.server.logger.info('connection-lost', protocol=self, server=self.server, exc=exc) + self.server.logger.info( + "connection-lost", protocol=self, server=self.server, exc=exc + ) self.server.connections.remove(self) @coroify @@ -142,12 +143,12 @@ class FakeServer: connections = attr.ib(default=attr.Factory(list)) async def stop(self): - self.logger.msg('stopping', server=self.server) + self.logger.msg("stopping", server=self.server) for connection in self.connections: await connection.close() self.server.close() await self.server.wait_closed() - self.logger.msg('stopped', server=self.server) + self.logger.msg("stopped", server=self.server) def create_device(self): device_id = secrets.token_hex(32) @@ -167,11 +168,11 @@ def create_protocol(self): async def start_fake_apns_server(port=0, database=None, lag=0): database = {} if database is None else database private_key = gen_private_key() - certificate = gen_certificate(private_key, 'server') + certificate = gen_certificate(private_key, "server") with tempfile.TemporaryDirectory() as workspace: - key_path = os.path.join(workspace, 'key.pem') - cert_path = os.path.join(workspace, 'cert.pem') - with open(key_path, 'wb') as fobj: + key_path = os.path.join(workspace, "key.pem") + cert_path = os.path.join(workspace, "cert.pem") + with open(key_path, "wb") as fobj: fobj.write( private_key.private_bytes( encoding=serialization.Encoding.PEM, @@ -179,12 +180,8 @@ async def start_fake_apns_server(port=0, database=None, lag=0): encryption_algorithm=serialization.NoEncryption(), ) ) - with open(cert_path, 'wb') as fobj: - fobj.write( - certificate.public_bytes( - encoding=serialization.Encoding.PEM, - ) - ) + with open(cert_path, "wb") as fobj: + fobj.write(certificate.public_bytes(encoding=serialization.Encoding.PEM)) ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context.load_cert_chain(certfile=cert_path, keyfile=key_path) ssl_context.set_alpn_protocols(["h2"]) @@ -193,10 +190,7 @@ async def start_fake_apns_server(port=0, database=None, lag=0): loop = asyncio.get_event_loop() server = await loop.create_server( - fake_server.create_protocol, - '127.0.0.1', - port, - ssl=ssl_context + fake_server.create_protocol, "127.0.0.1", port, ssl=ssl_context ) fake_server.address = server.sockets[0].getsockname() fake_server.server = server @@ -210,8 +204,8 @@ def main(): async def helper(): async with start_fake_apns_server() as server: device_id = server.create_device() - print(f'Serving on {server.address}') - print(f'Device ID: {device_id}') + print(f"Serving on {server.address}") + print(f"Device ID: {device_id}") while True: await asyncio.sleep(1) @@ -221,5 +215,5 @@ async def helper(): return -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tests/fake_client_cert.py b/tests/fake_client_cert.py index baa9f18..dfa14a5 100644 --- a/tests/fake_client_cert.py +++ b/tests/fake_client_cert.py @@ -10,65 +10,45 @@ BACKEND = default_backend() + def gen_private_key() -> rsa.RSAPrivateKeyWithSerialization: return rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=BACKEND, + public_exponent=65537, key_size=2048, backend=BACKEND ) -def gen_certificate(key: rsa.RSAPrivateKey, - common_name: str, - *, - issuer: Optional[str]=None, - sign_key: Optional[rsa.RSAPrivateKey]=None) -> x509.Certificate: +def gen_certificate( + key: rsa.RSAPrivateKey, + common_name: str, + *, + issuer: Optional[str] = None, + sign_key: Optional[rsa.RSAPrivateKey] = None +) -> x509.Certificate: now = datetime.datetime.utcnow() - return x509.CertificateBuilder().subject_name( - x509.Name([ - x509.NameAttribute(NameOID.COMMON_NAME, common_name), - ]) - ).issuer_name( - x509.Name([ - x509.NameAttribute( - NameOID.COMMON_NAME, - issuer or common_name - ) - ]) - ).not_valid_before( - now - ).not_valid_after( - now + datetime.timedelta(seconds=86400) - ).serial_number( - x509.random_serial_number() - ).public_key( - key.public_key() - ).add_extension( - x509.BasicConstraints(ca=True, path_length=0), critical=True - ).sign( - private_key=sign_key or key, - algorithm=hashes.SHA256(), - backend=BACKEND + return ( + x509.CertificateBuilder() + .subject_name(x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, common_name)])) + .issuer_name( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, issuer or common_name)]) + ) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(seconds=86400)) + .serial_number(x509.random_serial_number()) + .public_key(key.public_key()) + .add_extension(x509.BasicConstraints(ca=True, path_length=0), critical=True) + .sign(private_key=sign_key or key, algorithm=hashes.SHA256(), backend=BACKEND) ) def create_client_cert() -> bytes: ca_key = gen_private_key() - ca_cert = gen_certificate( - ca_key, - 'certificate_authority', - ) + ca_cert = gen_certificate(ca_key, "certificate_authority") client_key = gen_private_key() client_cert = gen_certificate( - client_key, - 'client', - issuer='certificate_authority', - sign_key=ca_key, + client_key, "client", issuer="certificate_authority", sign_key=ca_key ) return client_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), - ) + client_cert.public_bytes( - encoding=serialization.Encoding.PEM, - ) + ) + client_cert.public_bytes(encoding=serialization.Encoding.PEM) diff --git a/tests/test_connection.py b/tests/test_connection.py index 60c931f..ed15dd0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -31,8 +31,8 @@ async def auto_close(event_loop): @pytest.fixture def client_cert_path(): with TemporaryDirectory() as workspace: - path = os.path.join(workspace, 'cert.pem') - with open(path, 'wb') as fobj: + path = os.path.join(workspace, "cert.pem") + with open(path, "wb") as fobj: fobj.write(create_client_cert()) yield path @@ -46,7 +46,7 @@ async def client(client_cert_path): ssl_context=non_verifying_context, auto_reconnect=True, timeout=10, - logger=get_logger() + logger=get_logger(), ) try: yield apns @@ -59,23 +59,25 @@ async def test_auto_reconnect(auto_close, client_cert_path): async with start_fake_apns_server(database=database) as server: config = Server(*server.address) port = server.address[1] - apns = auto_close(await connect( - client_cert_path, - config, - ssl_context=non_verifying_context, - auto_reconnect=True, - timeout=10, - logger=get_logger() - )) + apns = auto_close( + await connect( + client_cert_path, + config, + ssl_context=non_verifying_context, + auto_reconnect=True, + timeout=10, + logger=get_logger(), + ) + ) device_id = server.create_device() - await apns.send_notification(device_id, Notification(Alert('test1'))) + await apns.send_notification(device_id, Notification(Alert("test1"))) assert len(server.get_notifications(device_id)) == 1 with pytest.raises(Disconnected): - await apns.send_notification(device_id, Notification(Alert('test2'))) + await apns.send_notification(device_id, Notification(Alert("test2"))) async with start_fake_apns_server(port, database) as server: - await apns.send_notification(device_id, Notification(Alert('test3'))) + await apns.send_notification(device_id, Notification(Alert("test3"))) assert len(server.get_notifications(device_id)) == 2 await apns.close() @@ -85,26 +87,24 @@ async def test_no_auto_reconnect(auto_close, client_cert_path): async with start_fake_apns_server(database=database) as server: config = Server(*server.address) port = server.address[1] - apns = auto_close(await connect( - client_cert_path, - config, - ssl_context=non_verifying_context, - auto_reconnect=False, - timeout=10 - )) + apns = auto_close( + await connect( + client_cert_path, + config, + ssl_context=non_verifying_context, + auto_reconnect=False, + timeout=10, + ) + ) device_id = server.create_device() - await apns.send_notification(device_id, Notification(Alert('test1'))) + await apns.send_notification(device_id, Notification(Alert("test1"))) assert len(server.get_notifications(device_id)) == 1 with pytest.raises(Disconnected): - await apns.send_notification( - device_id, - Notification(Alert('test2')) - ) + await apns.send_notification(device_id, Notification(Alert("test2"))) - future = ensure_future(apns.send_notification( - device_id, - Notification(Alert('test3'))) + future = ensure_future( + apns.send_notification(device_id, Notification(Alert("test3"))) ) async with start_fake_apns_server(port, database) as server: with pytest.raises(Disconnected): @@ -116,18 +116,20 @@ async def test_slow_server(auto_close, client_cert_path): database = {} async with start_fake_apns_server(database=database, lag=0.5) as server: config = Server(*server.address) - apns = auto_close(await connect( - client_cert_path, - config, - ssl_context=non_verifying_context, - auto_reconnect=False, - timeout=0.1 - )) + apns = auto_close( + await connect( + client_cert_path, + config, + ssl_context=non_verifying_context, + auto_reconnect=False, + timeout=0.1, + ) + ) device_id = server.create_device() with pytest.raises(asyncio.TimeoutError): - await apns.send_notification(device_id, Notification(Alert('test1'))) + await apns.send_notification(device_id, Notification(Alert("test1"))) async def test_bad_device_id(client): with pytest.raises(BadDeviceToken): - await client.send_notification('does not exist', Notification(Alert('test'))) + await client.send_notification("does not exist", Notification(Alert("test"))) diff --git a/tests/test_models.py b/tests/test_models.py index 99081aa..8c582b5 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,28 +4,25 @@ def test_encode(): - notification = models.Notification( - alert=models.Alert( - title='Test', - body='Content', - ) + notification = models.Notification(alert=models.Alert(title="Test", body="Content")) + assert ( + notification.encode() == b'{"aps":{"alert":{"body":"Content","title":"Test"}}}' ) - assert notification.encode() == b'{"aps":{"alert":{"body":"Content","title":"Test"}}}' def test_localize(): notification = models.Notification( alert=models.Alert( - title=models.Localized('Test', ['foo', 'bar']), - body=models.Localized('Content'), + title=models.Localized("Test", ["foo", "bar"]), + body=models.Localized("Content"), ) ) assert notification.get_dict() == { - 'aps': { - 'alert': { - 'title-loc-key': 'Test', - 'title-loc-args': ['foo', 'bar'], - 'loc-key': 'Content', + "aps": { + "alert": { + "title-loc-key": "Test", + "title-loc-args": ["foo", "bar"], + "loc-key": "Content", } } } @@ -33,53 +30,47 @@ def test_localize(): def test_localized_invalid_args(): with pytest.raises(TypeError): - models.Localized('foo', [1]) + models.Localized("foo", [1]) def test_full(): notification = models.Notification( alert=models.Alert( - title=models.Localized('Test', ['foo', 'bar']), - body=models.Localized('Content', ['hoge']), - action_loc_key='action', - launch_image='my/image.png', + title=models.Localized("Test", ["foo", "bar"]), + body=models.Localized("Content", ["hoge"]), + action_loc_key="action", + launch_image="my/image.png", ), badge=13, - sound='sounds/alert.mp3', + sound="sounds/alert.mp3", content_available=True, - category='my-category', - thread_id='thread-id', - extra={ - 'myapp': { - 'is': 'awesome' - } - } + category="my-category", + thread_id="thread-id", + extra={"myapp": {"is": "awesome"}}, ) assert notification.get_dict() == { - 'aps': { - 'alert': { - 'title-loc-key': 'Test', - 'title-loc-args': ['foo', 'bar'], - 'loc-key': 'Content', - 'loc-args': ['hoge'], - 'action-loc-key': 'action', - 'launch-image': 'my/image.png' + "aps": { + "alert": { + "title-loc-key": "Test", + "title-loc-args": ["foo", "bar"], + "loc-key": "Content", + "loc-args": ["hoge"], + "action-loc-key": "action", + "launch-image": "my/image.png", }, - 'badge': 13, - 'sound': 'sounds/alert.mp3', - 'content-available': 1, - 'category': 'my-category', - 'thread-id': 'thread-id', + "badge": 13, + "sound": "sounds/alert.mp3", + "content-available": 1, + "category": "my-category", + "thread-id": "thread-id", }, - 'myapp': { - 'is': 'awesome' - } + "myapp": {"is": "awesome"}, } def test_invalid_loc_args(): with pytest.raises(TypeError): - models.Localized('Test', [1, 2, 3]) + models.Localized("Test", [1, 2, 3]) def test_invalid_alert_title():