From 8a6273c29c411311a632f17d9f17998534fbcdba Mon Sep 17 00:00:00 2001 From: Frederick Gnodtke Date: Sun, 22 Sep 2024 00:15:31 +0200 Subject: [PATCH 1/5] Implement reconnecting for MQTT --- myskoda/const.py | 2 + myskoda/mqtt.py | 139 ++++++++++++++++++++++++++++++++++++-------- myskoda/myskoda.py | 12 ++-- myskoda/rest_api.py | 1 - poetry.lock | 46 +++++++-------- 5 files changed, 147 insertions(+), 53 deletions(-) diff --git a/myskoda/const.py b/myskoda/const.py index 19d5597b..7035316e 100644 --- a/myskoda/const.py +++ b/myskoda/const.py @@ -44,5 +44,7 @@ MQTT_ACCOUNT_EVENT_TOPICS = [ "account-event/privacy", ] +MQTT_KEEPALIVE = 15 +MQTT_RECONNECT_DELAY = 5 MAX_RETRIES = 5 diff --git a/myskoda/mqtt.py b/myskoda/mqtt.py index 3724446a..61764067 100644 --- a/myskoda/mqtt.py +++ b/myskoda/mqtt.py @@ -4,7 +4,7 @@ import logging import re import ssl -from asyncio import Future, create_task, get_event_loop +from asyncio import Future, Lock, create_task, get_event_loop, sleep from collections.abc import Awaitable, Callable from typing import Any, cast @@ -15,7 +15,9 @@ MQTT_ACCOUNT_EVENT_TOPICS, MQTT_BROKER_HOST, MQTT_BROKER_PORT, + MQTT_KEEPALIVE, MQTT_OPERATION_TOPICS, + MQTT_RECONNECT_DELAY, MQTT_SERVICE_EVENT_TOPICS, ) from .event import ( @@ -49,46 +51,92 @@ def __init__( # noqa: D107 self.future = future +connect_lock = Lock() + + class Mqtt: api: RestApi - user: User - vehicles: list[str] + user: User | None + vehicles: list[str] | None client: AsyncioPahoClient _callbacks: list[Callable[[Event], None | Awaitable[None]]] _operation_listeners: list[OperationListener] _connected_listeners: list[Future[None]] + should_reconnect: bool + is_connected: bool - def __init__(self, api: RestApi) -> None: # noqa: D107 + def __init__(self, api: RestApi, ssl_context: ssl.SSLContext | None = None) -> None: # noqa: D107 self.api = api self._callbacks = [] self._operation_listeners = [] self._connected_listeners = [] + self.ssl_context = ssl_context + self.is_connected = False + self.user = None + self.vehicles = None def subscribe(self, callback: Callable[[Event], None | Awaitable[None]]) -> None: """Listen for events emitted by MySkoda's MQTT broker.""" self._callbacks.append(callback) - async def connect(self, ssl_context: ssl.SSLContext | None = None) -> None: + async def connect(self) -> None: """Connect to the MQTT broker and listen for messages.""" - _LOGGER.debug(f"Connecting to MQTT on {MQTT_BROKER_HOST}:{MQTT_BROKER_PORT}...") - self.user = await self.api.get_user() - _LOGGER.debug(f"Using user id {self.user.id}...") - self.vehicles = await self.api.list_vehicles() - self.client = AsyncioPahoClient() - self.client.on_connect = self._on_connect - self.client.on_message = self._on_message - if ssl_context is not None: - self.client.tls_set_context(context=ssl_context) - else: - self.client.tls_set_context(context=ssl.create_default_context()) - self.client.username_pw_set( - self.user.id, await self.api.idk_session.get_access_token(self.api.session) - ) - self.client.connect_async(MQTT_BROKER_HOST, MQTT_BROKER_PORT, 60) - await self._wait_for_connection() + + async def perform_connect() -> bool: + try: + if self.is_connected: + return True + + _LOGGER.debug(f"Connecting to MQTT on {MQTT_BROKER_HOST}:{MQTT_BROKER_PORT}...") + + if not self.user: + self.user = await self.api.get_user() + _LOGGER.debug(f"Using user id {self.user.id}...") + if not self.vehicles: + self.vehicles = await self.api.list_vehicles() + + self.should_reconnect = True + + self.client = AsyncioPahoClient() + self.client.on_connect = self._on_connect + self.client.on_message = self._on_message + self.client.on_disconnect = self._on_disconnect + self.client.on_socket_close = self._on_socket_close + self.client.on_connect_fail = self._on_connect_fail + if self.ssl_context is not None: + self.client.tls_set_context(context=self.ssl_context) + else: + self.client.tls_set_context(context=ssl.create_default_context()) + self.client.username_pw_set( + self.user.id, + await self.api.idk_session.get_access_token(self.api.session), + ) + self.client.connect_async(MQTT_BROKER_HOST, MQTT_BROKER_PORT, MQTT_KEEPALIVE) + + await self._wait_for_connection() + except FailedToConnectError: + return False + else: + return True + + async with connect_lock: + while not await perform_connect(): # noqa: ASYNC110 + await sleep(MQTT_RECONNECT_DELAY) + + def reconnect(self) -> None: + """Reconnect a client that was previously connected and was disconnected.""" + _LOGGER.info("Scheduling to reconnect MQTT.") + + if not self.should_reconnect: + return + + task = create_task(cast(Any, self.connect())) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) def disconnect(self) -> None: """Stop the thread for processing MQTT messages.""" + self.should_reconnect = False self.client.disconnect() # pyright: ignore [reportArgumentType] def _wait_for_connection(self) -> Future[None]: @@ -108,10 +156,46 @@ def wait_for_operation(self, operation_name: OperationName) -> Future[OperationR return future + def _on_socket_close(self, client: AsyncioPahoClient, _data: None, _socket: None) -> None: + if client is not self.client: + return + _LOGGER.info("Socket to MQTT broker closed.") + self.is_connected = False + self.reconnect() + + def _on_connect_fail(self, client: AsyncioPahoClient, _data: None) -> None: + if client is not self.client: + return + _LOGGER.error("Failed to connect to MQTT.") + for future in self._connected_listeners: + future.set_exception(FailedToConnectError) + self._connected_listeners = [] + + def _on_disconnect( + self, + client: AsyncioPahoClient, + _userdata: None, + reason_code: int, + ) -> None: + if client is not self.client: + return + _LOGGER.info("Connection to MQTT broker lost, reason %d.", reason_code) + self.is_connected = False + self.reconnect() + def _on_connect( - self, _client: AsyncioPahoClient, _data: None, _flags: dict, _reason: int + self, client: AsyncioPahoClient, _data: None, _flags: dict, _reason: int ) -> None: + if client is not self.client: + return + + self.is_connected = True + _LOGGER.info("MQTT Connected.") + if not self.user or not self.vehicles: + _LOGGER.error("Reached on_connect, but user and vehicles not loaded") + return + user_id = self.user.id for vin in self.vehicles: @@ -181,7 +265,12 @@ def _handle_operation(self, event: Event) -> None: ) self._handle_operation_completed(event.operation) - def _on_message(self, _client: AsyncioPahoClient, _data: None, msg: MQTTMessage) -> None: + def _on_message( # noqa: C901 + self, client: AsyncioPahoClient, _data: None, msg: MQTTMessage + ) -> None: + if client is not self.client: + return + # Extract the topic, user id and vin from the topic's name. # Internally, the topic will always look like this: # `/{user_id}/{vin}/path/to/topic` @@ -226,3 +315,7 @@ def __init__(self, operation: OperationRequest) -> None: # noqa: D107 error = operation.error_code trace = operation.trace_id super().__init__(f"Operation {op} with trace {trace} failed: {error}") + + +class FailedToConnectError(Exception): + pass diff --git a/myskoda/myskoda.py b/myskoda/myskoda.py index e5e80e94..b57e0fd6 100644 --- a/myskoda/myskoda.py +++ b/myskoda/myskoda.py @@ -52,17 +52,17 @@ class MySkoda: rest_api: RestApi mqtt: Mqtt - def __init__(self, session: ClientSession) -> None: # noqa: D107 + def __init__( # noqa: D107 + self, session: ClientSession, ssl_context: SSLContext | None = None + ) -> None: self.session = session self.rest_api = RestApi(self.session) - self.mqtt = Mqtt(self.rest_api) + self.mqtt = Mqtt(self.rest_api, ssl_context=ssl_context) - async def connect( - self, email: str, password: str, ssl_context: SSLContext | None = None - ) -> None: + async def connect(self, email: str, password: str) -> None: """Authenticate on the rest api and connect to the MQTT broker.""" await self.rest_api.authenticate(email, password) - await self.mqtt.connect(ssl_context=ssl_context) + await self.mqtt.connect() _LOGGER.debug("Myskoda ready.") def subscribe(self, callback: Callable[[Event], None | Awaitable[None]]) -> None: diff --git a/myskoda/rest_api.py b/myskoda/rest_api.py index 32a12e7b..1fa30895 100644 --- a/myskoda/rest_api.py +++ b/myskoda/rest_api.py @@ -57,7 +57,6 @@ async def get_charging(self, vin: str) -> Charging: ) as response: response_text = await response.text() _LOGGER.debug(f"vin {vin}: received charging info: {response_text}") - print(await response.json()) return Charging.parse_raw(response_text) async def get_status(self, vin: str) -> Status: diff --git a/poetry.lock b/poetry.lock index 806ad94d..d1f69b4a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -138,13 +138,13 @@ frozenlist = ">=1.1.0" [[package]] name = "anyio" -version = "4.5.0" +version = "4.6.0" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = true -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "anyio-4.5.0-py3-none-any.whl", hash = "sha256:fdeb095b7cc5a5563175eedd926ec4ae55413bb4be5770c424af0ba46ccb4a78"}, - {file = "anyio-4.5.0.tar.gz", hash = "sha256:c5a275fe5ca0afd788001f58fca1e69e29ce706d746e317d660e21f70c530ef9"}, + {file = "anyio-4.6.0-py3-none-any.whl", hash = "sha256:c7d2e9d63e31599eeb636c8c5c03a7e108d73b345f064f1c19fdc87b79036a9a"}, + {file = "anyio-4.6.0.tar.gz", hash = "sha256:137b4559cbb034c477165047febb6ff83f390fc3b20bf181c1fc0a728cb8beeb"}, ] [package.dependencies] @@ -752,29 +752,29 @@ files = [ [[package]] name = "ruff" -version = "0.6.6" +version = "0.6.7" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.6-py3-none-linux_armv6l.whl", hash = "sha256:f5bc5398457484fc0374425b43b030e4668ed4d2da8ee7fdda0e926c9f11ccfb"}, - {file = "ruff-0.6.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:515a698254c9c47bb84335281a170213b3ee5eb47feebe903e1be10087a167ce"}, - {file = "ruff-0.6.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6bb1b4995775f1837ab70f26698dd73852bbb82e8f70b175d2713c0354fe9182"}, - {file = "ruff-0.6.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69c546f412dfae8bb9cc4f27f0e45cdd554e42fecbb34f03312b93368e1cd0a6"}, - {file = "ruff-0.6.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59627e97364329e4eae7d86fa7980c10e2b129e2293d25c478ebcb861b3e3fd6"}, - {file = "ruff-0.6.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:94c3f78c3d32190aafbb6bc5410c96cfed0a88aadb49c3f852bbc2aa9783a7d8"}, - {file = "ruff-0.6.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:704da526c1e137f38c8a067a4a975fe6834b9f8ba7dbc5fd7503d58148851b8f"}, - {file = "ruff-0.6.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:efeede5815a24104579a0f6320660536c5ffc1c91ae94f8c65659af915fb9de9"}, - {file = "ruff-0.6.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e368aef0cc02ca3593eae2fb8186b81c9c2b3f39acaaa1108eb6b4d04617e61f"}, - {file = "ruff-0.6.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2653fc3b2a9315bd809725c88dd2446550099728d077a04191febb5ea79a4f79"}, - {file = "ruff-0.6.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:bb858cd9ce2d062503337c5b9784d7b583bcf9d1a43c4df6ccb5eab774fbafcb"}, - {file = "ruff-0.6.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:488f8e15c01ea9afb8c0ba35d55bd951f484d0c1b7c5fd746ce3c47ccdedce68"}, - {file = "ruff-0.6.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:aefb0bd15f1cfa4c9c227b6120573bb3d6c4ee3b29fb54a5ad58f03859bc43c6"}, - {file = "ruff-0.6.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a4c0698cc780bcb2c61496cbd56b6a3ac0ad858c966652f7dbf4ceb029252fbe"}, - {file = "ruff-0.6.6-py3-none-win32.whl", hash = "sha256:aadf81ddc8ab5b62da7aae78a91ec933cbae9f8f1663ec0325dae2c364e4ad84"}, - {file = "ruff-0.6.6-py3-none-win_amd64.whl", hash = "sha256:0adb801771bc1f1b8cf4e0a6fdc30776e7c1894810ff3b344e50da82ef50eeb1"}, - {file = "ruff-0.6.6-py3-none-win_arm64.whl", hash = "sha256:4b4d32c137bc781c298964dd4e52f07d6f7d57c03eae97a72d97856844aa510a"}, - {file = "ruff-0.6.6.tar.gz", hash = "sha256:0fc030b6fd14814d69ac0196396f6761921bd20831725c7361e1b8100b818034"}, + {file = "ruff-0.6.7-py3-none-linux_armv6l.whl", hash = "sha256:08277b217534bfdcc2e1377f7f933e1c7957453e8a79764d004e44c40db923f2"}, + {file = "ruff-0.6.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:c6707a32e03b791f4448dc0dce24b636cbcdee4dd5607adc24e5ee73fd86c00a"}, + {file = "ruff-0.6.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:533d66b7774ef224e7cf91506a7dafcc9e8ec7c059263ec46629e54e7b1f90ab"}, + {file = "ruff-0.6.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17a86aac6f915932d259f7bec79173e356165518859f94649d8c50b81ff087e9"}, + {file = "ruff-0.6.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b3f8822defd260ae2460ea3832b24d37d203c3577f48b055590a426a722d50ef"}, + {file = "ruff-0.6.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9ba4efe5c6dbbb58be58dd83feedb83b5e95c00091bf09987b4baf510fee5c99"}, + {file = "ruff-0.6.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:525201b77f94d2b54868f0cbe5edc018e64c22563da6c5c2e5c107a4e85c1c0d"}, + {file = "ruff-0.6.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8854450839f339e1049fdbe15d875384242b8e85d5c6947bb2faad33c651020b"}, + {file = "ruff-0.6.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2f0b62056246234d59cbf2ea66e84812dc9ec4540518e37553513392c171cb18"}, + {file = "ruff-0.6.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b1462fa56c832dc0cea5b4041cfc9c97813505d11cce74ebc6d1aae068de36b"}, + {file = "ruff-0.6.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:02b083770e4cdb1495ed313f5694c62808e71764ec6ee5db84eedd82fd32d8f5"}, + {file = "ruff-0.6.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0c05fd37013de36dfa883a3854fae57b3113aaa8abf5dea79202675991d48624"}, + {file = "ruff-0.6.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f49c9caa28d9bbfac4a637ae10327b3db00f47d038f3fbb2195c4d682e925b14"}, + {file = "ruff-0.6.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a0e1655868164e114ba43a908fd2d64a271a23660195017c17691fb6355d59bb"}, + {file = "ruff-0.6.7-py3-none-win32.whl", hash = "sha256:a939ca435b49f6966a7dd64b765c9df16f1faed0ca3b6f16acdf7731969deb35"}, + {file = "ruff-0.6.7-py3-none-win_amd64.whl", hash = "sha256:590445eec5653f36248584579c06252ad2e110a5d1f32db5420de35fb0e1c977"}, + {file = "ruff-0.6.7-py3-none-win_arm64.whl", hash = "sha256:b28f0d5e2f771c1fe3c7a45d3f53916fc74a480698c4b5731f0bea61e52137c8"}, + {file = "ruff-0.6.7.tar.gz", hash = "sha256:44e52129d82266fa59b587e2cd74def5637b730a69c4542525dfdecfaae38bd5"}, ] [[package]] From e89d0fb3458436988a24cae571879dc12cf3fd77 Mon Sep 17 00:00:00 2001 From: Frederick Gnodtke Date: Sun, 22 Sep 2024 11:29:42 +0200 Subject: [PATCH 2/5] Re-login if refreshing failed --- myskoda/authorization.py | 88 +++++++++++++++++++++++++++++----------- myskoda/mqtt.py | 36 ++++++++++++---- myskoda/rest_api.py | 14 +++++-- 3 files changed, 102 insertions(+), 36 deletions(-) diff --git a/myskoda/authorization.py b/myskoda/authorization.py index c8b557f3..176271a0 100644 --- a/myskoda/authorization.py +++ b/myskoda/authorization.py @@ -1,5 +1,6 @@ """Handles authorization to the MySkoda API.""" +from asyncio import Lock import base64 import hashlib import json @@ -58,6 +59,9 @@ class IDKAuthorizationCode(BaseModel): id_token: str +refresh_token_lock = Lock() + + class IDKSession(BaseModel): """Stores the JWT tokens relevant for a session at the IDK server. @@ -68,30 +72,58 @@ class IDKSession(BaseModel): refresh_token: str = Field(None, refreshToken="accessToken") id_token: str = Field(None, idToken="accessToken") - async def perform_refresh(self, session: ClientSession, attempt: int = 0) -> None: + async def perform_refresh(self, session: ClientSession, username: str, password: str) -> None: """Refresh the authorization token. This will consume the `refresh_token` and exchange it for a new set of tokens. """ json_data = {"token": self.refresh_token} - async with session.post( - f"{BASE_URL_SKODA}/api/v1/authentication/refresh-token?tokenType=CONNECT", - json=json_data, - ) as response: - try: - if not response.ok: - raise InvalidStatusError(response.status) # noqa: TRY301 - data = json.loads(await response.text()) - self.access_token = data.get("accessToken") - self.refresh_token = data.get("refreshToken") - self.id_token = data.get("idToken") - except Exception: - if attempt >= MAX_RETRIES: - raise - _LOGGER.warning("Retrying failed request to refresh token.") - await self.perform_refresh(session, attempt=attempt + 1) - - async def get_access_token(self, session: ClientSession) -> str: + + async def perform_request() -> bool: + meta = jwt.decode( + self.access_token, options={"verify_signature": False} + ) + expiry = datetime.fromtimestamp(cast(float, meta.get("exp")), tz=UTC) + if datetime.now(tz=UTC) + timedelta(minutes=10) < expiry: + return True + async with session.post( + f"{BASE_URL_SKODA}/api/v1/authentication/refresh-token?tokenType=CONNECT", + json=json_data, + ) as response: + try: + if not response.ok: + raise InvalidStatusError(response.status) # noqa: TRY301 + data = json.loads(await response.text()) + except Exception: + return False + else: + self.access_token = data.get("accessToken") + self.refresh_token = data.get("refreshToken") + self.id_token = data.get("idToken") + return True + + async with refresh_token_lock: + attempts = 0 + while not await perform_request(): + if attempts >= MAX_RETRIES: + _LOGGER.error( + "Refreshing token failed after %d attempts.", MAX_RETRIES + ) + _LOGGER.info("Trying to recover by logging in again...") + try: + tokens = await idk_authorize(session, username, password) + except Exception: + _LOGGER.exception("Failed to login.") + else: + self.access_token = tokens.access_token + self.refresh_token = tokens.refresh_token + self.id_token = tokens.id_token + _LOGGER.info("Successfully recovered by logging in.") + return + _LOGGER.warning("Retrying failed request to refresh token. Retrying...") + attempts = attempts + 1 + + async def get_access_token(self, session: ClientSession, username: str, password: str) -> str: """Get the access token. Use this method instead of using `access_token` directly. It will automatically @@ -101,7 +133,7 @@ async def get_access_token(self, session: ClientSession) -> str: expiry = datetime.fromtimestamp(cast(float, meta.get("exp")), tz=UTC) if datetime.now(tz=UTC) + timedelta(minutes=10) > expiry: _LOGGER.info("Refreshing IDK access token") - await self.perform_refresh(session) + await self.perform_refresh(session, username, password) return self.access_token @@ -168,7 +200,9 @@ async def _initial_oidc_authorize( "code_challenge_method": "s256", "prompt": "login", } - async with session.get(f"{BASE_URL_IDENT}/oidc/v1/authorize", params=params) as response: + async with session.get( + f"{BASE_URL_IDENT}/oidc/v1/authorize", params=params + ) as response: data = _extract_states_from_website(await response.text()) return IDKCredentials(data, email, password) @@ -258,14 +292,18 @@ async def _exchange_auth_code_for_idk_session( return IDKSession(**await response.json()) -async def idk_authorize(session: ClientSession, email: str, password: str) -> IDKSession: +async def idk_authorize( + session: ClientSession, email: str, password: str +) -> IDKSession: """Perform the full login process. Must be called before any other methods on the class can be called. """ # Generate a random string for the OAUTH2 PKCE challenge. # (https://www.oauth.com/oauth2-servers/pkce/authorization-request/) - verifier = "".join(random.choices(string.ascii_uppercase + string.digits, k=16)) # noqa: S311 + verifier = "".join( + random.choices(string.ascii_uppercase + string.digits, k=16) + ) # noqa: S311 # Call the initial OIDC (OpenID Connect) authorization, giving us the initial SSO information. # The full flow is explain a little bit here: @@ -281,7 +319,9 @@ async def idk_authorize(session: ClientSession, email: str, password: str) -> ID authentication = await _enter_password(session, login_meta) # Exchange the token for access and refresh tokens (JWT format). - return await _exchange_auth_code_for_idk_session(session, authentication.code, verifier) + return await _exchange_auth_code_for_idk_session( + session, authentication.code, verifier + ) class AuthorizationError(Exception): diff --git a/myskoda/mqtt.py b/myskoda/mqtt.py index 61764067..7461730d 100644 --- a/myskoda/mqtt.py +++ b/myskoda/mqtt.py @@ -65,7 +65,9 @@ class Mqtt: should_reconnect: bool is_connected: bool - def __init__(self, api: RestApi, ssl_context: ssl.SSLContext | None = None) -> None: # noqa: D107 + def __init__( + self, api: RestApi, ssl_context: ssl.SSLContext | None = None + ) -> None: # noqa: D107 self.api = api self._callbacks = [] self._operation_listeners = [] @@ -87,7 +89,9 @@ async def perform_connect() -> bool: if self.is_connected: return True - _LOGGER.debug(f"Connecting to MQTT on {MQTT_BROKER_HOST}:{MQTT_BROKER_PORT}...") + _LOGGER.debug( + f"Connecting to MQTT on {MQTT_BROKER_HOST}:{MQTT_BROKER_PORT}..." + ) if not self.user: self.user = await self.api.get_user() @@ -109,9 +113,13 @@ async def perform_connect() -> bool: self.client.tls_set_context(context=ssl.create_default_context()) self.client.username_pw_set( self.user.id, - await self.api.idk_session.get_access_token(self.api.session), + await self.api.idk_session.get_access_token( + self.api.session, self.api.email, self.api.password + ), + ) + self.client.connect_async( + MQTT_BROKER_HOST, MQTT_BROKER_PORT, MQTT_KEEPALIVE ) - self.client.connect_async(MQTT_BROKER_HOST, MQTT_BROKER_PORT, MQTT_KEEPALIVE) await self._wait_for_connection() except FailedToConnectError: @@ -147,7 +155,9 @@ def _wait_for_connection(self) -> Future[None]: return future - def wait_for_operation(self, operation_name: OperationName) -> Future[OperationRequest]: + def wait_for_operation( + self, operation_name: OperationName + ) -> Future[OperationRequest]: """Wait until the next operation of the specified type completes.""" _LOGGER.debug("Waiting for operation %s complete.", operation_name) future: Future[OperationRequest] = get_event_loop().create_future() @@ -156,7 +166,9 @@ def wait_for_operation(self, operation_name: OperationName) -> Future[OperationR return future - def _on_socket_close(self, client: AsyncioPahoClient, _data: None, _socket: None) -> None: + def _on_socket_close( + self, client: AsyncioPahoClient, _data: None, _socket: None + ) -> None: if client is not self.client: return _LOGGER.info("Socket to MQTT broker closed.") @@ -241,9 +253,13 @@ def _handle_operation_completed(self, operation: OperationRequest) -> None: listener.future.set_exception(OperationFailedError(operation)) else: if operation.status == OperationStatus.COMPLETED_WARNING: - _LOGGER.warning("Operation '%s' completed with warnings.", operation.operation) + _LOGGER.warning( + "Operation '%s' completed with warnings.", operation.operation + ) - _LOGGER.debug("Resolving listener for operation '%s'.", operation.operation) + _LOGGER.debug( + "Resolving listener for operation '%s'.", operation.operation + ) listener.future.set_result(operation) def _handle_operation(self, event: Event) -> None: @@ -287,7 +303,9 @@ def _on_message( # noqa: C901 if len(data) == 0: return - _LOGGER.debug("Message (%s) received for %s on topic %s: %s", event_type, vin, topic, data) + _LOGGER.debug( + "Message (%s) received for %s on topic %s: %s", event_type, vin, topic, data + ) # Messages will contain payload as JSON. data = json.loads(msg.payload) diff --git a/myskoda/rest_api.py b/myskoda/rest_api.py index 1fa30895..0909c179 100644 --- a/myskoda/rest_api.py +++ b/myskoda/rest_api.py @@ -34,6 +34,8 @@ async def authenticate(self, email: str, password: str) -> bool: Must be called before any other methods on the class can be called. """ + self.email = email + self.password = password self.idk_session = await idk_authorize(self.session, email, password) _LOGGER.info("IDK Authorization was successful.") @@ -149,7 +151,9 @@ async def list_vehicles(self) -> list[str]: return [vehicle["vin"] for vehicle in json["vehicles"]] async def _headers(self) -> dict[str, str]: - return {"authorization": f"Bearer {await self.idk_session.get_access_token(self.session)}"} + return { + "authorization": f"Bearer {await self.idk_session.get_access_token(self.session, self.email, self.password)}" + } async def stop_air_conditioning(self, vin: str) -> None: """Stop the air conditioning.""" @@ -183,7 +187,9 @@ async def start_air_conditioning(self, vin: str, temperature: float) -> None: async def set_target_temperature(self, vin: str, temperature: float) -> None: """Set the air conditioning's target temperature in °C.""" - _LOGGER.debug("Setting target temperature for vehicle %s to %s", vin, str(temperature)) + _LOGGER.debug( + "Setting target temperature for vehicle %s to %s", vin, str(temperature) + ) json_data = {"temperatureValue": str(temperature), "unitInCar": "CELSIUS"} async with self.session.post( f"{BASE_URL_SKODA}/api/v2/air-conditioning/{vin}/settings/target-temperature", @@ -291,7 +297,9 @@ async def honk_flash( ) -> None: """Honk and/or flash.""" positions = await self.get_positions(vin) - position = next(pos for pos in positions.positions if pos.type == PositionType.VEHICLE) + position = next( + pos for pos in positions.positions if pos.type == PositionType.VEHICLE + ) json_data = { "mode": "HONK_AND_FLASH" if honk else "FLASH", "vehiclePosition": { From d8e18dd2459a59fd735fe349b916dd6c62e0b932 Mon Sep 17 00:00:00 2001 From: Frederick Gnodtke Date: Mon, 23 Sep 2024 11:33:48 +0200 Subject: [PATCH 3/5] Move authorization into own class --- myskoda/__init__.py | 11 +- myskoda/auth/__init__.py | 1 + myskoda/auth/authorization.py | 306 +++++++++++++++++++++++++++++++ myskoda/auth/csrf_parser.py | 59 ++++++ myskoda/authorization.py | 335 ---------------------------------- myskoda/mqtt.py | 108 +++++------ myskoda/myskoda.py | 15 +- myskoda/rest_api.py | 32 +--- tests/test_rest_api.py | 6 +- 9 files changed, 438 insertions(+), 435 deletions(-) create mode 100644 myskoda/auth/__init__.py create mode 100644 myskoda/auth/authorization.py create mode 100644 myskoda/auth/csrf_parser.py delete mode 100644 myskoda/authorization.py diff --git a/myskoda/__init__.py b/myskoda/__init__.py index 6afb6d80..0cedc033 100644 --- a/myskoda/__init__.py +++ b/myskoda/__init__.py @@ -1,12 +1,6 @@ """A library for interacting with the MySkoda APIs.""" -from .authorization import ( - AuthorizationError, - IDKAuthorizationCode, - IDKCredentials, - IDKSession, - idk_authorize, -) +from .auth.authorization import Authorization, AuthorizationError, IDKAuthorizationCode, IDKSession from .models import ( air_conditioning, charging, @@ -25,11 +19,10 @@ from .vehicle import Vehicle __all__ = [ + "Authorization", "AuthorizationError", "IDKAuthorizationCode", - "IDKCredentials", "IDKSession", - "idk_authorize", "air_conditioning", "charging", "common", diff --git a/myskoda/auth/__init__.py b/myskoda/auth/__init__.py new file mode 100644 index 00000000..29963fb6 --- /dev/null +++ b/myskoda/auth/__init__.py @@ -0,0 +1 @@ +"""Authorization for VW IDK servers.""" diff --git a/myskoda/auth/authorization.py b/myskoda/auth/authorization.py new file mode 100644 index 00000000..88e0a4c9 --- /dev/null +++ b/myskoda/auth/authorization.py @@ -0,0 +1,306 @@ +"""Handles authorization to the MySkoda API.""" + +import base64 +import hashlib +import logging +import random +import string +import uuid +from asyncio import Lock +from datetime import UTC, datetime, timedelta +from typing import cast + +import jwt +from aiohttp import ClientSession, FormData +from pydantic import BaseModel, Field, ValidationError + +from myskoda.auth.csrf_parser import CSRFParser, CSRFState +from myskoda.const import BASE_URL_IDENT, BASE_URL_SKODA, CLIENT_ID, MAX_RETRIES + +_LOGGER = logging.getLogger(__name__) + + +class IDKAuthorizationCode(BaseModel): + """One-time authorization code that can be obtained by logging in. + + This authorization code can later be exchanged for a set of JWT tokens. + """ + + code: str + token_type: str + id_token: str + + +refresh_token_lock = Lock() + + +class IDKSession(BaseModel): + """Stores the JWT tokens relevant for a session at the IDK server. + + Can be used to authorized and refresh the authorization token. + """ + + access_token: str = Field(None, alias="accessToken") + refresh_token: str = Field(None, alias="refreshToken") + id_token: str = Field(None, alias="idToken") + + +class Authorization: + session: ClientSession + idk_session: IDKSession | None = None + + def __init__(self, session: ClientSession) -> None: # noqa: D107 + self.session = session + + def _extract_csrf(self, html: str) -> CSRFState: + parser = CSRFParser() + parser.feed(html) + + if parser.csrf_state is None: + raise CSRFError + + return parser.csrf_state + + async def authorize(self, email: str, password: str) -> None: + """Authorize on the VW IDK servers.""" + self.email = email + self.password = password + self.idk_session = await self._get_idk_session() + + if self.idk_session is None: + raise AuthorizationFailedError + + async def _initial_oidc_authorize(self, verifier: str) -> CSRFState: + """First step of the login process. + + This calls the route for initial authorization, + which will contain the initial SSO information such as the CSRF or the HMAC. + """ + # A SHA256 hash of the random "verifier" string will be transmitted as a challenge. + # This is part of the OAUTH2 PKCE process. It is described here in detail: + # https://www.oauth.com/oauth2-servers/pkce/authorization-request/ + verifier_hash = hashlib.sha256(verifier.encode("utf-8")).digest() + challenge = ( + base64.b64encode(verifier_hash) + .decode("utf-8") + .replace("+", "-") + .replace("/", "_") + .rstrip("=") + ) + + params = { + "client_id": CLIENT_ID, + "nonce": str(uuid.uuid4()), + "redirect_uri": "myskoda://redirect/login/", + "response_type": "code id_token", + # OpenID scopes. Can be found here: https://identity.vwgroup.io/.well-known/openid-configuration + "scope": "address badge birthdate cars driversLicense dealers email mileage mbb nationalIdentifier openid phone profession profile vin", # noqa: E501 + "code_challenge": challenge, + "code_challenge_method": "s256", + "prompt": "login", + } + async with self.session.get( + f"{BASE_URL_IDENT}/oidc/v1/authorize", params=params + ) as response: + return self._extract_csrf(await response.text()) + + async def _enter_email_address(self, csrf: CSRFState) -> CSRFState: + """Second step in the login process. + + Will post only the email address to the backend. + The password will follow in a later request. + """ + form_data = FormData() + form_data.add_field("relayState", csrf.template_model.relay_state) + form_data.add_field("email", self.email) + form_data.add_field("hmac", csrf.template_model.hmac) + form_data.add_field("_csrf", csrf.csrf) + + async with self.session.post( + f"{BASE_URL_IDENT}/signin-service/v1/{CLIENT_ID}/login/identifier", + data=form_data(), + ) as response: + return self._extract_csrf(await response.text()) + + async def _enter_password(self, csrf: CSRFState) -> IDKAuthorizationCode: + """Third step in the login process. + + Post both the email address and the password to the backend. + This will return a token which can then be used in the skoda services to authenticate. + """ + form_data = FormData() + form_data.add_field("relayState", csrf.template_model.relay_state) + form_data.add_field("email", self.email) + form_data.add_field("password", self.password) + form_data.add_field("hmac", csrf.template_model.hmac) + form_data.add_field("_csrf", csrf.csrf) + + # The following is a bit hacky: + # The backend will redirect multiple times after the login was successful. + # The last redirect will redirect back to the `MySkoda` app in Android, + # using the `myskoda://` URL prefix. + # The following loop will follow all redirects until the last redirect to `myskoda://` is + # encountered. This last URL will contain the token. + async with self.session.post( + f"{BASE_URL_IDENT}/signin-service/v1/{CLIENT_ID}/login/authenticate", + data=form_data(), + allow_redirects=False, + ) as auth_response: + location = auth_response.headers["Location"] + while not location.startswith("myskoda://"): + async with self.session.get(location, allow_redirects=False) as response: + location = response.headers["Location"] + codes = location.replace("myskoda://redirect/login/#", "") + + # The last redirection starting with `myskoda://` was encountered. + # The URL will contain the information we need as query parameters, + # without the leading `?`. + data = {} + for code in codes.split("&"): + [key, value] = code.split("=") + data[key] = value + + return IDKAuthorizationCode(**data) + + async def _exchange_auth_code_for_idk_session(self, code: str, verifier: str) -> IDKSession: + """Exchange the ident login code for an auth token from Skoda. + + This will return multiple tokens, such as an access token and a refresh token. + """ + json_data = { + "code": code, + "redirectUri": "myskoda://redirect/login/", + "verifier": verifier, + } + + async with self.session.post( + f"{BASE_URL_SKODA}/api/v1/authentication/exchange-authorization-code?tokenType=CONNECT", + json=json_data, + allow_redirects=False, + ) as response: + return IDKSession(**await response.json()) + + async def _get_idk_session(self) -> IDKSession: + """Perform the full login process. + + Must be called before any other methods on the class can be called. + """ + # Generate a random string for the OAUTH2 PKCE challenge. + # (https://www.oauth.com/oauth2-servers/pkce/authorization-request/) + verifier = "".join( + random.choices(string.ascii_uppercase + string.digits, k=16) # noqa: S311 + ) + + # Call the initial OIDC (OpenID Connect) authorization, + # giving us the initial SSO information. + # The full flow is explain a little bit here: + # https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth + login_meta = await self._initial_oidc_authorize(verifier) + + # Use the information to login with the email address, + # which is an extra step before the actual login. + login_meta = await self._enter_email_address(login_meta) + + # Perform the actual login which will result in a token that can be exchanged for + # an access token at the Skoda server. + authentication = await self._enter_password(login_meta) + + # Exchange the token for access and refresh tokens (JWT format). + return await self._exchange_auth_code_for_idk_session(authentication.code, verifier) + + def is_token_expired(self) -> bool: + """Check whether the login token is expired.""" + if not self.idk_session: + raise NotAuthorizedError + + meta = jwt.decode(self.idk_session.access_token, options={"verify_signature": False}) + expiry = datetime.fromtimestamp(cast(float, meta.get("exp")), tz=UTC) + return datetime.now(tz=UTC) + timedelta(minutes=10) > expiry + + async def _perform_refresh_token(self) -> bool: + if not self.idk_session: + raise NotAuthorizedError + + if not self.is_token_expired(): + return True + + async with self.session.post( + f"{BASE_URL_SKODA}/api/v1/authentication/refresh-token?tokenType=CONNECT", + json={"token": self.idk_session.refresh_token}, + ) as response: + if not response.ok: + return False + try: + self.idk_session = IDKSession.parse_raw(await response.text()) + except ValidationError: + _LOGGER.exception("Failed to parse tokens from refresh endpoint.") + return False + else: + return True + + async def refresh_token(self) -> None: + """Refresh the authorization token. + + This will consume the `refresh_token` and exchange it for a new set of tokens. + """ + async with refresh_token_lock: + for attempt in range(MAX_RETRIES): + if await self._perform_refresh_token(): + return + _LOGGER.warning( + "Retrying failed request to refresh token (%d/%d). Retrying...", + attempt, + MAX_RETRIES, + ) + + _LOGGER.error("Refreshing token failed after %d attempts.", MAX_RETRIES) + _LOGGER.info("Trying to recover by logging in again...") + + try: + idk_session = await self._get_idk_session() + except Exception: + _LOGGER.exception("Failed to login.") + else: + self.idk_session = idk_session + _LOGGER.info("Successfully recovered by logging in.") + return + + async def get_access_token(self) -> str: + """Get the access token. + + Use this method instead of using `access_token` directly. It will automatically + check if the JWT token is about to expire and refresh it using the `refresh_token`. + """ + if self.idk_session is None: + raise NotAuthorizedError + + if self.is_token_expired(): + _LOGGER.info("Token expired. Refreshing IDK access token") + await self.refresh_token() + return self.idk_session.access_token + + +class AuthorizationError(Exception): + """Error to indicate that something unexpected happened during authorization.""" + + +class InvalidStatusError(Exception): + """An invalid HTTP status code was received.""" + + def __init__(self, status: int) -> None: # noqa: D107 + super().__init__(f"Received invalid HTTP status code {status}.") + + +class CSRFError(Exception): + """Failed to parse the CSRF information from the website.""" + + +class NotAuthorizedError(Exception): + """Not authorized. + + Did you forget to call Authorization.authorize()? + """ + + +class AuthorizationFailedError(Exception): + """Failed to authorize.""" diff --git a/myskoda/auth/csrf_parser.py b/myskoda/auth/csrf_parser.py new file mode 100644 index 00000000..64f891f6 --- /dev/null +++ b/myskoda/auth/csrf_parser.py @@ -0,0 +1,59 @@ +"""Parse CSRF information from the website.""" + +import re +from html.parser import HTMLParser + +import yaml +from pydantic import BaseModel, Field + +json_object = re.compile(r"window\._IDK\s=\s((?:\n|.)*?)$") + + +class TemplateModel(BaseModel): + hmac: str + relay_state: str = Field(None, alias="relayState") + + +class CSRFState(BaseModel): + csrf: str = Field(None, alias="csrf_token") + template_model: TemplateModel = Field(None, alias="templateModel") + + +class CSRFParser(HTMLParser): + """Information such as the CSRF or the hmac will be available in the HTML. + + This will parse the information from a `